当前位置:网站首页>《PyTorch深度学习实践》第十课(卷积神经网络CNN)
《PyTorch深度学习实践》第十课(卷积神经网络CNN)
2022-08-05 05:40:00 【falldeep】
b站刘二视频,地址:
《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili
CNN模型
卷积运算

卷积核运算
黄色方块为fliter(卷积核n * 3* 3),要想输出通道数为m,需要m个卷积核


import torch
in_channel, out_channel = 5, 10 #输入通道数,输出通道数(图层数)
width, height = 100, 100 #输入一张图层的大小
kernel_size = 3 #卷积核的大小(3 * 3)
batch_size = 1
input = torch.randn(batch_size, in_channel, width, height)
# B N W H
conv_layer = torch.nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size)
# N M (3 * 3)
output = conv_layer(input)
print(input.shape)
print(output.shape)
print(conv_layer.weight.shape)
#输出结果
# torch.Size([1, 5, 100, 100])
# batch大小 通道数 一个图层的大小
# torch.Size([1, 10, 98, 98])
# torch.Size([10, 5, 3, 3])
#10个卷积核 每个卷积核有5个通道 卷积核大小为3 * 3
padding
保持输出图像大小不变,进行零填充
stride
跳一格扫描

maxpooling最大池化层

网络整体


作业,手写MNIST识别
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
BATCH_SIZE = 64
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_set = datasets.MNIST(download=False, root='mnist', train=True, transform=transforms)
train_loader = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True)
test_set = datasets.MNIST(download=False, root='mnist', train=False, transform=transforms)
test_loader = DataLoader(dataset=test_set, batch_size=BATCH_SIZE, shuffle=False)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
self.pooling = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(320, 10)
def forward(self, x):
batch_size = x.size(0)
x = F.relu(self.pooling(self.conv1(x)))
x = F.relu(self.pooling(self.conv2(x)))
x = x.view(batch_size, -1)
x = self.fc(x)
return x
modle = Net()
criteration = torch.nn.CrossEntropyLoss()
optimizor = torch.optim.SGD(modle.parameters(), lr=0.01, momentum=0.5)
def train():
sum = 0
for i, data in enumerate(train_loader, 0):
inputs, lables = data
y_pred = modle(inputs)
loss = criteration(y_pred, lables)
sum += loss
optimizor.zero_grad()
loss.backward()
optimizor.step()
if(i % 300 == 299):
sum /= 300
loss_lst.append(sum)
sum = 0
def test():
correct = 0
totall = 0
with torch.no_grad():
for i, data in enumerate(test_loader, 0):
inputs, lables = data
y_pred = modle(inputs)
_, predicted = torch.max(y_pred, dim=-1)
correct += (lables == predicted).sum().item()
totall += lables.size(0)
acc_lst.append(correct / totall * 100)
if __name__ == '__main__':
loss_lst = []
acc_lst = []
for epoch in range(10):
train()
test()
num_lst = [i for i in range(len(loss_lst))]
plt.plot(num_lst, loss_lst)
plt.xlabel("i")
plt.ylabel("loss")
plt.show()
num_lst = [i for i in range(len(acc_lst))]
plt.plot(num_lst, acc_lst)
plt.xlabel("epoch")
plt.ylabel("acc")
plt.show()
边栏推荐
猜你喜欢

Successful indie developers deal with failure & imposters

Transformer详细解读与预测实例记录

深夜小酌,50道经典SQL题,真香~

微信小程序仿input组件、虚拟键盘

The future of cloud gaming

NACOS Configuration Center Settings Profile

Quick Start to Drools Rule Engine (1)

Met with the browser page

Configuration of routers and static routes

Collision, character controller, Cloth components (cloth), joints in the Unity physics engine
随机推荐
指针常量与常量指针 巧记
在小程序中关于js数字精度丢失的解决办法
Drools规则引擎快速入门(一)
系统基础-学习笔记(一些命令记录)
七夕!专属于程序员的浪漫表白
config.js related configuration summary
reduce()方法的学习和整理
System basics - study notes (some command records)
docker部署完mysql无法连接
淘宝宝贝页面制作
Email management Filter emails
In-depth analysis if according to data authority @datascope (annotation + AOP + dynamic sql splicing) [step by step, with analysis process]
Matplotlib绘图笔记
【FAQ】CCAPI兼容EOS相机列表(2022年8月 更新)
DevOps流程demo(实操记录)
Pytorch distributed parallel processing
Nacos集群的搭建过程详解
八大排序之快速排序
Transformer interprets and predicts instance records in detail
scikit-image image processing notes
