当前位置:网站首页>23.卷积神经网络实战-ResNet
23.卷积神经网络实战-ResNet
2022-08-02 00:14:00 【派大星的最爱海绵宝宝】
实例

当ch_in与,ch_out不等时,通过代码使得[b,ch_in,h,w] -> [b,ch_out,h,w],把,ch_in变成,ch_out。
forward中x与out不等时,在x前加一个extra()。
我们4个block中h和w是变化的,只是在此处表达的时候没有变。
我们进行一个小测试
blk=ResBlk(64,128)
tmp=torch.randn(2,64,32,32)
out=blk(tmp)
print(out.shape)
我们的channel越来越大,我们的长和宽保持不变,最终导致我们的参数量越来越大。
我们需要长和宽减半,我们需要在参数部分添加stride,stride为1时,输入和输出非常接近,当为2时,有可能输出为输入的一半。
blk=ResBlk(64,128,stride=2)
tmp=torch.randn(2,64,32,32)
out=blk(tmp)
print(out.shape)

blk=ResBlk(64,128,stride=4)
tmp=torch.randn(2,64,32,32)
out=blk(tmp)
print(out.shape)

如果是match,就不会报错。
进行人为的调试:
print('after conv:', x.shape)
x=self.outlay(x)

修改参数:
self.conv1=nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0),
nn.BatchNorm2d(64)
)
# followed 4 blocks
#[b,64,h,w]->[b,128,h,w]
self.blk1=ResBlk(64,128,stride=2)
# [b,128,h,w]->[b,2556,h,w]
self.blk2=ResBlk(128,256,stride=2)
# [b,256,h,w]->[b,512,h,w]
self.blk3=ResBlk(256,512,stride=2)
# [b,512,h,w]->[b,1024,h,w]
self.blk4=ResBlk(512,512,stride=2)
self.outlay=nn.Linear(512*1*1,10)

小结
整体是先对数据做一个预处理,然后进行4个block,每一个block都由2个卷积和一个短接层组成,处理过程中数据的channel会慢慢增加,但是长和宽会减少,得到(512,512),再把这个(512)打平后送入全连接层,做一个分类的任务。这就是ResNet的一个基本结构。
代码
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
''' resnet block '''
def __init__(self,ch_in,ch_out,stride=1):
''' :param ch_in: :param ch_out: '''
super(ResBlk, self).__init__()
self.con1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
self.bn1=nn.BatchNorm2d(ch_out)
self.con2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
self.bn2=nn.BatchNorm2d(ch_out)
self.extra=nn.Sequential()
if ch_out != ch_in:
self.extra=nn.Sequential(
nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self,x):
''' :param x:[b,ch,h,w] :return: '''
out=F.relu(self.bn1(self.con1(x)))
out=self.bn2(self.con2(out))
# short cut
# extra model:[b,ch_in,h,w] with [b,ch_out,h,w]
out=self.extra(x)+out
return out
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.conv1=nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0),
nn.BatchNorm2d(64)
)
# followed 4 blocks
#[b,64,h,w]->[b,128,h,w]
self.blk1=ResBlk(64,128,stride=2)
# [b,128,h,w]->[b,2556,h,w]
self.blk2=ResBlk(128,256,stride=2)
# [b,256,h,w]->[b,512,h,w]
self.blk3=ResBlk(256,512,stride=2)
# [b,512,h,w]->[b,1024,h,w]
self.blk4=ResBlk(512,512,stride=2)
self.outlay=nn.Linear(512*1*1,10)
def forward(self,x):
''' :param x: :return: '''
x=F.relu(self.conv1(x))
# [b,64,h,w]->[b,1024,h,w]
x=self.blk1(x)
x=self.blk2(x)
x=self.blk3(x)
x=self.blk4(x)
# print('after conv:', x.shape)
# x=self.outlay(x)
x=F.adaptive_avg_pool2d(x,[1,1])
x=x.view(x.size(0),-1)
x=self.outlay(x)
return x
def main():
blk=ResBlk(64,128,stride=4)
tmp=torch.randn(2,64,32,32)
out=blk(tmp)
print('block:',out.shape)
x=torch.randn(2,3,32,32)
model=ResNet18()
out=model(x)
print('resnet:',out.shape)
if __name__ == '__main__':
main()
边栏推荐
- 鲲鹏编译调试插件实战
- How does JSP use the page command to make the JSP file support Chinese encoding?
- Simpson's paradox
- Constructor, this keyword, method overloading, local variables and member variables
- Business test how to avoid missing?
- 【CodeTON Round 2 (Div. 1 + Div. 2, Rated, Prizes!)(A~D)】
- Redis 相关问题
- 这 4 款电脑记事本软件,得试试
- IO stream basics
- JS中清空数组的方法
猜你喜欢

Business test how to avoid missing?

冒泡排序函数封装

Don't know about SynchronousQueue?So ArrayBlockingQueue and LinkedBlockingQueue don't and don't know?

Unknown CMake command “add_action_files“

Unknown CMake command "add_action_files"

IP Core: FIFO

Trie详解
![[21-Day Learning Challenge] A small summary of sequential search and binary search](/img/81/7339a33de3b9e3aec0474a15825a53.png)
[21-Day Learning Challenge] A small summary of sequential search and binary search

What is Low-Code?What scenarios is low code suitable for?

NodeJs, all kinds of path
随机推荐
【HCIP】BGP小型实验(联邦,优化)
Interview high-frequency test questions solution - stack push and pop sequence, effective parentheses, reverse Polish expression evaluation
BGP综合实验 建立对等体、路由反射器、联邦、路由宣告及聚合
Angr(十二)——官方文档(Part3)
2022/08/01 Study Notes (day21) Generics and Enums
After an incomplete recovery, the control file has been created or restored, the database must be opened with RESETLOGS, interpreting RESETLOGS.
Pytorch seq2seq 模型架构实现英译法任务
Mean Consistency Tracking of Time-Varying Reference Inputs for Multi-Agent Systems with Communication Delays
攻防世界-web-Training-WWW-Robots
Short video SEO optimization tutorial Self-media SEO optimization skills and methods
ES6对箭头函数的理解
MLX90640 红外热成像仪测温模块开发笔记(完整版)
bgp aggregation reflector federation experiment
Multi-feature fusion face detection based on attention mechanism
go笔记之——goroutine
Constructor, this keyword, method overloading, local variables and member variables
unity2D横版游戏教程5-UI
ICML 2022 | GraphFM:通过特征Momentum提升大规模GNN的训练
不要用jOOQ串联字符串
Unknown CMake command "add_action_files"