当前位置:网站首页>Pytorch实现ResNet
Pytorch实现ResNet
2022-07-31 05:16:00 【王大队长】
目标:用pytorch实现下图所示的网络

代码:
import torch
from torch import nn
import torch.nn.functional as F
class ResBlock(nn.Module): #残差块的实现也是继承nn.module后实现一个类,同样的要实现__init__()方法和forward方法
def __init__(self, n_chans):
super().__init__()
self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)
self.batch_norm = nn.BatchNorm2d(n_chans)
torch.nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu') #参数初始化
torch.nn.init.constant_(self.batch_norm.weight, 0.5)
torch.nn.init.zeros_(self.batch_norm.bias)
def forward(self,x):
out = self.conv(x)
out = self.batch_norm(out)
out = F.relu(out)
return out + x
class NetResDepp(nn.Module):
def __init__(self, n_chans1=32, num_blocks=100):
super().__init__()
self.n_chans1 = n_chans1
self.num_blocks = num_blocks
self.conv = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
self.resblocks = nn.Sequential(*(num_blocks * [*ResBlock(n_chans=n_chans1)])) # 注意这里的100个Resblock是通过先对ResBlock解包放到列表里,再用100乘这个列表就实现了将列表复制100倍,再解包就实现了100个ResBlock
self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)
self.fc2 = nn.Linear(32,2)
def forward(self, x):
out = F.relu(self.conv(x))
out = F.max_pool2d(out, 2)
out = self.resblocks(out)
out = F.max_pool2d(out, 2)
out = out.view(-1, 8 * 8 * self.n_chans1)
out = self.fc1(out)
out = self.fc2(out)
return out
参考资料:
pytorch深度学习实战(伊莱史蒂文斯)
边栏推荐
- MySql to create data tables
- softmax函数详解
- [Cloud native] Simple introduction and use of microservice Nacos
- ERROR Error: No module factory availabl at Object.PROJECT_CONFIG_JSON_NOT_VALID_OR_NOT_EXIST ‘Error
- 2021 Mianjing - Embrace Change
- Podspec verification dependency error problem pod lib lint , need to specify the source
- this指向问题
- 微信小程序启动优化
- Markdown 帮助文档
- Gradle sync failed: Uninitialized object exists on backward branch 142
猜你喜欢

Nmap的下载与安装

Sqlite column A data is copied to column B

朴素贝叶斯文本分类(代码实现)

Artifact SSMwar exploded Error deploying artifact.See server log for details
![[Cloud native] Open source data analysis SPL easily copes with T+0](/img/89/4a96358956782ef9dacf0b700b54c3.png)
[Cloud native] Open source data analysis SPL easily copes with T+0

WeChat applet source code acquisition and decompilation method

VS2017 connects to MYSQL

微信小程序源码获取与反编译方式

场效应管 | N-mos内部结构详解

What is an EVM Compatible Chain?
随机推荐
数据库 | SQL增删改查基础语法
quick lua加密
cocos2d-x 实现跨平台的目录遍历
腾讯云轻量服务器删除所有防火墙规则
quick-3.5 ActionTimeline的setLastFrameCallFunc调用会崩溃问题
mysql common commands
This in js points to the prototype object
sqlite 查看表结构 android.database.sqlite.SQLiteException: table splitTable has no column named
js中的break与continue退出
Judgment of database in SQL injection
Linux modify MySQL database password
cocos2d-x-3.2图片灰化效果
MYSQL事务与锁问题处理
Attribute Changer的几种形态
Filter out egrep itself when using ps | egrep
Error: Cannot find module 'D:\Application\nodejs\node_modules\npm\bin\npm-cli.js'
Why does read in bash need to cooperate with while to read the contents of /dev/stdin
2021美赛C题M奖思路
使用 OpenCV 提取图像的 HOG、SURF 及 LBP 特征 (含代码)
动态规划(一)| 斐波那契数列和归递