当前位置:网站首页>pytorch学习记录(五):卷积神经网络的实现
pytorch学习记录(五):卷积神经网络的实现
2022-07-30 13:25:00 【狸狸Arina】
1. 数据集的加载
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn, optim
from lenet5 import Lenet5
from resnet import ResNet18
def main():
batchsz = 128
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x, label = iter(cifar_train).next()
print('x:', x.shape, 'label:', label.shape)
device = torch.device('cuda')
# model = Lenet5().to(device)
model = ResNet18().to(device)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
for epoch in range(1000):
model.train()
for batchidx, (x, label) in enumerate(cifar_train):
# [b, 3, 32, 32]
# [b]
x, label = x.to(device), label.to(device)
logits = model(x)
# logits: [b, 10]
# label: [b]
# loss: tensor scalar
loss = criteon(logits, label)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch, 'loss:', loss.item())
model.eval()
with torch.no_grad():
# test
total_correct = 0
total_num = 0
for x, label in cifar_test:
# [b, 3, 32, 32]
# [b]
x, label = x.to(device), label.to(device)
# [b, 10]
logits = model(x)
# [b]
pred = logits.argmax(dim=1)
# [b] vs [b] => scalar tensor
correct = torch.eq(pred, label).float().sum().item()
total_correct += correct
total_num += x.size(0)
# print(correct)
acc = total_correct / total_num
print(epoch, 'test acc:', acc)
if __name__ == '__main__':
main()
2. LeNet 实现
import torch
from torch import nn
from torch.nn import functional as F
class Lenet5(nn.Module):
""" for cifar10 dataset. """
def __init__(self):
super(Lenet5, self).__init__()
self.conv_unit = nn.Sequential(
# x: [b, 3, 32, 32] => [b, 16, ]
nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=0),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
#
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=0),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
#
)
# flatten
# fc unit
self.fc_unit = nn.Sequential(
nn.Linear(32*5*5, 32),
nn.ReLU(),
# nn.Linear(120, 84),
# nn.ReLU(),
nn.Linear(32, 10)
)
# [b, 3, 32, 32]
tmp = torch.randn(2, 3, 32, 32)
out = self.conv_unit(tmp)
# [b, 16, 5, 5]
print('conv out:', out.shape)
# # use Cross Entropy Loss
# self.criteon = nn.CrossEntropyLoss()
def forward(self, x):
""" :param x: [b, 3, 32, 32] :return: """
batchsz = x.size(0)
# [b, 3, 32, 32] => [b, 16, 5, 5]
x = self.conv_unit(x)
# [b, 16, 5, 5] => [b, 16*5*5]
x = x.view(batchsz, 32*5*5)
# [b, 16*5*5] => [b, 10]
logits = self.fc_unit(x)
# # [b, 10]
# pred = F.softmax(logits, dim=1)
# loss = self.criteon(logits, y)
return logits
def main():
net = Lenet5()
tmp = torch.randn(2, 3, 32, 32)
out = net(tmp)
print('lenet out:', out.shape)
if __name__ == '__main__':
main()
3. ResNet 实现
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
""" resnet block """
def __init__(self, ch_in, ch_out, stride=1):
""" :param ch_in: :param ch_out: """
super(ResBlk, self).__init__()
# we add stride support for resbok, which is distinct from tutorials.
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
if ch_out != ch_in:
# [b, ch_in, h, w] => [b, ch_out, h, w]
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
""" :param x: [b, ch, h, w] :return: """
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# short cut.
# extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
# element-wise add:
out = self.extra(x) + out
out = F.relu(out)
return out
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(64)
)
# followed 4 blocks
# [b, 64, h, w] => [b, 128, h ,w]
self.blk1 = ResBlk(64, 128, stride=2)
# [b, 128, h, w] => [b, 256, h, w]
self.blk2 = ResBlk(128, 256, stride=2)
# # [b, 256, h, w] => [b, 512, h, w]
self.blk3 = ResBlk(256, 512, stride=2)
# # [b, 512, h, w] => [b, 1024, h, w]
self.blk4 = ResBlk(512, 512, stride=2)
self.outlayer = nn.Linear(512*1*1, 10)
def forward(self, x):
""" :param x: :return: """
x = F.relu(self.conv1(x))
# [b, 64, h, w] => [b, 1024, h, w]
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print('after conv:', x.shape) #[b, 512, 2, 2]
# [b, 512, h, w] => [b, 512, 1, 1]
x = F.adaptive_avg_pool2d(x, [1, 1])
# print('after pool:', x.shape)
x = x.view(x.size(0), -1)
x = self.outlayer(x)
return x
def main():
blk = ResBlk(64, 128, stride=4)
tmp = torch.randn(2, 64, 32, 32)
out = blk(tmp)
print('block:', out.shape)
x = torch.randn(2, 3, 32, 32)
model = ResNet18()
out = model(x)
print('resnet:', out.shape)
if __name__ == '__main__':
main()
边栏推荐
- 如何判断自己是否适合IT行业?方法很简单
- ENVI Image Processing (6): NDVI and Vegetation Index
- MQTT网关读取西门子PLC数据传输到阿里云平台案例教程
- js男女身高体重关系图
- R语言向前或者向后移动时间序列数据(自定义滞后或者超前的期数):使用dplyr包中的lag函数将时间序列数据向后移动一天(设置参数n为负值)
- 二手手机销量突破3亿部,与降价的iPhone夹击国产手机
- 缓存一致性
- [PostgreSQL] - Storage structure and cache shared_buffers
- How to solve the problem that the page does not display the channel configuration after the EasyNVR is updated to (V5.3.0)?
- datax enables hana support and dolphinscheduler enables datax tasks
猜你喜欢

js人均寿命和GDP散点图统计样式

jsArray array copy method performance test 2207300823

缓存一致性

How to display an Excel table in the body of an email?

ARC117E Zero-Sum Ranges 2

Study Notes - Becoming a Data Analyst in Seven Weeks "Week 2: Business": Business Analysis Metrics

当下,产业园区发展面临的十大问题

一本通循环结构的程序设计题解(2)

cpu / CS 和 IP

Eleven BUUCTF questions (06)
随机推荐
The way of programmers' cultivation: do one's own responsibilities, be clear in reality - lead to the highest realm of pragmatism
shell script flow control statement
CF780G Andryusha and Nervous Barriers
戴墨镜的卡通太阳SVG动画js特效
There is no one of the strongest kings in the surveillance world!
Markdown 1 - 图文音视频等
EasyNVS云管理平台功能重构:支持新增用户、修改信息等
R语言筛选时间序列数据的子集(subset time series data)、使用window函数筛选连续日期时间范围内的数据(start参数和end参数分别指定起始和结束时间)
TaskDispatcher source code parsing
Learning notes - 7 weeks as data analyst "in the first week: data analysis of thinking"
05 | 后台登录:基于账号密码的登录方式(下)
shell 编程规范与变量
DeFi 巨头进军 NFT 领域 用户怎么看?
434. 字符串中的单词数
当下,产业园区发展面临的十大问题
CF603E Pastoral Oddities
程序员修炼之道:务以己任,实则明心——通向务实的最高境界
C语言学习练习题:汉诺塔(函数与递归)
R语言向前或者向后移动时间序列数据(自定义滞后或者超前的期数):使用dplyr包中的lag函数将时间序列数据向后移动一天(设置参数n为负值)
【Advanced Mathematics】【7】Double Integral