当前位置:网站首页>torch.nn.Embedding()详解
torch.nn.Embedding()详解
2022-07-29 05:21:00 【Quinn-ntmy】
PyTorch中的Embedding Layer
一、语法格式
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2.0,
scale_grad_by_freq=False, sparse=False, _weight=None)
1、参数说明
(1)num_embeddings(int)
:语料库字典大小;
(2)embedding_dim(int)
:每个嵌入向量的大小;
(3)padding_idx(int, optional)
:输出遇到此下标时用零填充;
(4)max_norm(float, optional)
:重新归一化词嵌入,使它们的范数小于提供的值;
(5)norm_type(float, optional)
:对应max_norm选项计算p范数时的p,默认值为2;
(上面的4、5两个参数基本不用,通常使用kaiming和xavier初始化参数)
(6)scale_grad_by_freq(boolean, optional)
:将通过小批量中单词频率的倒数来缩放梯度,默认为False。注意!这里的词频指的是自动获取当前小批量中的词频,而非整个词典;
(7) sparse(bool, optional)
:如果为True,则与权重矩阵相关的梯度转变为稀疏张量。
稀疏张量指反向传播时只更新当前使用词的权重矩阵,以加快更新速度。但是,即使设置 sparse=True ,权重矩阵也未必稀疏更新,原因如下:
- 与优化器相关,使用SGD、Adam等优化器时包含momentum项,导致不相关词的Embedding依然会叠加动量,无法稀疏更新;
- 使用weight_decay,即正则项计入损失值。
基本上通常需要设置的参数是前三个
2、变量说明Embedding.weight
为 可学习参数 ,形状为 (num_embeddings, embedding_dim) ,初始化为标准正态分布 (N(0, 10)) 。
输入:input(*),数据类型LongTensor,一般为[mini-batch, nums of index]。
输出:output( * , embedding_dim),其中 * 是输入的形状。
二、实例
import torch
import torch.nn as nn
# 1、定义查找表的形状为10*3
embedding = nn.Embedding(10, 3)
# 2、查看Embedding初始化权重信息
embedding.weight
print(embedding.weight)
# 3、定义输入
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
# 4、将输入中的每个词转换为词嵌入
a = embedding(input)
print(a)
输出结果:
Parameter containing:
tensor([[-1.7372, -0.7281, -1.9509],
[-1.1080, 0.7775, -0.7351],
[ 0.9606, 2.3034, 1.1976],
[-0.6429, 2.1996, -0.0045],
[-0.6949, -1.9427, -0.3486],
[-2.4980, -0.7219, 1.0658],
[-1.4095, 1.7520, 0.7215],
[-0.2162, 0.7108, 0.9062],
[-2.3733, 0.1184, -0.9335],
[-0.0870, 0.1308, -0.6418]], requires_grad=True)
tensor([[[ 0.2644, 0.4962, -2.5476],
[ 1.3521, -0.2055, 0.9044],
[-0.3781, 0.0259, -1.7972],
[-1.0164, -0.5694, -1.0062]],
[[-0.3781, 0.0259, -1.7972],
[-1.6988, -1.1996, -1.7316],
[ 1.3521, -0.2055, 0.9044],
[-1.1474, 0.9734, -0.2874]]], grad_fn=<EmbeddingBackward0>)
Process finished with exit code 0
requires_grad=True,所以weight是可学习的。
三、初始化
Enbedding Layer是如何初始化权重矩阵(即查找表)的??
观察nn.Embedding对应的源码:
class Embedding(Module):
............
if _weight is None:
self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
self.reset_parameters()
else:
............
def reset_parameters(self) -> None:
init.normal_(self.weight)
............
更新weight时主要使用了实例方法self.reset_parameters(),而这个实例方法又调用了初始化(init)模块中的normal_方法。
题外话
对于CNN中的参数:
-可学习的参数:卷积层和全连接层的权重、bias、BatchNorm的 [公式] 等。
-不可学习的参数(超参数):学习率、batch_size、weight_decay、模型的深度宽度分辨率等。
边栏推荐
- Lock lock of concurrent programming learning notes and its implementation basic usage of reentrantlock, reentrantreadwritelock and stampedlock
- Centos7 silently installs Oracle
- [convolution kernel design] scaling up your kernels to 31x31: revising large kernel design in CNN
- 【Transformer】ATS: Adaptive Token Sampling For Efficient Vision Transformers
- 【Transformer】AdaViT: Adaptive Tokens for Efficient Vision Transformer
- Basic use of array -- traverse the circular array to find the maximum value, minimum value, maximum subscript and minimum subscript of the array
- PHP write a diaper to buy the lowest price in the whole network
- How to PR an open source composer project
- [competition website] collect machine learning / deep learning competition website (continuously updated)
- 研究生新生培训第一周:深度学习和pytorch基础
猜你喜欢
第三周周报 ResNet+ResNext
【Transformer】SOFT: Softmax-free Transformer with Linear Complexity
Exploration of flutter drawing skills: draw arrows together (skill development)
FFmpeg创作GIF表情包教程来了!赶紧说声多谢乌蝇哥?
How to PR an open source composer project
【Transformer】ATS: Adaptive Token Sampling For Efficient Vision Transformers
Thinkphp6 pipeline mode pipeline use
Reporting service 2016 custom authentication
30 knowledge points that must be mastered in quantitative development [what is level-2 data]
Tear the ORM framework by hand (generic + annotation + reflection)
随机推荐
Synchronous development with open source projects & codereview & pull request & Fork how to pull the original warehouse
【Transformer】ACMix:On the Integration of Self-Attention and Convolution
[clustmaps] visitor statistics
Use of xtrabackup
[CV] what are the specific numbers of convolution kernels (filters) 3*3, 5*5, 7*7 and 11*11?
并发编程学习笔记 之 原子操作类AtomicReference、AtomicStampedReference详解
Nifi changed UTC time to CST time
【pycharm】pycharm远程连接服务器
Simple optimization of interesting apps for deep learning (suitable for novices)
【综述】图像分类网络
Activity交互问题,你确定都知道?
Analysis on the principle of flow
Android studio login registration - source code (connect to MySQL database)
Android Studio 实现登录注册-源代码 (连接MySql数据库)
并发编程学习笔记 之 Lock锁及其实现类ReentrantLock、ReentrantReadWriteLock和StampedLock的基本用法
ANR优化:导致 OOM 崩溃及相对应的解决方案
Technology that deeply understands the principle of MMAP and makes big manufacturers love it
Spring, summer, autumn and winter with Miss Zhang (5)
Show profiles of MySQL is used.
Markdown syntax