当前位置:网站首页>1.6 example: cifar-10 classification
1.6 example: cifar-10 classification
2022-07-29 03:23:00 【smiling0927】
CIFAR-10 classification , Steps are as follows :
1) Use torchvision Load and preprocess CIFAR-10 Data sets
2) Defining network
3) Define the loss function and optimizer
4) Train the network and update network parameters
5) Test network
CIFAR-10 Data loading and preprocessing
CIFAR-10 Is a commonly used color image data set , It has 10 Categories airplane、automobile、bird、cat、deer、dog、frog、horse、ship and truck. Every picture is 3*32*32, That is to say 3 Channel color pictures , A resolution of 32*32.
import torch as t
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage()# You can put Tensor Turn into Image, Easy to visualize
# Run the program for the first time torchvision Will automatically download CIFAR-10 Data sets
# about 100MB, It takes a certain amount of time ,
# If it has been downloaded CIFAR-10, It can be done by root Parameter assignment
# Define the preprocessing of data
transform=transforms.Compose([
transforms.ToTensor(),# To Tensor
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))# normalization
])
# Training set
trainset = tv.datasets.CIFAR10(
root='/B/CIFAR-10data/',
train=True,
download=True,
transform=transform
)
trainloader=t.utils.data.DataLoader(
trainset,
batch_size=4,
shuffle=True,
num_workers=2
)
# Test set
testset = tv.datasets.CIFAR10 (
'/B/CIFAR-10data/',
train=False,
download=True,
transform=transform
)
testloader = t.utils.data.DataLoader (
testset,
batch_size=4,
shuffle=False,
num_workers=2
)
classes=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# Defining network
import torch.nn as nn
import torch.nn.functional as F
import time
start = time.time () # timing
# Define network structure
class Net (nn.Module):
def __init__(self):
super (Net, self).__init__ ()
self.conv1 = nn.Conv2d (3, 6, 5)
self.conv2 = nn.Conv2d (6, 16, 5)
self.fc1 = nn.Linear (16 * 5 * 5, 120)
self.fc2 = nn.Linear (120, 84)
self.fc3 = nn.Linear (84, 10)
def forward(self, x):
x = F.max_pool2d (F.relu (self.conv1 (x)), 2)
x = F.max_pool2d (F.relu (self.conv2 (x)), 2)
x = x.view (x.size ()[0], -1)
x = F.relu (self.fc1 (x))
x = F.relu (self.fc2 (x))
x = self.fc3 (x)
return x
net = Net ()
print (net)
# Define optimization and loss
# Define optimization and loss
loss_func = nn.CrossEntropyLoss () # Cross entropy loss function
optimizer = t.optim.SGD (net.parameters (), lr=0.001, momentum=0.9)
# Training network
for epoch in range (2):
running_loss = 0
for i, data in enumerate (trainloader, 0):
inputs, labels = data
outputs = net (inputs)
loss = loss_func (outputs, labels)
optimizer.zero_grad ()
loss.backward ()
optimizer.step ()
running_loss += loss.item ()
if i % 2000 == 1999:
print ('epoch:', epoch + 1, '|i:', i + 1, '|loss:%.3f' % (running_loss / 2000))
running_loss = 0.0
end = time.time ()
time_using = end - start
print ('finish training')
print ('time:', time_using)
# test result
correct = 0 # The number of pictures correctly predicted by the definition
total = 0# Total number of pictures
with t.no_grad():
for data in testloader:
images,labels = data
outputs = net(images)
_,predict = t.max(outputs,1)
total += labels.size(0)
correct += (predict == labels).sum()
print(' The accuracy in the test set is :%d%%'%(100*correct/total))
Dataloader It's an iterative object , It will dataset Each returned data sample is spliced into one batch, It also provides operations such as multi-threaded acceleration optimization and data scrambling . When the program is right dataset After traversing all the data of , Yes Dataloader Also completed an iteration .
Only 2 individual epoch( After traversing the data set, it is called a epoch), Observe whether the network is effective . Input the test image into the network , Calculate its label, And then with reality label Compare .
Output :
Files already downloaded and verified
Files already downloaded and verified
Net(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
epoch: 1 |i: 2000 |loss:2.156
epoch: 1 |i: 4000 |loss:1.826
epoch: 1 |i: 6000 |loss:1.693
epoch: 1 |i: 8000 |loss:1.574
epoch: 1 |i: 10000 |loss:1.527
epoch: 1 |i: 12000 |loss:1.476
epoch: 2 |i: 2000 |loss:1.391
epoch: 2 |i: 4000 |loss:1.360
epoch: 2 |i: 6000 |loss:1.349
epoch: 2 |i: 8000 |loss:1.326
epoch: 2 |i: 10000 |loss:1.300
epoch: 2 |i: 12000 |loss:1.291
finish training
time: 217.39351534843445
The accuracy in the test set is :54%
We can already see the effect , Accuracy rate is 54%, But this is only part of the picture , Let's see the effect on the whole tester .
The accuracy of training is much higher than that of random guess ( Accuracy rate is 10%) good , Prove that the network has indeed learned something .
stay GPU Training
Just like before Tensor from CPU go to GPU equally , Models can also be obtained from CPU go to GPU.
if t.cuda.is_available():
net.cuda()
images=images.cuda()
labels=labels.cuda()
output=net(Variable(images))
loss=criterion(output,Variable(labels))
边栏推荐
- Practical guidance for interface automation testing (Part I): what preparations should be made for interface automation
- A case of gradually analyzing the splitting of classes -- colorful ball collisions
- SAP 中国本地化内容汇总
- 接口自动化测试实践指导(上):接口自动化需要做哪些准备工作
- Redis configuration cache expiration listening event trigger
- 国产ERP有没有机会击败SAP ?
- 3D高级渲染器:Artlantis studio 2021.2中文版
- 12_ UE4 advanced_ Change a more beautiful character model
- How to realize multi line annotation in MATLAB
- Introduction and advanced level of MySQL (12)
猜你喜欢

How to solve the time zone problem in MySQL timestamp

What if MySQL forgets the password

Watermelon book learning Chapter 6 -- SVM

Rongyun IM & RTC capabilities on new sites

Redis之sentinel哨兵集群怎么部署

Idea configuration web container and war packaging

How dare you write a resume that is proficient in concurrent programming? Why do you use a two-way linked list in AQS?

Configure vscade to realize ROS writing

12_ UE4 advanced_ Change a more beautiful character model
![LeetCode 1331 数组序号转换[Map] HERODING的LeetCode之路](/img/be/d429d0c437dc5ed7cb4448e223a83a.png)
LeetCode 1331 数组序号转换[Map] HERODING的LeetCode之路
随机推荐
原理知识用得上
Singleton and invariant modes of concurrent mode
Calculation of array serial number of force deduction questions (daily question 7/28)
During the year, the first "three consecutive falls" of No. 95 gasoline returned to the "8 Yuan era"“
Matlab learning - accumulation of small knowledge points
Photo scale correction tool: DxO viewpoint 3 direct mount version
Incremental real-time disaster recovery notes
【科技1】
Configure vscade to realize ROS writing
Division and description of military technical documents
What is eplato cast by Plato farm on elephant swap? Why is there a high premium?
C traps and defects Chapter 3 semantic "traps" 3.8 operators &, |, and!
Tencent cloud logs in with PEM
数字图像处理 第10章——图像分割
基于单片机烟雾温湿度甲醛监测设计
STC单片机驱动1.8‘TFT SPI屏幕演示示例(含资料包)
04 | background login: login method based on account and password (Part 1)
A simple and general method to obtain the size of function stack space
Redis configuration cache expiration listening event trigger
Practical guidance for interface automation testing (Part I): what preparations should be made for interface automation