当前位置:网站首页>pytorch学习记录(五):卷积神经网络的实现
pytorch学习记录(五):卷积神经网络的实现
2022-07-30 13:25:00 【狸狸Arina】
1. 数据集的加载
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn, optim
from lenet5 import Lenet5
from resnet import ResNet18
def main():
batchsz = 128
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x, label = iter(cifar_train).next()
print('x:', x.shape, 'label:', label.shape)
device = torch.device('cuda')
# model = Lenet5().to(device)
model = ResNet18().to(device)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
for epoch in range(1000):
model.train()
for batchidx, (x, label) in enumerate(cifar_train):
# [b, 3, 32, 32]
# [b]
x, label = x.to(device), label.to(device)
logits = model(x)
# logits: [b, 10]
# label: [b]
# loss: tensor scalar
loss = criteon(logits, label)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch, 'loss:', loss.item())
model.eval()
with torch.no_grad():
# test
total_correct = 0
total_num = 0
for x, label in cifar_test:
# [b, 3, 32, 32]
# [b]
x, label = x.to(device), label.to(device)
# [b, 10]
logits = model(x)
# [b]
pred = logits.argmax(dim=1)
# [b] vs [b] => scalar tensor
correct = torch.eq(pred, label).float().sum().item()
total_correct += correct
total_num += x.size(0)
# print(correct)
acc = total_correct / total_num
print(epoch, 'test acc:', acc)
if __name__ == '__main__':
main()
2. LeNet 实现
import torch
from torch import nn
from torch.nn import functional as F
class Lenet5(nn.Module):
""" for cifar10 dataset. """
def __init__(self):
super(Lenet5, self).__init__()
self.conv_unit = nn.Sequential(
# x: [b, 3, 32, 32] => [b, 16, ]
nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=0),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
#
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=0),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
#
)
# flatten
# fc unit
self.fc_unit = nn.Sequential(
nn.Linear(32*5*5, 32),
nn.ReLU(),
# nn.Linear(120, 84),
# nn.ReLU(),
nn.Linear(32, 10)
)
# [b, 3, 32, 32]
tmp = torch.randn(2, 3, 32, 32)
out = self.conv_unit(tmp)
# [b, 16, 5, 5]
print('conv out:', out.shape)
# # use Cross Entropy Loss
# self.criteon = nn.CrossEntropyLoss()
def forward(self, x):
""" :param x: [b, 3, 32, 32] :return: """
batchsz = x.size(0)
# [b, 3, 32, 32] => [b, 16, 5, 5]
x = self.conv_unit(x)
# [b, 16, 5, 5] => [b, 16*5*5]
x = x.view(batchsz, 32*5*5)
# [b, 16*5*5] => [b, 10]
logits = self.fc_unit(x)
# # [b, 10]
# pred = F.softmax(logits, dim=1)
# loss = self.criteon(logits, y)
return logits
def main():
net = Lenet5()
tmp = torch.randn(2, 3, 32, 32)
out = net(tmp)
print('lenet out:', out.shape)
if __name__ == '__main__':
main()
3. ResNet 实现
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
""" resnet block """
def __init__(self, ch_in, ch_out, stride=1):
""" :param ch_in: :param ch_out: """
super(ResBlk, self).__init__()
# we add stride support for resbok, which is distinct from tutorials.
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
if ch_out != ch_in:
# [b, ch_in, h, w] => [b, ch_out, h, w]
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
""" :param x: [b, ch, h, w] :return: """
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# short cut.
# extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
# element-wise add:
out = self.extra(x) + out
out = F.relu(out)
return out
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(64)
)
# followed 4 blocks
# [b, 64, h, w] => [b, 128, h ,w]
self.blk1 = ResBlk(64, 128, stride=2)
# [b, 128, h, w] => [b, 256, h, w]
self.blk2 = ResBlk(128, 256, stride=2)
# # [b, 256, h, w] => [b, 512, h, w]
self.blk3 = ResBlk(256, 512, stride=2)
# # [b, 512, h, w] => [b, 1024, h, w]
self.blk4 = ResBlk(512, 512, stride=2)
self.outlayer = nn.Linear(512*1*1, 10)
def forward(self, x):
""" :param x: :return: """
x = F.relu(self.conv1(x))
# [b, 64, h, w] => [b, 1024, h, w]
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print('after conv:', x.shape) #[b, 512, 2, 2]
# [b, 512, h, w] => [b, 512, 1, 1]
x = F.adaptive_avg_pool2d(x, [1, 1])
# print('after pool:', x.shape)
x = x.view(x.size(0), -1)
x = self.outlayer(x)
return x
def main():
blk = ResBlk(64, 128, stride=4)
tmp = torch.randn(2, 64, 32, 32)
out = blk(tmp)
print('block:', out.shape)
x = torch.randn(2, 3, 32, 32)
model = ResNet18()
out = model(x)
print('resnet:', out.shape)
if __name__ == '__main__':
main()
边栏推荐
- 缓存一致性
- 二手手机销量突破3亿部,与降价的iPhone夹击国产手机
- 元宇宙的六大支撑技术
- Study Notes - Becoming a Data Analyst in Seven Weeks "Week 2: Business": Business Analysis Metrics
- 重保特辑|拦截99%恶意流量,揭秘WAF攻防演练最佳实践
- 【微信小程序】一文带你搞懂小程序的页面配置和网络数据请求
- 每天学一点Scala之 伴生类和伴生对象
- 外包干了七年,废了。。。
- R语言ggplot2可视化:使用ggpubr包的ggmaplot函数可视化MA图(MA-plot)、设置label.select参数自定义在图中显示标签的基因类型(自定义显示的标签列表)
- shell脚本流程控制语句
猜你喜欢

