当前位置:网站首页>《PyTorch深度学习实践》第十课(卷积神经网络CNN)
《PyTorch深度学习实践》第十课(卷积神经网络CNN)
2022-08-05 05:40:00 【falldeep】
b站刘二视频,地址:
《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili
CNN模型
卷积运算
卷积核运算
黄色方块为fliter(卷积核n * 3* 3),要想输出通道数为m,需要m个卷积核
import torch
in_channel, out_channel = 5, 10 #输入通道数,输出通道数(图层数)
width, height = 100, 100 #输入一张图层的大小
kernel_size = 3 #卷积核的大小(3 * 3)
batch_size = 1
input = torch.randn(batch_size, in_channel, width, height)
# B N W H
conv_layer = torch.nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size)
# N M (3 * 3)
output = conv_layer(input)
print(input.shape)
print(output.shape)
print(conv_layer.weight.shape)
#输出结果
# torch.Size([1, 5, 100, 100])
# batch大小 通道数 一个图层的大小
# torch.Size([1, 10, 98, 98])
# torch.Size([10, 5, 3, 3])
#10个卷积核 每个卷积核有5个通道 卷积核大小为3 * 3
padding
保持输出图像大小不变,进行零填充
stride
跳一格扫描
maxpooling最大池化层
网络整体
作业,手写MNIST识别
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
BATCH_SIZE = 64
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_set = datasets.MNIST(download=False, root='mnist', train=True, transform=transforms)
train_loader = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True)
test_set = datasets.MNIST(download=False, root='mnist', train=False, transform=transforms)
test_loader = DataLoader(dataset=test_set, batch_size=BATCH_SIZE, shuffle=False)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
self.pooling = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(320, 10)
def forward(self, x):
batch_size = x.size(0)
x = F.relu(self.pooling(self.conv1(x)))
x = F.relu(self.pooling(self.conv2(x)))
x = x.view(batch_size, -1)
x = self.fc(x)
return x
modle = Net()
criteration = torch.nn.CrossEntropyLoss()
optimizor = torch.optim.SGD(modle.parameters(), lr=0.01, momentum=0.5)
def train():
sum = 0
for i, data in enumerate(train_loader, 0):
inputs, lables = data
y_pred = modle(inputs)
loss = criteration(y_pred, lables)
sum += loss
optimizor.zero_grad()
loss.backward()
optimizor.step()
if(i % 300 == 299):
sum /= 300
loss_lst.append(sum)
sum = 0
def test():
correct = 0
totall = 0
with torch.no_grad():
for i, data in enumerate(test_loader, 0):
inputs, lables = data
y_pred = modle(inputs)
_, predicted = torch.max(y_pred, dim=-1)
correct += (lables == predicted).sum().item()
totall += lables.size(0)
acc_lst.append(correct / totall * 100)
if __name__ == '__main__':
loss_lst = []
acc_lst = []
for epoch in range(10):
train()
test()
num_lst = [i for i in range(len(loss_lst))]
plt.plot(num_lst, loss_lst)
plt.xlabel("i")
plt.ylabel("loss")
plt.show()
num_lst = [i for i in range(len(acc_lst))]
plt.plot(num_lst, acc_lst)
plt.xlabel("epoch")
plt.ylabel("acc")
plt.show()
边栏推荐
- 文件内音频的时长统计并生成csv文件
- Error correction notes for the book Image Processing, Analysis and Machine Vision
- 在小程序中关于js数字精度丢失的解决办法
- Some basic method records of commonly used languages in LeetCode
- Detailed explanation of the construction process of Nacos cluster
- lingo入门——河北省第三届研究生建模竞赛B题
- 边缘盒子+时序数据库,美的数字化平台 iBUILDING 背后的技术选型
- selenium learning
- 更改小程序原生radio的颜色及大小
- 盒子模型小练习
猜你喜欢
随机推荐
Q 2020, the latest senior interview Laya soul, do you know?
The cocos interview answers you are looking for are all here!
超简单的白鹭egret项目添加图片详细教程
D39_Eulerian Angles and Quaternions
Redis的使用
格式化代码缩进的小技巧
Nacos集群搭建
Come, come, let you understand how Cocos Creator reads and writes JSON files
js判断文字是否超过区域
【8】Docker中部署Redis
边缘盒子+时序数据库,美的数字化平台 iBUILDING 背后的技术选型
网络协议基础-学习笔记
MySQL的主从模式搭建
D39_Vector
Writing OpenCV in VSCode
uniapp打包次数限制怎么办?只需两步就能解决
LeetCode练习及自己理解记录(1)
lingo入门——河北省第三届研究生建模竞赛B题
D46_Force applied to rigid body
LeetCode practice and self-comprehension record (1)