当前位置:网站首页>Problems in the implementation of lenet
Problems in the implementation of lenet
2022-07-03 09:09:00 【weixin_ thirty-seven million six hundred and eighty-two thousan】
First of all LeNet Implementation of network class
from torchvision import datasets, transforms
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv_layers = nn.Sequential(
# [b, 3, 32, 32] batch, rgb3 individual channel
nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, stride=1, padding=0),
# [b, 3, 32, 32] Through the convolution layer -> [b, 6, 28,28]
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
# [b,6,28,28] after pooling layer -> [b,6, 14,14]
nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
# [b,6,14,14] Through the convolution layer -> [b,16,10,10]
nn.AvgPool2d(kernel_size=2,stride=2,padding=0)
# [b,16,10,10] after pooling layer -> [b,16,5,5]
)
# flatten # [b,16,5,5] -> [16*5*5]
# FC
self.fc_layers = nn.Sequential(
nn.Linear(in_features=16*5*5, out_features=120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84, 10)
)
def forward(self, x):
batchsz = x.size(0)
x = self.conv_layers(x)
x = x.reshape(batchsz, 16*5*5)
# [b,16,5,5] -> [b, 16*5*5]
logit = self.fc_layers(x)
return logit
The problems encountered in the middle are ,
- Not written super(LeNet, self).init() reason : incomprehension super The role of , Its function is to create a LeNet Class , Then convert this object to its parent class nn.Module Example , Then call the __init__() Method to initialize
- the last one pooling Layer connection fc Layers will not be calculated reason : Do not understand the network batch Dimension is beyond other hyperparameters of the network , In fact, the parameters in the network are all for one 32x32 In terms of pictures , and batch How many pictures are training at the same time . and pooling Of 16 individual channel, Every channel 5x5 go to 120 Of fc Layer is [16,5,5] Draw to [1655] Or rather, [batchsize,16,5,5] -》 [batchsize,1655]
- To obtain the length of the first dimension of a certain layer of the network, you can use x.size(0)
The following is the training code
from torchvision import datasets, transforms
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from lenet import LeNet
batchsz = 32
cifar_train = datasets.CIFAR10('../cifar', train=True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('../cifar', train=False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
model = LeNet()
# device = torch.device('cuda')
# model = LeNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
print(model)
model.train()
for epoch in range(10):
for batchidx, (x, label) in enumerate(cifar_train):
# x [b,3,32,32] label [b]
# x, label = x.to(device), label.to(device)
# Be careful CrossEntropyLoss() Include softmax, So the input must be logits, No pred
logits = model(x)
loss = criterion(logits, label)
# label [b] logits [b,10]
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("epoch", epoch, "loss:", loss.item())
#test
model.eval()
correct_num = 0
total_num = 0
with torch.no_grad():
for (x, label) in cifar_test:
# x[b,3,32,32] label [b]
logits = model(x)
# logits [b,10]
pred = torch.argmax(logits, dim=1)
correct_num += torch.eq(pred, label).float().sum().item()
total_num += x.size(0)
print('test accuarcy:', correct_num/total_num)
There have been problems
- It's written by the guide package import LeNet, In that case, it should be written below model=lenet.Lenet(). If write from lenet import LeNet Now you can write model=LeNet()
- Wrong writing loss = criterion(label, logits) CrossEntorpyLoss() Order is important , The predicted value should be before
- test Inside with torch.no_grad() Wrong writing with model.zero_grad()
Last , Running results
D:\Anaconda\envs\pytorch\python.exe "D:/PycharmProjects/pythonProjecet_pytorch/LeNet & ResNet/LeNet/train.py"
Files already downloaded and verified
Files already downloaded and verified
LeNet(
(conv_layers): Sequential(
(0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(1): AvgPool2d(kernel_size=2, stride=2, padding=0)
(2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(3): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(fc_layers): Sequential(
(0): Linear(in_features=400, out_features=120, bias=True)
(1): ReLU()
(2): Linear(in_features=120, out_features=84, bias=True)
(3): ReLU()
(4): Linear(in_features=84, out_features=10, bias=True)
)
)
epoch 0 loss: 1.37169349193573
test accuarcy: 0.4542
epoch 1 loss: 1.9433141946792603
test accuarcy: 0.4866
epoch 2 loss: 2.1775083541870117
test accuarcy: 0.5228
epoch 3 loss: 1.9912444353103638
test accuarcy: 0.5261
epoch 4 loss: 1.410996913909912
test accuarcy: 0.5381
epoch 5 loss: 0.9067156314849854
test accuarcy: 0.5481
epoch 6 loss: 1.1418999433517456
test accuarcy: 0.53
epoch 7 loss: 1.032881498336792
test accuarcy: 0.5518
epoch 8 loss: 1.4319336414337158
test accuarcy: 0.5499
epoch 9 loss: 0.9463621377944946
test accuarcy: 0.5431
Process finished with exit code 0
边栏推荐
- Summary of methods for counting the number of file lines in shell scripts
- Facial expression recognition based on pytorch convolution -- graduation project
- LeetCode 535. Encryption and decryption of tinyurl
- Format - C language project sub file
- 求组合数 AcWing 885. 求组合数 I
- Gaussian elimination acwing 883 Gauss elimination for solving linear equations
- Sword finger offer II 029 Sorted circular linked list
- [point cloud processing paper crazy reading frontier version 10] - mvtn: multi view transformation network for 3D shape recognition
- Apache startup failed phpstudy Apache startup failed
- Excel is not as good as jnpf form for 3 minutes in an hour. Leaders must praise it when making reports like this!
猜你喜欢
[point cloud processing paper crazy reading frontier version 10] - mvtn: multi view transformation network for 3D shape recognition
Education informatization has stepped into 2.0. How can jnpf help teachers reduce their burden and improve efficiency?
AcWing 788. Number of pairs in reverse order
Divide candy (circular queue)
数位统计DP AcWing 338. 计数问题
LeetCode 515. Find the maximum value in each tree row
推荐一个 yyds 的低代码开源项目
AcWing 787. 归并排序(模板)
Basic knowledge of network security
LeetCode 513. 找树左下角的值
随机推荐
樹形DP AcWing 285. 沒有上司的舞會
Divide candy (circular queue)
【点云处理之论文狂读经典版9】—— Pointwise Convolutional Neural Networks
Binary tree traversal (first order traversal. Output results according to first order, middle order, and last order)
In the digital transformation, what problems will occur in enterprise equipment management? Jnpf may be the "optimal solution"
too many open files解决方案
How to place the parameters of the controller in the view after encountering the input textarea tag in the TP framework
LeetCode 1089. Duplicate zero
【点云处理之论文狂读前沿版8】—— Pointview-GCN: 3D Shape Classification With Multi-View Point Clouds
PHP mnemonic code full text 400 words to extract the first letter of each Chinese character
Sword finger offer II 029 Sorted circular linked list
State compression DP acwing 291 Mondrian's dream
【点云处理之论文狂读经典版11】—— Mining Point Cloud Local Structures by Kernel Correlation and Graph Pooling
浅谈企业信息化建设
状态压缩DP AcWing 91. 最短Hamilton路径
The method of replacing the newline character '\n' of a file with a space in the shell
Solution of 300ms delay of mobile phone
LeetCode 513. Find the value in the lower left corner of the tree
On a un nom en commun, maître XX.
22-05-26 Xi'an interview question (01) preparation