当前位置:网站首页>23.卷积神经网络实战-ResNet
23.卷积神经网络实战-ResNet
2022-08-02 00:14:00 【派大星的最爱海绵宝宝】
实例

当ch_in与,ch_out不等时,通过代码使得[b,ch_in,h,w] -> [b,ch_out,h,w],把,ch_in变成,ch_out。
forward中x与out不等时,在x前加一个extra()。
我们4个block中h和w是变化的,只是在此处表达的时候没有变。
我们进行一个小测试
blk=ResBlk(64,128)
tmp=torch.randn(2,64,32,32)
out=blk(tmp)
print(out.shape)
我们的channel越来越大,我们的长和宽保持不变,最终导致我们的参数量越来越大。
我们需要长和宽减半,我们需要在参数部分添加stride,stride为1时,输入和输出非常接近,当为2时,有可能输出为输入的一半。
blk=ResBlk(64,128,stride=2)
tmp=torch.randn(2,64,32,32)
out=blk(tmp)
print(out.shape)

blk=ResBlk(64,128,stride=4)
tmp=torch.randn(2,64,32,32)
out=blk(tmp)
print(out.shape)

如果是match,就不会报错。
进行人为的调试:
print('after conv:', x.shape)
x=self.outlay(x)

修改参数:
self.conv1=nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0),
nn.BatchNorm2d(64)
)
# followed 4 blocks
#[b,64,h,w]->[b,128,h,w]
self.blk1=ResBlk(64,128,stride=2)
# [b,128,h,w]->[b,2556,h,w]
self.blk2=ResBlk(128,256,stride=2)
# [b,256,h,w]->[b,512,h,w]
self.blk3=ResBlk(256,512,stride=2)
# [b,512,h,w]->[b,1024,h,w]
self.blk4=ResBlk(512,512,stride=2)
self.outlay=nn.Linear(512*1*1,10)

小结
整体是先对数据做一个预处理,然后进行4个block,每一个block都由2个卷积和一个短接层组成,处理过程中数据的channel会慢慢增加,但是长和宽会减少,得到(512,512),再把这个(512)打平后送入全连接层,做一个分类的任务。这就是ResNet的一个基本结构。
代码
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
''' resnet block '''
def __init__(self,ch_in,ch_out,stride=1):
''' :param ch_in: :param ch_out: '''
super(ResBlk, self).__init__()
self.con1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
self.bn1=nn.BatchNorm2d(ch_out)
self.con2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
self.bn2=nn.BatchNorm2d(ch_out)
self.extra=nn.Sequential()
if ch_out != ch_in:
self.extra=nn.Sequential(
nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self,x):
''' :param x:[b,ch,h,w] :return: '''
out=F.relu(self.bn1(self.con1(x)))
out=self.bn2(self.con2(out))
# short cut
# extra model:[b,ch_in,h,w] with [b,ch_out,h,w]
out=self.extra(x)+out
return out
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.conv1=nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0),
nn.BatchNorm2d(64)
)
# followed 4 blocks
#[b,64,h,w]->[b,128,h,w]
self.blk1=ResBlk(64,128,stride=2)
# [b,128,h,w]->[b,2556,h,w]
self.blk2=ResBlk(128,256,stride=2)
# [b,256,h,w]->[b,512,h,w]
self.blk3=ResBlk(256,512,stride=2)
# [b,512,h,w]->[b,1024,h,w]
self.blk4=ResBlk(512,512,stride=2)
self.outlay=nn.Linear(512*1*1,10)
def forward(self,x):
''' :param x: :return: '''
x=F.relu(self.conv1(x))
# [b,64,h,w]->[b,1024,h,w]
x=self.blk1(x)
x=self.blk2(x)
x=self.blk3(x)
x=self.blk4(x)
# print('after conv:', x.shape)
# x=self.outlay(x)
x=F.adaptive_avg_pool2d(x,[1,1])
x=x.view(x.size(0),-1)
x=self.outlay(x)
return x
def main():
blk=ResBlk(64,128,stride=4)
tmp=torch.randn(2,64,32,32)
out=blk(tmp)
print('block:',out.shape)
x=torch.randn(2,3,32,32)
model=ResNet18()
out=model(x)
print('resnet:',out.shape)
if __name__ == '__main__':
main()
边栏推荐
- JSP out.println()方法具有什么功能呢?
- Redis-消息发布订阅
- uni-app项目总结
- Industrial control network intrusion detection based on automatic optimization of hyperparameters
- Are test points the same as test cases?
- String splitting function strtok exercise
- 基于相关性变量筛选偏最小二乘回归的多维相关时间序列建模方法
- bgp 聚合 反射器 联邦实验
- 实现删除-一个字符串中的指定字母,如:字符串“abcd”,删除其中的”a”字母,剩余”bcd”,也可以传递多个需要删除的字符,传递”ab”也可以做到删除”ab”,剩余”cd”。
- [Headline] Written test questions - minimum stack
猜你喜欢

632. Minimum interval

【HCIP】BGP小型实验(联邦,优化)

C语言实现扫雷游戏

链上治理为何如此重要,波卡Gov 2.0又会如何引领链上治理的发展?

测试用例:四步测试设计法

【CodeTON Round 2 (Div. 1 + Div. 2, Rated, Prizes!)(A~D)】

Identify memory functions memset, memcmp, memmove, and memcpy

MLX90640 红外热成像仪测温模块开发笔记(完整版)

Short video SEO optimization tutorial Self-media SEO optimization skills and methods

C language character and string function summary (2)
随机推荐
工业信息物理系统攻击检测增强模型
JSP out. The write () method has what function?
Task execution control in Ansible
Constructor, this keyword, method overloading, local variables and member variables
磁盘与文件系统管理
抖音数据接口API-获取用户主页信息-监控直播开启
信息物理系统状态估计与传感器攻击检测
Active Disturbance Rejection Control of Substation Inspection Robot Based on Data Drive
08-SDRAM: Summary
unity2D横版游戏教程5-UI
MYSQL(基本篇)——一篇文章带你走进MYSQL的奇妙世界
C language character and string function summary (2)
Kunpeng compile and debug plug-in actual combat
Web开发
els block boundary deformation processing
请教一下本网站左下角的动漫人物是怎么做的?
go笔记——锁
Cyber-Physical System State Estimation and Sensor Attack Detection
Business test how to avoid missing?
扑克牌问题