当前位置:网站首页>22.卷积神经网络实战-Lenet5
22.卷积神经网络实战-Lenet5
2022-08-02 00:14:00 【派大星的最爱海绵宝宝】
CIFAR10数据集介绍
10类,每一类有6000张照片,50000张training,10000张test。
实例
从datasets包中加载数据集,使用transforms包进行变换,通过resize获取图片维度,再把图片转换成tensor,因为pytorch的数据类型都是tensor。
cifar_train一次加载一张,我们需要使用DataLoader加载一次一批。直接覆盖写即可。
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
通过iter方法得到DataLoader的迭代器,再使用迭代器的next方法得到一个batch。
接下来新建一个类,lenet5,卷积神经网络的最简单的一个版本。
第一层是卷积层,第一个卷积层输入维度是照片的维度,cifar是彩色照片。
第二层是subsampling,是一个pooling层。
又一个卷积层,pooling不改变channel,输入依然是6。
全连接层时,输入维度是4维的,我们需要打平,但是pytorch中没有自带的FLatten函数,但是Sequential中需要写既有的类,所以我们写两个unit。
self.conv_unit=nn.Sequential(
# x:[b,3,32,32]->[b,6,]
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
#
nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
#
)
打平
打平后,看结构图是120层,全连接层是Linear。激活函数一般选择sigmod和relu,而sigmod会出现梯度离散现象,所有选择relu。
我们计算一些输入输出值,使用一个例子tmp,送入第一个unit运行。
self.fc_unit=nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
# x:[b,3,32,32]
tmp=torch.randn(2,3,32,32)
out=self.conv_unit(tmp)
# x:[b,16,5,5]
print('conv_out:',out.shape)
每一个网络结构都需要一个forward前向计算,且不需要backward,自动会有。
def forward(self,x):
batchsz=x.size(0)
# x:[b,3,32,32]->[b,16,5,5]
x=self.conv_unit(x)
# [b,16,5,5]->[b,16*5*5]
x=x.view(batchsz,16*5*5)
# [b,16*5*5]->[b,10]
logits=self.fc_unit(x)
# pred=F.softmax(logits,dim=1)
#y是我们的输出,需要另外引入
# loss=self.criteon(logits,y)
return logits
我们取名字叫logits,一般在经过softmax之前的数叫logits。pred和logits的区别在于pred是logits经过softmax操作。
使用loss,我们这是个分类问题,通常使用cross entropy loss。
softmax和loss的操作叫做CELoss
nn上面的类是大写的,F上面的类是小写,两者的区别是nn上面的类先要初始化一下,再在forward里面调用,F里面的类是直接运行的函数,我们可以直接代入数值使用。
total_correct+=torch.eq(pred,label).float().sum().item()
eq函数进行对比,[2 1 1 2 1]的转置与][2 0 0 1 2]的转置进行对比,答案是[1 0 0 0 1]的转置,float后再相加可得2。
完整代码
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn,optim
from lenet5 import Lenet5
def main():
batchsz = 32
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x,label=iter(cifar_train).next()
print('x:',x.shape,'label:',label.shape)
device = torch.device('cuda')
model=Lenet5().to(device)
criteon=nn.CrossEntropyLoss().to(device)
optimizer=optim.Adam(model.parameters(),lr=1e-3)
print('model:',model)
for epoch in range(1000):
model.train()
for batchsz,(x,label) in enumerate(cifar_train):
#x:[b,3,32,32]
#[b]
x,label=x.to(device),label.to(device)
logits=model(x)
#logits:[b,10]需要给出概率。
#label:[b]
#loss:tensor scalar
loss=criteon(logits,label)
#backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
#
print(epoch,loss.item())
model.eval()
with torch.no_grad():
#test
total_correct=0
total_num=0
for x,label in cifar_test:
x, label = x.to(device), label.to(device)
#[b,10]
logits=model(x)
#取logits最大的点作为pred:[b]
pred=logits.argmax(dim=1)
#[b] vs [b]-> scalar tensor
total_correct+=torch.eq(pred,label).float().sum().item()
total_num+=x.size(0)
acc = total_correct/total_num
print('epoch,acc:',epoch ,acc)
if __name__ == '__main__':
main()
import torch
from torch import nn
from torch.nn import functional as F
class Lenet5(nn.Module):
def __init__(self):
super(Lenet5, self).__init__()
self.conv_unit=nn.Sequential(
# x:[b,3,32,32]->[b,6,]
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
#
nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
#
)
# flatten
# fc unit
self.fc_unit=nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
# x:[b,3,32,32]
tmp=torch.randn(2,3,32,32)
out=self.conv_unit(tmp)
# x:[b,16,5,5]
print('conv_out:',out.shape)
# use cross entroy loss
# use mean square error:self.criteon=nn.MSELoss()
self.criteon=nn.CrossEntropyLoss()
def forward(self,x):
batchsz=x.size(0)
# x:[b,3,32,32]->[b,16,5,5]
x=self.conv_unit(x)
# [b,16,5,5]->[b,16*5*5]
x=x.view(batchsz,16*5*5)
# [b,16*5*5]->[b,10]
logits=self.fc_unit(x)
# pred=F.softmax(logits,dim=1)
#y是我们的输出,需要另外引入
# loss=self.criteon(logits,y)
return logits
def main():
net=Lenet5()
tmp = torch.randn(2, 3, 32, 32)
out = net(tmp)
print('lenet_out:', out.shape)
if __name__ == '__main__':
main()
边栏推荐
- 【CodeTON Round 2 (Div. 1 + Div. 2, Rated, Prizes!)(A~D)】
- [Solution] Emqx startup under win10 reports Unable to load emulator DLL, node.db_role = EMQX_NODE__DB_ROLE = core
- After reshipment tencent greetings to monitor if the corresponding service does not exist by sc. Exe command to add services
- Constructor, this keyword, method overloading, local variables and member variables
- uni-app project summary
- JS中localStorage和sessionStorage
- Realize deletion - a specified letter in a string, such as: the string "abcd", delete the "a" letter in it, the remaining "bcd", you can also pass multiple characters to be deleted, and pass "ab" can
- Knowing the inorder traversal of the array and the preorder traversal of the array, return the postorder history array
- Task execution control in Ansible
- go笔记——锁
猜你喜欢
Short video seo search optimization main content
攻防世界-web-Training-WWW-Robots
不要用jOOQ串联字符串
nodeJs--mime模块
Interview high-frequency test questions solution - stack push and pop sequence, effective parentheses, reverse Polish expression evaluation
这 4 款电脑记事本软件,得试试
Redis - message publish and subscribe
Day11 shell脚本基础知识
bgp 聚合 反射器 联邦实验
期货开户手续费加一分是主流
随机推荐
信息物理系统状态估计与传感器攻击检测
Constructor, this keyword, method overloading, local variables and member variables
PHP to read data from TXT file
Statement执行update语句
Knowing the inorder traversal of the array and the preorder traversal of the array, return the postorder history array
攻防世界-web-Training-WWW-Robots
C语言函数详解(1)【库函数与自定义函数】
鲲鹏编译调试插件实战
What is the function of the JSP Taglib directive?
bgp 聚合 反射器 联邦实验
uni-app project summary
Pytorch seq2seq 模型架构实现英译法任务
MYSQL(基本篇)——一篇文章带你走进MYSQL的奇妙世界
go笔记——map
JSP how to obtain the path information in the request object?
An Enhanced Model for Attack Detection of Industrial Cyber-Physical Systems
2022/08/01 学习笔记 (day21) 泛型和枚举
JS中的防抖和节流
How to use the go language standard library fmt package
BGP综合实验 建立对等体、路由反射器、联邦、路由宣告及聚合