当前位置:网站首页>Pytorch实现ResNet
Pytorch实现ResNet
2022-07-31 05:16:00 【王大队长】
目标:用pytorch实现下图所示的网络
代码:
import torch
from torch import nn
import torch.nn.functional as F
class ResBlock(nn.Module): #残差块的实现也是继承nn.module后实现一个类,同样的要实现__init__()方法和forward方法
def __init__(self, n_chans):
super().__init__()
self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)
self.batch_norm = nn.BatchNorm2d(n_chans)
torch.nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu') #参数初始化
torch.nn.init.constant_(self.batch_norm.weight, 0.5)
torch.nn.init.zeros_(self.batch_norm.bias)
def forward(self,x):
out = self.conv(x)
out = self.batch_norm(out)
out = F.relu(out)
return out + x
class NetResDepp(nn.Module):
def __init__(self, n_chans1=32, num_blocks=100):
super().__init__()
self.n_chans1 = n_chans1
self.num_blocks = num_blocks
self.conv = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
self.resblocks = nn.Sequential(*(num_blocks * [*ResBlock(n_chans=n_chans1)])) # 注意这里的100个Resblock是通过先对ResBlock解包放到列表里,再用100乘这个列表就实现了将列表复制100倍,再解包就实现了100个ResBlock
self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)
self.fc2 = nn.Linear(32,2)
def forward(self, x):
out = F.relu(self.conv(x))
out = F.max_pool2d(out, 2)
out = self.resblocks(out)
out = F.max_pool2d(out, 2)
out = out.view(-1, 8 * 8 * self.n_chans1)
out = self.fc1(out)
out = self.fc2(out)
return out
参考资料:
pytorch深度学习实战(伊莱史蒂文斯)
边栏推荐
- powershell统计文件夹大小
- [Cloud native] Ribbon is no longer used at the bottom layer of OpenFeign starting from the 2020.0.X version
- unicloud cloud development record
- quick-3.5 ActionTimeline的setLastFrameCallFunc调用会崩溃问题
- VTK:Could not locate vtkTextRenderer object.
- Understanding of objects and functions in js
- NFTs: The Heart of Digital Ownership
- unicloud 发布后小程序提示连接本地调试服务失败,请检查客户端是否和主机在同一局域网下
- JS写一段代码,判断一个字符串中出现次数最多的字符串,并统计出现的次数JS
- [swagger close] The production environment closes the swagger method
猜你喜欢
Build DVWA with phpstudy
MySql to create data tables
使用 OpenCV 提取图像的 HOG、SURF 及 LBP 特征 (含代码)
npm WARN config global `--global`, `--local` are deprecated. Use `--location solution
Chinese garbled solution in UTF-8 environment in Powershell
Global scope and function scope in js
Powershell中UTF-8环境中文乱码解决办法
[Cloud Native] What should I do if SQL (and stored procedures) run too slowly?
为什么bash中的read要配合while才能读取/dev/stdin的内容
QT VS中双击ui文件无法打开的问题
随机推荐
quick-3.6源码修改纪录
网页截图与反向代理
Sqlite column A data is copied to column B
Attribute Changer的几种形态
数据库 | SQL增删改查基础语法
Understanding of objects and functions in js
Error: Cannot find module 'D:\Application\nodejs\node_modules\npm\bin\npm-cli.js'
MySQL错误-this is incompatible with sql_mode=only_full_group_by完美解决方案
Chinese garbled solution in UTF-8 environment in Powershell
DeFi Token in the project management
js中的break与continue退出
数据库 | SQL查询进阶语法
[Cloud Native] What should I do if SQL (and stored procedures) run too slowly?
sql 外键约束【表关系绑定】
禅道安装及使用教程
Xiaomi mobile phone SMS location service activation failed
powershell统计文件夹大小
flutter 混合开发 module 依赖
UiBot has an open Microsoft Edge browser and cannot perform the installation
自定dialog 布局没有居中解决方案