当前位置:网站首页>Torch. NN. Embedding() details
Torch. NN. Embedding() details
2022-07-29 06:11:00 【Quinn-ntmy】
PyTorch Medium Embedding Layer
One 、 Grammar format
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2.0,
scale_grad_by_freq=False, sparse=False, _weight=None)
1、 Parameter description
(1)num_embeddings(int)
: Corpus dictionary size ;
(2)embedding_dim(int)
: The size of each embedded vector ;
(3)padding_idx(int, optional)
: When the output encounters this subscript, fill it with zero ;
(4)max_norm(float, optional)
: Re normalize word embedding , Make their norm less than the value provided ;
(5)norm_type(float, optional)
: Corresponding max_norm Option calculation p Norm p, The default value is 2;
( above 4、5 Two parameters are basically not used , Usually use kaiming and xavier Initialize parameters )
(6)scale_grad_by_freq(boolean, optional)
: The gradient will be scaled by the reciprocal of the word frequency in the small batch , The default is False. Be careful ! The word frequency here refers to the automatic acquisition of the word frequency in the current small batch , Not the whole dictionary ;
(7) sparse(bool, optional)
: If True, Then the gradient associated with the weight matrix is transformed into a sparse tensor .
Sparse tensor means that only the weight matrix of the currently used words is updated during back propagation , To speed up the update . however , Even if the setting sparse=True , Weight matrix is not necessarily sparse update , Here's why :
- Optimizer related , Use SGD、Adam When waiting for the optimizer, include momentum term , Leading to irrelevant words Embedding Still add momentum , Cannot sparse update ;
- Use weight_decay, That is, the regular term is included in the loss value .
Basically, the first three parameters usually need to be set
2、 Variable description Embedding.weight
by Learnable parameters , Shape is (num_embeddings, embedding_dim) , Initialize to standard normal distribution (N(0, 10)) .
Input :input(*), data type LongTensor, It's usually [mini-batch, nums of index].
Output :output( * , embedding_dim), among * Is the input shape .
Two 、 example
import torch
import torch.nn as nn
# 1、 Define the shape of the lookup table as 10*3
embedding = nn.Embedding(10, 3)
# 2、 see Embedding Initialize weight information
embedding.weight
print(embedding.weight)
# 3、 Define input
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
# 4、 Convert every word in the input into word embedding
a = embedding(input)
print(a)
Output results :
Parameter containing:
tensor([[-1.7372, -0.7281, -1.9509],
[-1.1080, 0.7775, -0.7351],
[ 0.9606, 2.3034, 1.1976],
[-0.6429, 2.1996, -0.0045],
[-0.6949, -1.9427, -0.3486],
[-2.4980, -0.7219, 1.0658],
[-1.4095, 1.7520, 0.7215],
[-0.2162, 0.7108, 0.9062],
[-2.3733, 0.1184, -0.9335],
[-0.0870, 0.1308, -0.6418]], requires_grad=True)
tensor([[[ 0.2644, 0.4962, -2.5476],
[ 1.3521, -0.2055, 0.9044],
[-0.3781, 0.0259, -1.7972],
[-1.0164, -0.5694, -1.0062]],
[[-0.3781, 0.0259, -1.7972],
[-1.6988, -1.1996, -1.7316],
[ 1.3521, -0.2055, 0.9044],
[-1.1474, 0.9734, -0.2874]]], grad_fn=<EmbeddingBackward0>)
Process finished with exit code 0
requires_grad=True, therefore weight It's learnable .
3、 ... and 、 initialization
Enbedding Layer How to initialize the weight matrix ( Look up table ) Of ??
Observe nn.Embedding Corresponding source code :
class Embedding(Module):
............
if _weight is None:
self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
self.reset_parameters()
else:
............
def reset_parameters(self) -> None:
init.normal_(self.weight)
............
to update weight It mainly uses the instance method self.reset_parameters(), And this instance method calls initialization (init) Module normal_ Method .
Digression
about CNN Parameters in :
- Learnable parameters : Weight of convolution layer and full connection layer 、bias、BatchNorm Of [ The formula ] etc. .
- Non learnable parameters ( Hyperparameters ): Learning rate 、batch_size、weight_decay、 Depth, width and resolution of the model .
边栏推荐
- "Full flash measurement" database acceleration solution
- Power Bi report server custom authentication
- 一、PyTorch Cookbook(常用代码合集)
- GAN:生成对抗网络 Generative Adversarial Networks
- [semantic segmentation] Introduction to mapillary dataset
- [convolution kernel design] scaling up your kernels to 31x31: revising large kernel design in CNN
- Typical cases of xdfs & China Daily Online Collaborative Editing Platform
- Technology that deeply understands the principle of MMAP and makes big manufacturers love it
- 2022春招——芯动科技FPGA岗技术面(一面心得)
- 一、多个txt文件合并成1个txt文件
猜你喜欢
4、 Application of one hot and loss function
【Transformer】ACMix:On the Integration of Self-Attention and Convolution
These process knowledge you must know
ML10自学笔记-SVM
1、 Usage of common loss function
零基础学FPGA(五):时序逻辑电路设计之计数器(附有呼吸灯实验、简单组合逻辑设计介绍)
六、基于深度学习关键点的指针式表计识别
Wechat built-in browser prohibits caching
第一周任务 深度学习和pytorch基础
[semantic segmentation] Introduction to mapillary dataset
随机推荐
入门到入魂:单片机如何利用TB6600高精度控制步进电机(42/57)
Wechat applet source code acquisition (download with tools)
【卷积核设计】Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs
Typical cases of xdfs & China Daily Online Collaborative Editing Platform
3、 How to read video?
引入Spacy模块出错—OSError: [E941] Can‘t find model ‘en‘.
迁移学习——Transfer Joint Matching for Unsupervised Domain Adaptation
Low rank transfer subspace learning
[semantic segmentation] setr_ Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformer
【Transformer】AdaViT: Adaptive Vision Transformers for Efficient Image Recognition
三、如何读取视频?
【语义分割】Fully Attentional Network for Semantic Segmentation
Set automatic build in idea - change the code, and refresh the page without restarting the project
【Transformer】SegFormer:Simple and Efficient Design for Semantic Segmentation with Transformers
基于STM32开源:磁流体蓝牙音箱(包含源码+PCB)
虚假新闻检测论文阅读(五):A Semi-supervised Learning Method for Fake News Detection in Social Media
一、迁移学习与fine-tuning有什么区别?
1、 Usage of common loss function
研究生新生培训第二周:卷积神经网络基础
二、如何保存MNIST数据集中train和test的图片?