当前位置:网站首页>PyTorch中的模型构建
PyTorch中的模型构建
2022-07-29 05:21:00 【Quinn-ntmy】
一、构建模型的两个要素
- 构建子模块:在自己建立的模型(继承nn.Module)的
__init__()方法; - 拼接子模块:在模型的
forward()方法中。
二、nn.Module类
模型中的 nn.Module :我们所有的模型,所有的网络层都是继承与这个类的。torch.nn包括(1)nn.Parameter、(2)nn.functional、(3)nn.Module、(4)nn.init,这几个子模块协同工作。
1.nn.Parameter
张量子类,表示可学习参数,如weight、bias。
模型的参数是需要被优化器训练的,因此通常要设置参数为requires_grad = True的张量。同时,在一个模型中,往往有许多的参数,手动管理不容易。一般将参数用nn.Parameter表示,并且用nn.Module来管理其结构下的所有参数。
代码示例
如子模块Attention中的可学习参数:
if score_function == 'mlp':
self.weight = nn.Parameter(torch.Tensor(hidden_dim*2))
elif self.score_function == 'bi_linear':
self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
else: # dot_product / scaled_dot_product
self.register_parameter('weight', None)
self.reset_parameters()
实践当中,一般通过继承nn.Module来构建模块类,并将所有含有需要学习的参数的部分放在构造函数中。
class AEN_BERT(nn.Module):
def __init__(self, bert, opt):
super(AEN_BERT, self).__init__()
self.opt = opt
self.bert = bert
self.squeeze_embedding = SqueezeEmbedding()
self.dropout = nn.Dropout(opt.dropout)
self.attn_k = Attention(opt.bert_dim, out_dim=opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
self.attn_q = Attention(opt.bert_dim, out_dim=opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
self.ffn_c = PositionwiseFeedForward(opt.hidden_dim, dropout=opt.dropout)
self.ffn_t = PositionwiseFeedForward(opt.hidden_dim, dropout=opt.dropout)
self.attn_s1 = Attention(opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
self.dense = nn.Linear(opt.hidden_dim*3, opt.polarities_dim)
def forward(self, inputs):
context, target = inputs[0], inputs[1]
context_len = torch.sum(context != 0, dim=-1)
target_len = torch.sum(target != 0, dim=-1)
context = self.squeeze_embedding(context, context_len)
context, _ = self.bert(context, return_dict=False)
context = self.dropout(context)
target = self.squeeze_embedding(target, target_len)
target, _ = self.bert(target, return_dict=False)
target = self.dropout(target)
hc, _ = self.attn_k(context, context) # 内省上下文词建模
hc = self.ffn_c(hc) # 逐点卷积变换
ht, _ = self.attn_q(context, target) # 上下文感知目标词建模
ht = self.ffn_t(ht) # 逐点卷积变换
s1, _ = self.attn_s1(hc, ht) # 目标特定上下文表示
# 论文中的平均池化??average pooling 输出的最终表示
hc_mean = torch.div(torch.sum(hc, dim=1), context_len.unsqueeze(1).float())
ht_mean = torch.div(torch.sum(ht, dim=1), target_len.unsqueeze(1).float())
s1_mean = torch.div(torch.sum(s1, dim=1), context_len.unsqueeze(1).float())
# torch.div(a, b ):张量a和标量b做逐元素除法,或者两个可广播的张量a、b之间做逐元素除法
x = torch.cat((hc_mean, s1_mean, ht_mean), dim=-1) # concat 连接到一起
out = self.dense(x) # 使用 nn.Linear 全连接层
return out
可以看到构建了模块类AEN_BERT,其中包括子模块Attention,模型中含有需要学习的参数的部分就放在构造的函数(子模块)中。
2. nn.functionalnn.functional:函数的具体实现。如:
(1)激活函数系列(F.relu,F.sigmoid,F.tanh,F.softmax)
(2) 模型层系列(F.linear,F.conv2d,F.max_pool2d,F.dropout2d,F.embedding)
(3)损失函数系列(F.binary_cross_entropy,F.mse_loss,F.cross_entropy)
为了便于对参数进行管理,一般通过继承 nn.Module 转换为类的实现形式,并直接封装在 nn 模块下:
(1)激活函数变成(nn.Relu,nn.Sigmoid,nn.Tanh,nn.Softmax)
(2) 模型层(nn.Linear,nn.Conv2d,nn.Max_pool2d,nn.Dropout2d,nn.Embedding)
(3)损失函数(nn.BCELoss,nn.MSELoss,nn.CrossEntorpyLoss)
3. nn.Module
所有网络层基类,管理有关网络的属性。
在 nn.Module 中,有8个重要的属性,用于管理整个模型,它们都是以有序字典的形式存在着:
self._parameters: Dict[str, Optional[Parameter]] = OrderedDict()
self._buffers: Dict[str, Optional[Tensor]] = OrderedDict()
self._backward_hooks: Dict[int, Callable] = OrderedDict()
self._forward_hooks: Dict[int, Callable] = OrderedDict()
self._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
self._state_dict_hooks: Dict[int, Callable] = OrderedDict()
self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
self._modules: Dict[str, Optional['Module']] = OrderedDict()
(1)_parameters:存储管理属于nn.Parameter类的属性,例如weight、bias这些参数;
(2)_modules:存储管理 nn.Module 类;
(3)_buffers:存储管理缓冲属性,如BN层中的running_mean,std等都会存在这里面;
(4)***_hooks:存储管理钩子函数(5个与hooks有关的字典)。nn.Module构建属性的机制:先是有一个大的Module继承nn.Module这个基类,比如上面的 AEN_BERT,然后这个大的Module里面又可以有很多的子模块,这些子模块同样也继承于nn.Module,在这些Module的__init__方法中,会先通过调用父类的初始化方法进行8个属性的初始化。
然后在构建每个子模块的时候,分为两步,第一步是初始化,然后被__setattr__方法通过判断 value 的类型将其保存到相应的属性字典里面去,然后再进行赋值给相应的成员。这样一个一个地构建子模块,最终整个大的Module构建完成。
总结:
- 一个Module可以包含多个子module;
- 一个Module相当于一个运算,必须实现
forward()函数; - 每个Module都有8个字典管理它的属性(最常用的是
_parameters,_modules)
一般情况下,我们都很少直接使用nn.Parameter来定义参数构建模型,而是通过拼装一些常用的模型层。这些模型层也是继承自nn.Module的对象,本身也包括参数,属于我们要定义的模块的子模块。
nn.Module 提供了一些方法可以管理这些子模块:
children()方法:返回生成器,包括模块下的所有子模块;
named_children()方法:返回一个生成器,包括模块下的所有子模块,以及它们的名字;
modules()方法:返回一个生成器,包括模块下的所有各个层级的模块,包括模块本身;
named_modules()方法:返回一个生成器,包括模块下的所有各个层级的模块以及它们的名字,包括模块本身。
其中children()方法和named_children()方法较多使用,modules()方法和named_modules()方法较少使用,其功能可以通过多个named_children()的嵌套使用实现。
4. nn.init:参数初始化方法。
边栏推荐
- 通过简单的脚本在Linux环境实现Mysql数据库的定时备份(Mysqldump命令备份)
- isAccessible()方法:使用反射技巧让你的性能提升数倍
- anaconda中移除旧环境、增加新环境、查看环境、安装库、清理缓存等操作命令
- Briefly talk about the difference between pendingintent and intent
- Show profiles of MySQL is used.
- 【go】defer的使用
- Spring, summer, autumn and winter with Miss Zhang (2)
- Thinkphp6 pipeline mode pipeline use
- Basic use of array -- traverse the circular array to find the maximum value, minimum value, maximum subscript and minimum subscript of the array
- Personal learning website
猜你喜欢

并发编程学习笔记 之 ReentrantLock实现原理的探究

【pycharm】pycharm远程连接服务器

File permissions of day02 operation

Thinkphp6 pipeline mode pipeline use

Spring, summer, autumn and winter with Miss Zhang (3)

DataX installation

简单聊聊 PendingIntent 与 Intent 的区别

备份谷歌或其他浏览器插件

Semaphore (semaphore) for learning notes of concurrent programming

【Transformer】SOFT: Softmax-free Transformer with Linear Complexity
随机推荐
Exploration of flutter drawing skills: draw arrows together (skill development)
Valuable blog and personal experience collection (continuous update)
Flink, the mainstream real-time stream processing computing framework, is the first experience.
简单聊聊 PendingIntent 与 Intent 的区别
GA-RPN:引导锚点的建议区域网络
anaconda中移除旧环境、增加新环境、查看环境、安装库、清理缓存等操作命令
第一周任务 深度学习和pytorch基础
Flutter正在被悄悄放弃?浅析Flutter的未来
Ribbon学习笔记一
并发编程学习笔记 之 原子操作类AtomicReference、AtomicStampedReference详解
【Transformer】ACMix:On the Integration of Self-Attention and Convolution
Huawei 2020 school recruitment written test programming questions read this article is enough (Part 2)
【语义分割】语义分割综述
Personal learning website
Isaccessible() method: use reflection techniques to improve your performance several times
【Transformer】AdaViT: Adaptive Vision Transformers for Efficient Image Recognition
Android Studio 实现登录注册-源代码 (连接MySql数据库)
asyncawait和promise的区别
虚假新闻检测论文阅读(三):Semi-supervised Content-based Detection of Misinformation via Tensor Embeddings
Realize the scheduled backup of MySQL database in Linux environment through simple script (mysqldump command backup)