学习笔记——七周成为数据分析师《第二周:业务》:业务分析指标

Study Notes - Becoming a Data Analyst in Seven Weeks "Week 2: Business": Business Analysis Metrics

svg波浪动画js特效代码

如何把Excel表格显示到邮件正文里?

leetcode207.课程表(判断有向图是否有环)

Apache Log4j2漏洞

PyQt5快速开发与实战 8.6 设置样式

What is the level of Ali P7?

戴墨镜的卡通太阳SVG动画js特效

jsArray array copy method performance test 2207300823
随机推荐
正确处理页面控制器woopagecontroller.php,当提交表单时是否跳转正确的页面
在 Scala 中读取整个文件
电池包托盘有进水风险,存在安全隐患,紧急召回52928辆唐DM
434. 字符串中的单词数
EasyNVS cloud management platform function reconstruction: support for adding users, modifying information, etc.
Composer安装方式
“12306” 的架构到底有多牛逼
自从外包干了四年,基本废了...
How to solve the problem that the page does not display the channel configuration after the EasyNVR is updated to (V5.3.0)?
Raja Koduri澄清Arc GPU跳票传闻 AXG年底前新推四条产品线
jsArray array copy method performance test 2207300823
学习笔记——七周成为数据分析师《第一周:数据分析思维》
R语言使用aov函数进行单因素协方差分析(One-way ANCOVA)、使用effects包中的effect函数来计算调整后的分组均值(calculate adjusted means)
UPC2022暑期个人训练赛第19场(B,P)
重保特辑|拦截99%恶意流量,揭秘WAF攻防演练最佳实践
R语言ggstatsplot包grouped_ggwithinstats函数可视化分组小提琴图、并添加假设检验结果(包含样本数、统计量、效应大小及其置信区间、显著性、组间两两比较、贝叶斯假设)
IDEA 重复代码快速重构(抽取重复代码快捷键)
shell 编程规范与变量
第十五天笔记
dolphinscheduler simple task definition and complex cross-node parameter transfer