当前位置:网站首页>《PyTorch深度学习实践》第十课(卷积神经网络CNN)
《PyTorch深度学习实践》第十课(卷积神经网络CNN)
2022-08-05 05:40:00 【falldeep】
b站刘二视频,地址:
《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili
CNN模型
卷积运算

卷积核运算
黄色方块为fliter(卷积核n * 3* 3),要想输出通道数为m,需要m个卷积核


import torch
in_channel, out_channel = 5, 10 #输入通道数,输出通道数(图层数)
width, height = 100, 100 #输入一张图层的大小
kernel_size = 3 #卷积核的大小(3 * 3)
batch_size = 1
input = torch.randn(batch_size, in_channel, width, height)
# B N W H
conv_layer = torch.nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size)
# N M (3 * 3)
output = conv_layer(input)
print(input.shape)
print(output.shape)
print(conv_layer.weight.shape)
#输出结果
# torch.Size([1, 5, 100, 100])
# batch大小 通道数 一个图层的大小
# torch.Size([1, 10, 98, 98])
# torch.Size([10, 5, 3, 3])
#10个卷积核 每个卷积核有5个通道 卷积核大小为3 * 3
padding
保持输出图像大小不变,进行零填充
stride
跳一格扫描

maxpooling最大池化层

网络整体


作业,手写MNIST识别
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
BATCH_SIZE = 64
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_set = datasets.MNIST(download=False, root='mnist', train=True, transform=transforms)
train_loader = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True)
test_set = datasets.MNIST(download=False, root='mnist', train=False, transform=transforms)
test_loader = DataLoader(dataset=test_set, batch_size=BATCH_SIZE, shuffle=False)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
self.pooling = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(320, 10)
def forward(self, x):
batch_size = x.size(0)
x = F.relu(self.pooling(self.conv1(x)))
x = F.relu(self.pooling(self.conv2(x)))
x = x.view(batch_size, -1)
x = self.fc(x)
return x
modle = Net()
criteration = torch.nn.CrossEntropyLoss()
optimizor = torch.optim.SGD(modle.parameters(), lr=0.01, momentum=0.5)
def train():
sum = 0
for i, data in enumerate(train_loader, 0):
inputs, lables = data
y_pred = modle(inputs)
loss = criteration(y_pred, lables)
sum += loss
optimizor.zero_grad()
loss.backward()
optimizor.step()
if(i % 300 == 299):
sum /= 300
loss_lst.append(sum)
sum = 0
def test():
correct = 0
totall = 0
with torch.no_grad():
for i, data in enumerate(test_loader, 0):
inputs, lables = data
y_pred = modle(inputs)
_, predicted = torch.max(y_pred, dim=-1)
correct += (lables == predicted).sum().item()
totall += lables.size(0)
acc_lst.append(correct / totall * 100)
if __name__ == '__main__':
loss_lst = []
acc_lst = []
for epoch in range(10):
train()
test()
num_lst = [i for i in range(len(loss_lst))]
plt.plot(num_lst, loss_lst)
plt.xlabel("i")
plt.ylabel("loss")
plt.show()
num_lst = [i for i in range(len(acc_lst))]
plt.plot(num_lst, acc_lst)
plt.xlabel("epoch")
plt.ylabel("acc")
plt.show()
边栏推荐
- Media query, rem mobile terminal adaptation
- numpy.random使用文档
- Successful indie developers deal with failure & imposters
- VRRP overview and experiment
- Collection of error records (write down when you encounter them)
- scikit-image image processing notes
- Q 2020, the latest senior interview Laya soul, do you know?
- 超简单的白鹭egret项目添加图片详细教程
- input detailed file upload
- Quick Start to Drools Rule Engine (1)
猜你喜欢
随机推荐
LaTeX笔记
BIO, NIO, AIO practical study notes (easy to understand theory)
DevOps-了解学习
HelloWorld
Jenkins详细配置
D46_Force applied to rigid body
DevOps - Understanding Learning
D41_buffer pool
document.querySelector() method
cs231n学习记录
七夕!专属于程序员的浪漫表白
MyCat配置文件
DevOps process demo (practical record)
指针常量与常量指针 巧记
浏览器兼容汇总
八大排序之堆排序
## 简讲protobuf-从原理到使用
Chengyun Technology was invited to attend the 2022 Alibaba Cloud Partner Conference and won the "Gathering Strength and Going Far" Award
Tencent Internal Technology: Evolution of Server Architecture of "The Legend of Xuanyuan"
D39_Eulerian Angles and Quaternions









