当前位置:网站首页>pytorch模型
pytorch模型
2022-07-27 05:13:00 【Mr_health】
假设有一个模型为conv + bn + relu :
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)
来看一下下面的代码输出:
model = ConvBNReLU(3, 32) # 卷积 + bn + relu
i = 0
for key, weight in model.named_parameters(): #
print(key, weight.shape)
# weight.requires_grad = False #可以用来设置是否更新参数
i += 1
print('数量',i)
'''
输出:
0.weight torch.Size([32, 3, 3, 3])
1.weight torch.Size([32])
1.bias torch.Size([32])
数量 3
'''
state = model.state_dict()
j = 0
for key in state:
print(key,state[key].shape)
j += 1
print('数量',j)
'''
输出:
0.weight torch.Size([32, 3, 3, 3])
1.weight torch.Size([32])
1.bias torch.Size([32])
1.running_mean torch.Size([32])
1.running_var torch.Size([32])
1.num_batches_tracked torch.Size([])
数量 6
'''
k = 0
for content in list(model.parameters()): #
print(content.shape)
k += 1
print('数量', k)
'''
torch.Size([32, 3, 3, 3])
torch.Size([32])
torch.Size([32])
数量 3
'''可以看到:
- model.named_parameters():是对应的字典形式, key是参数名称,储存网络需要反向传播的参数
- model.state_dict():储存网络整体参数,包括:需要反向传播训练的参数、仅仅需要向前向前传播的参数。这里仅仅需要向前传播的参数的参数主要是bn层的均值方差等:1.running_mean torch.Size([32])、1.running_var torch.Size([32])、1.num_batches_tracked torch.Size([])
- model.parameters():储存网络需要训练(反向传播)的参数,一般会在定义optimizer的时候用到。可以看到它与model.named_parameters()所表示的是一样的。没有bn层的均值、方差等参数,因为这些参数只需要向前传播,无需反向传播。
补充一下,上面的参数中:
1.running_mean torch.Size([32])、1.running_var torch.Size([32])、1.num_batches_tracked torch.Size([])表示的是bn层均值方差,只需要向前传播
1.weight torch.Size([32]),1.bias torch.Size([32])表示的是bn层的scale和bias参数,这个是需要训练的
边栏推荐
- RK3288板卡HDMI显示uboot和kernel的logo图片
- GBASE 8C——SQL参考6 sql语法(12)
- Sealem Finance - a new decentralized financial platform based on Web3
- Do you really know session and cookies?
- Day 4.Social Data Sentiment Analysis: Detection of Adolescent Depression Signals
- Minimum handling charges and margins for futures companies
- Gbase 8C - SQL reference 5 full text search
- MySQL查询操作索引优化实践
- 数字图像处理第五章——图像复原与重建
- GBASE 8C——SQL参考6 sql语法(14)
猜你喜欢

4.张量数据类型和创建Tensor

MySQL如何执行查询语句

新冠时空分析——Global evidence of expressed sentiment alterations during the COVID-19 pandemic

Okaleido launched the fusion mining mode, which is the only way for Oka to verify the current output

使用Docker部署Redis进行高可用主从复制

Minimum handling charges and margins for futures companies

17.动量与学习率的衰减

【高并发】面试官

16.过拟合欠拟合

7.合并与分割
随机推荐
Day 9. Graduate survey: A love–hurt relationship
数字图像处理——第九章 形态学图像处理
vscode打造golang开发环境以及golang的debug单元测试
19.上下采样与BatchNorm
MySQL索引分析除了EXPLAIN还有什么方法
GBASE 8C——SQL参考6 sql语法(12)
How can seektiger go against the cold air in the market?
贪心高性能神经网络与AI芯片应用研修
如果面试官问你 JVM,额外回答“逃逸分析”技术会让你加分
Day14. Using interpretable machine learning method to distinguish intestinal tuberculosis and Crohn's disease
Day 8.Developing Simplified Chinese Psychological Linguistic Analysis Dictionary for Microblog
16.过拟合欠拟合
Seven enabling schemes of m-dao help Dao ecology move towards mode and standardization
andorid检测GPU呈现速度和过度绘制
Mysql和Redis如何保证数据一致性
GBASE 8C——SQL参考4 字符集支持
7.合并与分割
Cap principle
Es time query error - "caused_by": {"type": "illegal_argument_exception", "reason": "failed to parse date field
DDD领域驱动设计笔记