当前位置:网站首页>Pytorch实现ResNet
Pytorch实现ResNet
2022-07-31 05:16:00 【王大队长】
目标:用pytorch实现下图所示的网络

代码:
import torch
from torch import nn
import torch.nn.functional as F
class ResBlock(nn.Module): #残差块的实现也是继承nn.module后实现一个类,同样的要实现__init__()方法和forward方法
def __init__(self, n_chans):
super().__init__()
self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)
self.batch_norm = nn.BatchNorm2d(n_chans)
torch.nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu') #参数初始化
torch.nn.init.constant_(self.batch_norm.weight, 0.5)
torch.nn.init.zeros_(self.batch_norm.bias)
def forward(self,x):
out = self.conv(x)
out = self.batch_norm(out)
out = F.relu(out)
return out + x
class NetResDepp(nn.Module):
def __init__(self, n_chans1=32, num_blocks=100):
super().__init__()
self.n_chans1 = n_chans1
self.num_blocks = num_blocks
self.conv = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
self.resblocks = nn.Sequential(*(num_blocks * [*ResBlock(n_chans=n_chans1)])) # 注意这里的100个Resblock是通过先对ResBlock解包放到列表里,再用100乘这个列表就实现了将列表复制100倍,再解包就实现了100个ResBlock
self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)
self.fc2 = nn.Linear(32,2)
def forward(self, x):
out = F.relu(self.conv(x))
out = F.max_pool2d(out, 2)
out = self.resblocks(out)
out = F.max_pool2d(out, 2)
out = out.view(-1, 8 * 8 * self.n_chans1)
out = self.fc1(out)
out = self.fc2(out)
return out
参考资料:
pytorch深度学习实战(伊莱史蒂文斯)
边栏推荐
- npm WARN config global `--global`, `--local` are deprecated. Use `--location solution
- [Cloud native] Open source data analysis SPL easily copes with T+0
- Notes on creating a new virtual machine in Hyper-V
- 2021 Mianjing - Embrace Change
- 朴素贝叶斯文本分类(代码实现)
- Build vulhub vulnerability shooting range on kali
- 微信小程序源码获取与反编译方式
- This in js points to the prototype object
- Common JVM interview questions and answers
- 浏览器查找js绑定或者监听的事件
猜你喜欢

js中的对象与函数的理解

为什么bash中的read要配合while才能读取/dev/stdin的内容

The server time zone value ‘й‘ is unrecognized or represents more than one time zone

this points to the problem

Artifact SSMwar exploded Error deploying artifact.See server log for details

Common JVM interview questions and answers

禅道安装及使用教程

JS写一段代码,判断一个字符串中出现次数最多的字符串,并统计出现的次数JS

Why does read in bash need to cooperate with while to read the contents of /dev/stdin

Hyper-V新建虚拟机注意事项
随机推荐
What is an EVM Compatible Chain?
cocoscreator 显示刘海内容
Gradle sync failed: Uninitialized object exists on backward branch 142
场效应管 | N-mos内部结构详解
SSH automatic reconnection script
纯shell实现文本替换
Hyper-V新建虚拟机注意事项
quick-3.5 ActionTimeline的setLastFrameCallFunc调用会崩溃问题
理解js运算符
js中的全局作用域与函数作用域
unicloud 发布后小程序提示连接本地调试服务失败,请检查客户端是否和主机在同一局域网下
Navicat从本地文件中导入sql文件
Eternal blue bug reappears
JS写一段代码,判断一个字符串中出现次数最多的字符串,并统计出现的次数JS
通信原理——纠错编码 | 汉明码(海明码)手算详解
ERROR Error: No module factory availabl at Object.PROJECT_CONFIG_JSON_NOT_VALID_OR_NOT_EXIST ‘Error
Android software security and reverse analysis reading notes
MySQL面试题大全(陆续更新)
[Cloud Native] What should I do if SQL (and stored procedures) run too slowly?
微信小程序源码获取与反编译方式