当前位置:网站首页>torch.nn.Embedding()详解
torch.nn.Embedding()详解
2022-07-29 05:21:00 【Quinn-ntmy】
PyTorch中的Embedding Layer
一、语法格式
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2.0,
scale_grad_by_freq=False, sparse=False, _weight=None)
1、参数说明
(1)num_embeddings(int):语料库字典大小;
(2)embedding_dim(int):每个嵌入向量的大小;
(3)padding_idx(int, optional):输出遇到此下标时用零填充;
(4)max_norm(float, optional):重新归一化词嵌入,使它们的范数小于提供的值;
(5)norm_type(float, optional):对应max_norm选项计算p范数时的p,默认值为2;
(上面的4、5两个参数基本不用,通常使用kaiming和xavier初始化参数)
(6)scale_grad_by_freq(boolean, optional):将通过小批量中单词频率的倒数来缩放梯度,默认为False。注意!这里的词频指的是自动获取当前小批量中的词频,而非整个词典;
(7) sparse(bool, optional):如果为True,则与权重矩阵相关的梯度转变为稀疏张量。
稀疏张量指反向传播时只更新当前使用词的权重矩阵,以加快更新速度。但是,即使设置 sparse=True ,权重矩阵也未必稀疏更新,原因如下:
- 与优化器相关,使用SGD、Adam等优化器时包含momentum项,导致不相关词的Embedding依然会叠加动量,无法稀疏更新;
- 使用weight_decay,即正则项计入损失值。
基本上通常需要设置的参数是前三个
2、变量说明Embedding.weight为 可学习参数 ,形状为 (num_embeddings, embedding_dim) ,初始化为标准正态分布 (N(0, 10)) 。
输入:input(*),数据类型LongTensor,一般为[mini-batch, nums of index]。
输出:output( * , embedding_dim),其中 * 是输入的形状。
二、实例
import torch
import torch.nn as nn
# 1、定义查找表的形状为10*3
embedding = nn.Embedding(10, 3)
# 2、查看Embedding初始化权重信息
embedding.weight
print(embedding.weight)
# 3、定义输入
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
# 4、将输入中的每个词转换为词嵌入
a = embedding(input)
print(a)
输出结果:
Parameter containing:
tensor([[-1.7372, -0.7281, -1.9509],
[-1.1080, 0.7775, -0.7351],
[ 0.9606, 2.3034, 1.1976],
[-0.6429, 2.1996, -0.0045],
[-0.6949, -1.9427, -0.3486],
[-2.4980, -0.7219, 1.0658],
[-1.4095, 1.7520, 0.7215],
[-0.2162, 0.7108, 0.9062],
[-2.3733, 0.1184, -0.9335],
[-0.0870, 0.1308, -0.6418]], requires_grad=True)
tensor([[[ 0.2644, 0.4962, -2.5476],
[ 1.3521, -0.2055, 0.9044],
[-0.3781, 0.0259, -1.7972],
[-1.0164, -0.5694, -1.0062]],
[[-0.3781, 0.0259, -1.7972],
[-1.6988, -1.1996, -1.7316],
[ 1.3521, -0.2055, 0.9044],
[-1.1474, 0.9734, -0.2874]]], grad_fn=<EmbeddingBackward0>)
Process finished with exit code 0
requires_grad=True,所以weight是可学习的。
三、初始化
Enbedding Layer是如何初始化权重矩阵(即查找表)的??
观察nn.Embedding对应的源码:
class Embedding(Module):
............
if _weight is None:
self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
self.reset_parameters()
else:
............
def reset_parameters(self) -> None:
init.normal_(self.weight)
............
更新weight时主要使用了实例方法self.reset_parameters(),而这个实例方法又调用了初始化(init)模块中的normal_方法。
题外话
对于CNN中的参数:
-可学习的参数:卷积层和全连接层的权重、bias、BatchNorm的 [公式] 等。
-不可学习的参数(超参数):学习率、batch_size、weight_decay、模型的深度宽度分辨率等。
边栏推荐
- File permissions of day02 operation
- Android studio login registration - source code (connect to MySQL database)
- Windos下安装pyspider报错:Please specify --curl-dir=/path/to/built/libcurl解决办法
- Flutter正在被悄悄放弃?浅析Flutter的未来
- mysql在查询字符串类型的时候带单引号和不带的区别和原因
- [network design] convnext:a convnet for the 2020s
- 【Transformer】ATS: Adaptive Token Sampling For Efficient Vision Transformers
- 通过简单的脚本在Linux环境实现Mysql数据库的定时备份(Mysqldump命令备份)
- Process management of day02 operation
- 【数据库】数据库课程设计一一疫苗接种数据库
猜你喜欢

Exploration of flutter drawing skills: draw arrows together (skill development)

MySql统计函数COUNT详解

mysql在查询字符串类型的时候带单引号和不带的区别和原因

【Attention】Visual Attention Network

简单聊聊 PendingIntent 与 Intent 的区别

Spring, summer, autumn and winter with Miss Zhang (2)

day02 作业之文件权限

Semaphore (semaphore) for learning notes of concurrent programming

GA-RPN:引导锚点的建议区域网络

性能优化之趣谈线程池:线程开的越多就越好吗?
随机推荐
Personal learning website
The difference between asyncawait and promise
ANR优化:导致 OOM 崩溃及相对应的解决方案
[pycharm] pycharm remote connection server
ASM插桩:学完ASM Tree api,再也不用怕hook了
Rsync+inotyfy realize real-time synchronization of single data monitoring
【综述】图像分类网络
File permissions of day02 operation
第2周学习:卷积神经网络基础
mysql在查询字符串类型的时候带单引号和不带的区别和原因
Reporting service 2016 custom authentication
【卷积核设计】Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs
Spring, summer, autumn and winter with Miss Zhang (4)
Simple optimization of interesting apps for deep learning (suitable for novices)
【pycharm】pycharm远程连接服务器
Android Studio 实现登录注册-源代码 (连接MySql数据库)
yum本地源制作
MarkDown简明语法手册
30 knowledge points that must be mastered in quantitative development [what is individual data]?
并发编程学习笔记 之 工具类Semaphore(信号量)