当前位置:网站首页>22.卷积神经网络实战-Lenet5
22.卷积神经网络实战-Lenet5
2022-08-02 00:14:00 【派大星的最爱海绵宝宝】
CIFAR10数据集介绍
10类,每一类有6000张照片,50000张training,10000张test。
实例
从datasets包中加载数据集,使用transforms包进行变换,通过resize获取图片维度,再把图片转换成tensor,因为pytorch的数据类型都是tensor。
cifar_train一次加载一张,我们需要使用DataLoader加载一次一批。直接覆盖写即可。
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
通过iter方法得到DataLoader的迭代器,再使用迭代器的next方法得到一个batch。
接下来新建一个类,lenet5,卷积神经网络的最简单的一个版本。
第一层是卷积层,第一个卷积层输入维度是照片的维度,cifar是彩色照片。
第二层是subsampling,是一个pooling层。
又一个卷积层,pooling不改变channel,输入依然是6。
全连接层时,输入维度是4维的,我们需要打平,但是pytorch中没有自带的FLatten函数,但是Sequential中需要写既有的类,所以我们写两个unit。
self.conv_unit=nn.Sequential(
# x:[b,3,32,32]->[b,6,]
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
#
nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
#
)
打平
打平后,看结构图是120层,全连接层是Linear。激活函数一般选择sigmod和relu,而sigmod会出现梯度离散现象,所有选择relu。
我们计算一些输入输出值,使用一个例子tmp,送入第一个unit运行。
self.fc_unit=nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
# x:[b,3,32,32]
tmp=torch.randn(2,3,32,32)
out=self.conv_unit(tmp)
# x:[b,16,5,5]
print('conv_out:',out.shape)

每一个网络结构都需要一个forward前向计算,且不需要backward,自动会有。
def forward(self,x):
batchsz=x.size(0)
# x:[b,3,32,32]->[b,16,5,5]
x=self.conv_unit(x)
# [b,16,5,5]->[b,16*5*5]
x=x.view(batchsz,16*5*5)
# [b,16*5*5]->[b,10]
logits=self.fc_unit(x)
# pred=F.softmax(logits,dim=1)
#y是我们的输出,需要另外引入
# loss=self.criteon(logits,y)
return logits
我们取名字叫logits,一般在经过softmax之前的数叫logits。pred和logits的区别在于pred是logits经过softmax操作。
使用loss,我们这是个分类问题,通常使用cross entropy loss。
softmax和loss的操作叫做CELoss
nn上面的类是大写的,F上面的类是小写,两者的区别是nn上面的类先要初始化一下,再在forward里面调用,F里面的类是直接运行的函数,我们可以直接代入数值使用。
total_correct+=torch.eq(pred,label).float().sum().item()
eq函数进行对比,[2 1 1 2 1]的转置与][2 0 0 1 2]的转置进行对比,答案是[1 0 0 0 1]的转置,float后再相加可得2。

完整代码
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn,optim
from lenet5 import Lenet5
def main():
batchsz = 32
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x,label=iter(cifar_train).next()
print('x:',x.shape,'label:',label.shape)
device = torch.device('cuda')
model=Lenet5().to(device)
criteon=nn.CrossEntropyLoss().to(device)
optimizer=optim.Adam(model.parameters(),lr=1e-3)
print('model:',model)
for epoch in range(1000):
model.train()
for batchsz,(x,label) in enumerate(cifar_train):
#x:[b,3,32,32]
#[b]
x,label=x.to(device),label.to(device)
logits=model(x)
#logits:[b,10]需要给出概率。
#label:[b]
#loss:tensor scalar
loss=criteon(logits,label)
#backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
#
print(epoch,loss.item())
model.eval()
with torch.no_grad():
#test
total_correct=0
total_num=0
for x,label in cifar_test:
x, label = x.to(device), label.to(device)
#[b,10]
logits=model(x)
#取logits最大的点作为pred:[b]
pred=logits.argmax(dim=1)
#[b] vs [b]-> scalar tensor
total_correct+=torch.eq(pred,label).float().sum().item()
total_num+=x.size(0)
acc = total_correct/total_num
print('epoch,acc:',epoch ,acc)
if __name__ == '__main__':
main()
import torch
from torch import nn
from torch.nn import functional as F
class Lenet5(nn.Module):
def __init__(self):
super(Lenet5, self).__init__()
self.conv_unit=nn.Sequential(
# x:[b,3,32,32]->[b,6,]
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
#
nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
#
)
# flatten
# fc unit
self.fc_unit=nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
# x:[b,3,32,32]
tmp=torch.randn(2,3,32,32)
out=self.conv_unit(tmp)
# x:[b,16,5,5]
print('conv_out:',out.shape)
# use cross entroy loss
# use mean square error:self.criteon=nn.MSELoss()
self.criteon=nn.CrossEntropyLoss()
def forward(self,x):
batchsz=x.size(0)
# x:[b,3,32,32]->[b,16,5,5]
x=self.conv_unit(x)
# [b,16,5,5]->[b,16*5*5]
x=x.view(batchsz,16*5*5)
# [b,16*5*5]->[b,10]
logits=self.fc_unit(x)
# pred=F.softmax(logits,dim=1)
#y是我们的输出,需要另外引入
# loss=self.criteon(logits,y)
return logits
def main():
net=Lenet5()
tmp = torch.randn(2, 3, 32, 32)
out = net(tmp)
print('lenet_out:', out.shape)
if __name__ == '__main__':
main()
边栏推荐
- Detailed explanation of JSP request object function
- uni-app project summary
- Mean Consistency Tracking of Time-Varying Reference Inputs for Multi-Agent Systems with Communication Delays
- Business test how to avoid missing?
- js中内存泄漏的几种情况
- DFS详解
- JS中对事件代理的理解及其应用场景
- JS中localStorage和sessionStorage
- 信息物理系统状态估计与传感器攻击检测
- CVPR 2022 | SharpContour:一种基于轮廓变形 实现高效准确实例分割的边缘细化方法
猜你喜欢

poker question

C language character and string function summary (2)

【HCIP】BGP小型实验(联邦,优化)

Redis 相关问题

Web开发

Business test how to avoid missing?

unity2D横版游戏教程5-UI

Short video SEO search operation customer acquisition system function introduction

08-SDRAM: Summary
![[HCIP] BGP Small Experiment (Federation, Optimization)](/img/a2/0967200c69cff3b683dc0af6f314c8.png)
[HCIP] BGP Small Experiment (Federation, Optimization)
随机推荐
bgp 聚合 反射器 联邦实验
IO stream basics
CRS management and maintenance
字符串分割函数strtok练习
C语言函数详解(1)【库函数与自定义函数】
146. LRU cache
GIF making - very simple one-click animation tool
Industrial control network intrusion detection based on automatic optimization of hyperparameters
Short video seo search optimization main content
Business test how to avoid missing?
els strip deformation
C语言实现扫雷游戏
After an incomplete recovery, the control file has been created or restored, the database must be opened with RESETLOGS, interpreting RESETLOGS.
【CodeTON Round 2 (Div. 1 + Div. 2, Rated, Prizes!)(A~D)】
信息物理系统状态估计与传感器攻击检测
鲲鹏编译调试插件实战
基于注意力机制的多特征融合人脸活体检测
ICML 2022 | GraphFM:通过特征Momentum提升大规模GNN的训练
els block deformation judgment.
JSP request对象功能详解说明