当前位置:网站首页>Pytorch实现ResNet
Pytorch实现ResNet
2022-07-31 05:16:00 【王大队长】
目标:用pytorch实现下图所示的网络

代码:
import torch
from torch import nn
import torch.nn.functional as F
class ResBlock(nn.Module): #残差块的实现也是继承nn.module后实现一个类,同样的要实现__init__()方法和forward方法
def __init__(self, n_chans):
super().__init__()
self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)
self.batch_norm = nn.BatchNorm2d(n_chans)
torch.nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu') #参数初始化
torch.nn.init.constant_(self.batch_norm.weight, 0.5)
torch.nn.init.zeros_(self.batch_norm.bias)
def forward(self,x):
out = self.conv(x)
out = self.batch_norm(out)
out = F.relu(out)
return out + x
class NetResDepp(nn.Module):
def __init__(self, n_chans1=32, num_blocks=100):
super().__init__()
self.n_chans1 = n_chans1
self.num_blocks = num_blocks
self.conv = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
self.resblocks = nn.Sequential(*(num_blocks * [*ResBlock(n_chans=n_chans1)])) # 注意这里的100个Resblock是通过先对ResBlock解包放到列表里,再用100乘这个列表就实现了将列表复制100倍,再解包就实现了100个ResBlock
self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)
self.fc2 = nn.Linear(32,2)
def forward(self, x):
out = F.relu(self.conv(x))
out = F.max_pool2d(out, 2)
out = self.resblocks(out)
out = F.max_pool2d(out, 2)
out = out.view(-1, 8 * 8 * self.n_chans1)
out = self.fc1(out)
out = self.fc2(out)
return out
参考资料:
pytorch深度学习实战(伊莱史蒂文斯)
边栏推荐
- cocos create EditBox 输入文字被刘海屏遮挡修改
- Android软件安全与逆向分析阅读笔记
- NFTs: The Heart of Digital Ownership
- Understanding of objects and functions in js
- Flow control statement in js
- 5 methods of MySQL paging query
- After unicloud is released, the applet prompts that the connection to the local debugging service failed. Please check whether the client and the host are under the same local area network.
- break and continue exit in js
- MySql to create data tables
- Podspec automatic upgrade script
猜你喜欢

MySQL高级语句(一)

浏览器查找js绑定或者监听的事件
![[Cloud native] Simple introduction and use of microservice Nacos](/img/06/b0594208d5b0cbf3ae8edd80ec12c4.png)
[Cloud native] Simple introduction and use of microservice Nacos

JS写一段代码,判断一个字符串中出现次数最多的字符串,并统计出现的次数JS

Principle analysis of famous website msdn.itellyou.cn

js中的对象与函数的理解

this points to the problem

For penetration testing methods where the output point is a timestamp (take Oracle database as an example)

Xiaomi mobile phone SMS location service activation failed

Understanding of js arrays
随机推荐
js中的对象与函数的理解
Understanding of objects and functions in js
Android software security and reverse analysis reading notes
js中流程控制语句
Why does read in bash need to cooperate with while to read the contents of /dev/stdin
网页截图与反向代理
ERROR Error: No module factory availabl at Object.PROJECT_CONFIG_JSON_NOT_VALID_OR_NOT_EXIST ‘Error
SSH自动重连脚本
quick-3.5 ActionTimeline的setLastFrameCallFunc调用会崩溃问题
understand js operators
Markdown help documentation
Judgment of database in SQL injection
configure:error no SDL library found
2021年京东数据分析工程师秋招笔试编程题
Markdown 帮助文档
flutter arr dependencies
cocos2d-x-3.2创建项目方法
flutter arr 依赖
VS2017连接MYSQL
MYSQL事务与锁问题处理