当前位置:网站首页>Detailed explanation of the training and prediction process of deep learning [taking lenet model and cifar10 data set as examples]
Detailed explanation of the training and prediction process of deep learning [taking lenet model and cifar10 data set as examples]
2022-07-25 13:07:00 【1 + 1= Wang】
List of articles
Introduction to models and datasets
Model :LeNet
Lenet It's a 7 Layer of neural network ( Does not include the input layer ), contain 3 Convolution layers ,2 A pool layer ,2 All connection layers .
Use pytorch The structure is as follows :
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 5) # C1
self.pool1 = nn.MaxPool2d(2, 2) # S2
self.conv2 = nn.Conv2d(16, 32, 5) # C3
self.pool2 = nn.MaxPool2d(2, 2) # S4
self.fc1 = nn.Linear(32*5*5, 120) # C5( Replace with full connection )
self.fc2 = nn.Linear(120, 84) # F6
self.fc3 = nn.Linear(84, 10) # F7
def forward(self, x):
x = F.relu(self.conv1(x)) # input(3, 32, 32) output(16, 28, 28)
x = self.pool1(x) # output(16, 14, 14)
x = F.relu(self.conv2(x)) # output(32, 10, 10)
x = self.pool2(x) # output(32, 5, 5)
x = x.view(-1, 32*5*5) # output(32*5*5) Flattening
x = F.relu(self.fc1(x)) # output(120)
x = F.relu(self.fc2(x)) # output(84)
x = self.fc3(x) # output(10)
return x
model = LeNet()
print(model)

Data sets :CIFAR10
Download address :https://tensorflow.google.cn/datasets/catalog/cifar10
CIFAR10 Data set co ownership 60000 A color image , among 50000 Zhang is used for training ,5 Training batch , Every batch 10000 Pictures ;10000 Zhang for testing .
The size of the picture is 3X32X32, It is divided into 10 Categories , Each class 6000 Zhang .
Training process
The training of the model can be divided into the following steps :
- Dataset loading
- Model loading
- Iterative training
- verification
The following is a detailed analysis combined with the code :
1. Loading data and normalization
# Data normalization
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Load training dataset
# Use it for the first time download Set to True To download the dataset automatically
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
shuffle=True, num_workers=0)
# Load test validation data set
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,
shuffle=False, num_workers=0)
val_data_iter = iter(val_loader)
# val_image, val_label Respectively represent the original image and the label corresponding to the image ( Category )
# The training set should also be split during training
val_image, val_label = val_data_iter.next()
2. Load model
# Introducing models
net = LeNet()
# Define the loss function
loss_function = nn.CrossEntropyLoss()
# Define optimizer , Input model parameters and learning rate lr
optimizer = optim.Adam(net.parameters(), lr=0.001)
3. Iterative training
for epoch in range(100): # iteration 100 Time
running_loss = 0.0 # Set the initial loss to 0
for step, data in enumerate(train_loader, start=0):
# inputs, labels Respectively represent the original image and the label corresponding to the image ( Category )
inputs, labels = data
# Every batch Set the gradient information to 0
# ( You can have more than one batch Call it once optimizer.zero_grad function . This is equivalent to increasing batch_size)
optimizer.zero_grad()
# Transfer the original picture into the model , Get the prediction
outputs = net(inputs)
# Calculate the loss with the predicted result and the original label
loss = loss_function(outputs, labels)
# Back propagation
loss.backward()
# Update parameters
optimizer.step()
4. verification
for epoch in range(100): # iteration 100 Time
running_loss = 0.0 # Set the initial loss to 0
for step, data in enumerate(train_loader, start=0):
# inputs, labels Respectively represent the original image and the label corresponding to the image ( Category )
inputs, labels = data
# Every batch Set the gradient information to 0
# ( You can have more than one batch Call it once optimizer.zero_grad function . This is equivalent to increasing batch_size)
optimizer.zero_grad()
# Transfer the original picture into the model , Get the prediction
outputs = net(inputs)
# Calculate the loss with the predicted result and the original label
loss = loss_function(outputs, labels)
# Back propagation
loss.backward()
# Update parameters
optimizer.step()
# Calculate the total training loss
running_loss += loss.item()
# The above is the training process , Verify from here
##########################################
with torch.no_grad(): # Verification is to stop calculating the gradient
# Pass in the original image of the validation set to the model , Get the prediction
outputs = net(val_image)
predict_y = torch.max(outputs, dim=1)[1]
# Calculate the same number of predicted values and tag values , And divide by the total , Get accuracy
accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)
# Print training losses and accuracy
print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
(epoch + 1, step + 1, running_loss / 500, accuracy))
# Reset loss is 0, Start the next iteration training
running_loss = 0.0
print('Finished Training')
# End of training , Save the model
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)
边栏推荐
- CONDA common commands: install, update, create, activate, close, view, uninstall, delete, clean, rename, change source, problem
- 卷积神经网络模型之——AlexNet网络结构与代码实现
- 程序员奶爸自制AI喂奶检测仪,预判宝宝饿点,不让哭声影响老婆睡眠
- Simple understanding of flow
- 跌荡的人生
- Selenium uses -- XPath and analog input and analog click collaboration
- [problem solving] ibatis.binding BindingException: Type interface xxDao is not known to the MapperRegistry.
- 程序的内存布局
- 卷积神经网络模型之——GoogLeNet网络结构与代码实现
- Shell常用脚本:获取网卡IP地址
猜你喜欢

Shell常用脚本:判断远程主机的文件是否存在

Microsoft proposed CodeT: a new SOTA for code generation, with 20 points of performance improvement

EMQX Cloud 更新:日志分析增加更多参数,监控运维更省心

基于JEECG制作一个通用的级联字典选择控件-DictCascadeUniversal

交换机链路聚合详解【华为eNSP】

Common operations for Yum and VIM

Machine learning strong foundation program 0-4: popular understanding of Occam razor and no free lunch theorem

ECCV2022 | TransGrasp类级别抓取姿态迁移

Mid 2022 review | latest progress of large model technology Lanzhou Technology

Want to go whoring in vain, right? Enough for you this time!
随机推荐
零基础学习CANoe Panel(13)—— 滑条(TrackBar )
Seven lines of code made station B crash for three hours, but "a scheming 0"
485通讯( 详解 )
Force deduction 83 biweekly T4 6131. The shortest dice sequence impossible to get, 303 weeks T4 6127. The number of high-quality pairs
Shell common script: get the IP address of the network card
[problem solving] org.apache.ibatis.exceptions PersistenceException: Error building SqlSession. 1-byte word of UTF-8 sequence
机器学习强基计划0-4:通俗理解奥卡姆剃刀与没有免费午餐定理
【OpenCV 例程 300篇】239. Harris 角点检测之精确定位(cornerSubPix)
clickhouse笔记03-- Grafana 接入ClickHouse
[300 opencv routines] 239. accurate positioning of Harris corner detection (cornersubpix)
The world is exploding, and the Google server has collapsed
Zero basic learning canoe panel (16) -- clock control/panel control/start stop control/tab control
OAuth, JWT, oidc, you mess me up
I want to ask whether DMS has the function of regularly backing up a database?
Azure Devops(十四) 使用Azure的私有Nuget仓库
How to understand metrics in keras
Shell common script: check whether a domain name and IP address are connected
吕蒙正《破窑赋》
Vim技巧:永远显示行号
“蔚来杯“2022牛客暑期多校训练营2 补题题解(G、J、K、L)