当前位置:网站首页>Pytorch实现ResNet
Pytorch实现ResNet
2022-07-31 05:16:00 【王大队长】
目标:用pytorch实现下图所示的网络

代码:
import torch
from torch import nn
import torch.nn.functional as F
class ResBlock(nn.Module): #残差块的实现也是继承nn.module后实现一个类,同样的要实现__init__()方法和forward方法
def __init__(self, n_chans):
super().__init__()
self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)
self.batch_norm = nn.BatchNorm2d(n_chans)
torch.nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu') #参数初始化
torch.nn.init.constant_(self.batch_norm.weight, 0.5)
torch.nn.init.zeros_(self.batch_norm.bias)
def forward(self,x):
out = self.conv(x)
out = self.batch_norm(out)
out = F.relu(out)
return out + x
class NetResDepp(nn.Module):
def __init__(self, n_chans1=32, num_blocks=100):
super().__init__()
self.n_chans1 = n_chans1
self.num_blocks = num_blocks
self.conv = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
self.resblocks = nn.Sequential(*(num_blocks * [*ResBlock(n_chans=n_chans1)])) # 注意这里的100个Resblock是通过先对ResBlock解包放到列表里,再用100乘这个列表就实现了将列表复制100倍,再解包就实现了100个ResBlock
self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)
self.fc2 = nn.Linear(32,2)
def forward(self, x):
out = F.relu(self.conv(x))
out = F.max_pool2d(out, 2)
out = self.resblocks(out)
out = F.max_pool2d(out, 2)
out = out.view(-1, 8 * 8 * self.n_chans1)
out = self.fc1(out)
out = self.fc2(out)
return out
参考资料:
pytorch深度学习实战(伊莱史蒂文斯)
边栏推荐
猜你喜欢

MySQL compressed package installation, fool teaching

Principle analysis of famous website msdn.itellyou.cn

多元线性回归方程原理及其推导

CMOS管原理,及其在推挽电路中的应用

js中的全局作用域与函数作用域

Access database query

How MySQL - depots table?A look at will understand

this points to the problem

安装Multisim出现 No software will be installed or removed解决方法

DeFi Token in the project management
随机推荐
cocoscreator 显示刘海内容
Markdown help documentation
Sqlite A列数据复制到B列
cocos create EditBox 输入文字被刘海屏遮挡修改
著名网站msdn.itellyou.cn原理分析
Principle analysis of famous website msdn.itellyou.cn
通信原理——纠错编码 | 汉明码(海明码)手算详解
Navicat从本地文件中导入sql文件
计网 Packet Tracer仿真 | 简单易懂集线器和交换机对比(理论+仿真)
softmax函数详解
VTK环境配置
DeFi Token in the project management
VS2017连接MYSQL
How MySQL - depots table?A look at will understand
[Ubuntu20.04 installs MySQL and MySQL-workbench visualization tool]
jenkins +miniprogram-ci upload WeChat applet with one click
Powershell中UTF-8环境中文乱码解决办法
Gradle sync failed: Uninitialized object exists on backward branch 142
对js的数组的理解
理解js运算符