当前位置:网站首页>模型定义#pytorch学习
模型定义#pytorch学习
2022-07-26 17:56:00 【CHENYUZ-hub】
1. PyTorch模型定义的方式: nn.Module(Sequential,ModuleList,ModuleDict)
Module类是torch.nn模块里提供的一个模型构造类 (nn.Module),是所有神经⽹网络模块的基类,我们可以继承它来定义我们想要的模型;PyTorch模型定义应包括两个主要部分:各个部分的初始化(
__init__);数据流向定义(forward)
基于nn.Module,我们可以通过Sequential,ModuleList和ModuleDict三种方式定义PyTorch模型。
1.1 Sequential: nn.Sequential()——直接排列/使用OrderedDict
当模型的前向计算为简单串联各个层的计算时, Sequential 类可以通过更加简单的方式定义模型。它可以接收一个子模块的有序字典(OrderedDict) 或者一系列子模块作为参数来逐一添加 Module 的实例,模型的前向计算就是将这些实例按添加的顺序逐个计算。我们结合Sequential和定义方式加以理解:
class MySequential(nn.Module):
from collections import OrderedDict
def __init__(self, *args):
super(MySequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict): # 如果传入的是一个OrderedDict
for key, module in args[0].items():
self.add_module(key, module)
# add_module方法会将module添加进self._modules(一个OrderedDict)
else: # 传入的是一些Module
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def forward(self, input):
# self._modules返回一个 OrderedDict,保证会按照成员添加时的顺序遍历成
for module in self._modules.values():
input = module(input)
return input使用Sequential来定义模型,只需要将模型的层按序排列起来即可,根据层名的不同,排列的时候有两种方式:
1. 直接排列
import torch.nn as nn
net = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
print(net)
'''
Sequential(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=10, bias=True)
)
'''2. 使用OrderedDict
import collections
import torch.nn as nn
net2 = nn.Sequential(collections.OrderedDict([
('fc1', nn.Linear(784, 256)),
('relu1', nn.ReLU()),
('fc2', nn.Linear(256, 10))
]))
print(net2)
'''
Sequential(
(fc1): Linear(in_features=784, out_features=256, bias=True)
(relu1): ReLU()
(fc2): Linear(in_features=256, out_features=10, bias=True)
)
'''可以看到,使用Sequential定义模型的好处在于简单、易读,同时使用Sequential定义的模型不需要再写forward,因为顺序已经定义好了。但使用Sequential也会使得模型定义丧失灵活性,比如需要在模型中间加入一个外部输入时就不适合用Sequential的方式实现。使用时需根据实际需求加以选择。
1.2 ModuleList:nn.ModuleList() 存储不同模块
ModuleList 接收一个子模块(或层,需属于nn.Module类)的列表作为输入,然后也可以类似List那样进行append和extend操作。同时,子模块或层的权重也会自动添加到网络中来。
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1]) # 类似List的索引访问
print(net)
'''
Linear(in_features=256, out_features=10, bias=True)
ModuleList(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=10, bias=True)
)
'''要特别注意的是,nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起。ModuleList中元素的先后顺序并不代表其在网络中的真实位置顺序,需要经过forward函数指定各个层的先后顺序后才算完成了模型的定义。具体实现时用for循环即可完成:
class model(nn.Module):
def __init__(self, ...):
super().__init__()
self.modulelist = ...
...
def forward(self, x):
for layer in self.modulelist:
x = layer(x)
return x1.3 ModuleDict:nn.ModuleDict() 便于添加名称
ModuleDict和ModuleList的作用类似,只是ModuleDict能够更方便地为神经网络的层添加名称。
net = nn.ModuleDict({
'linear': nn.Linear(784, 256),
'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
'''
Linear(in_features=784, out_features=256, bias=True)
Linear(in_features=256, out_features=10, bias=True)
ModuleDict(
(act): ReLU()
(linear): Linear(in_features=784, out_features=256, bias=True)
(output): Linear(in_features=256, out_features=10, bias=True)
)
'''1.4 三种方法的比较与适用场景
Sequential适用于快速验证结果,因为已经明确了要用哪些层,直接写一下就好了,不需要同时写__init__和forward;
ModuleList和ModuleDict在某个完全相同的层需要重复出现多次时,非常方便实现,可以”一行顶多行“;
当我们需要之前层的信息的时候,比如 ResNets 中的残差计算,当前层的结果需要和之前层中的结果进行融合,一般使用 ModuleList/ModuleDict 比较方便。
2. 利用模型块快速搭建复杂网络
上一节中我们介绍了怎样定义PyTorch的模型,其中给出的示例都是用torch.nn中的层来完成的。这种定义方式易于理解,在实际场景下不一定利于使用。当模型的深度非常大时候,使用Sequential定义模型结构需要向其中添加几百行代码,使用起来不甚方便。
对于大部分模型结构(比如ResNet、DenseNet等),我们仔细观察就会发现,虽然模型有很多层, 但是其中有很多重复出现的结构。考虑到每一层有其输入和输出,若干层串联成的”模块“也有其输入和输出,如果我们能将这些重复出现的层定义为一个”模块“,每次只需要向网络中添加对应的模块来构建模型,这样将会极大便利模型构建的过程。
3. PyTorch修改模型
4. PyTorch模型保存与读取
Reference
边栏推荐
- The pit of mpc5744p reports an error, RTOS cannot be started, and there is a clock source problem
- Summary of some problems encountered in developing WinForm (continuous updating)
- 任正非首度揭秘:华为差点100亿美元“卖身”摩托罗拉背后的故事!
- Concentrate, heart to heart! The Chinese funded mobile phone Enterprises Association (CMA) of India is officially operational!
- MySQL exercises elementary 45 questions (Unified table)
- NFT数字藏品开发:数字藏品助力企业发展
- The first ABAP ALV reporter construction process
- CTO will teach you: how to take over his project when a technician suddenly leaves
- rancher部署kubernetes集群
- PyQt5快速开发与实战 3.5 菜单栏与工具栏
猜你喜欢

Offer set (1)

Module 8 job message data MySQL table design

Linked list - the first common node of two linked lists

Redis persistent rdb/aof

flex布局

Visual VM positioning oom, fullgc usage

SSM整-整合配置

Operations research 69 | explanation of classic examples of dynamic planning

项目中@RequestMapping的作用以及如何使用

模板进阶(跑路人笔记)
随机推荐
Excellent JSON processing tool
Redis核心原理
Database expansion can also be so smooth, MySQL 100 billion level data production environment expansion practice
SSM integration configuration
骚操作:巧用MySQL主从复制延迟拯救误删数据
微软默默给 curl 捐赠一万美元,半年后才通知
ALV screen input option learning
实用工具网站推荐
455. Distribute cookies [double pointer ++i, ++j]
Data security knowledge system
Automated test tool playwright (quick start)
自动化测试工具-Playwright(快速上手)
Concentrate, heart to heart! The Chinese funded mobile phone Enterprises Association (CMA) of India is officially operational!
Flask 封装七牛云
mpc5744p烧录到98%无法继续下载程序
Shader code of parallax map in OpenGL
Redis persistent rdb/aof
Understand in depth why not use system.out.println()
JS刷题计划——数组
Daorayaki | product principles of non-financial decentralized application