当前位置:网站首页>Torch. NN. Embedding() details
Torch. NN. Embedding() details
2022-07-29 06:11:00 【Quinn-ntmy】
PyTorch Medium Embedding Layer
One 、 Grammar format
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2.0,
scale_grad_by_freq=False, sparse=False, _weight=None)
1、 Parameter description
(1)num_embeddings(int): Corpus dictionary size ;
(2)embedding_dim(int): The size of each embedded vector ;
(3)padding_idx(int, optional): When the output encounters this subscript, fill it with zero ;
(4)max_norm(float, optional): Re normalize word embedding , Make their norm less than the value provided ;
(5)norm_type(float, optional): Corresponding max_norm Option calculation p Norm p, The default value is 2;
( above 4、5 Two parameters are basically not used , Usually use kaiming and xavier Initialize parameters )
(6)scale_grad_by_freq(boolean, optional): The gradient will be scaled by the reciprocal of the word frequency in the small batch , The default is False. Be careful ! The word frequency here refers to the automatic acquisition of the word frequency in the current small batch , Not the whole dictionary ;
(7) sparse(bool, optional): If True, Then the gradient associated with the weight matrix is transformed into a sparse tensor .
Sparse tensor means that only the weight matrix of the currently used words is updated during back propagation , To speed up the update . however , Even if the setting sparse=True , Weight matrix is not necessarily sparse update , Here's why :
- Optimizer related , Use SGD、Adam When waiting for the optimizer, include momentum term , Leading to irrelevant words Embedding Still add momentum , Cannot sparse update ;
- Use weight_decay, That is, the regular term is included in the loss value .
Basically, the first three parameters usually need to be set
2、 Variable description Embedding.weight by Learnable parameters , Shape is (num_embeddings, embedding_dim) , Initialize to standard normal distribution (N(0, 10)) .
Input :input(*), data type LongTensor, It's usually [mini-batch, nums of index].
Output :output( * , embedding_dim), among * Is the input shape .
Two 、 example
import torch
import torch.nn as nn
# 1、 Define the shape of the lookup table as 10*3
embedding = nn.Embedding(10, 3)
# 2、 see Embedding Initialize weight information
embedding.weight
print(embedding.weight)
# 3、 Define input
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
# 4、 Convert every word in the input into word embedding
a = embedding(input)
print(a)
Output results :
Parameter containing:
tensor([[-1.7372, -0.7281, -1.9509],
[-1.1080, 0.7775, -0.7351],
[ 0.9606, 2.3034, 1.1976],
[-0.6429, 2.1996, -0.0045],
[-0.6949, -1.9427, -0.3486],
[-2.4980, -0.7219, 1.0658],
[-1.4095, 1.7520, 0.7215],
[-0.2162, 0.7108, 0.9062],
[-2.3733, 0.1184, -0.9335],
[-0.0870, 0.1308, -0.6418]], requires_grad=True)
tensor([[[ 0.2644, 0.4962, -2.5476],
[ 1.3521, -0.2055, 0.9044],
[-0.3781, 0.0259, -1.7972],
[-1.0164, -0.5694, -1.0062]],
[[-0.3781, 0.0259, -1.7972],
[-1.6988, -1.1996, -1.7316],
[ 1.3521, -0.2055, 0.9044],
[-1.1474, 0.9734, -0.2874]]], grad_fn=<EmbeddingBackward0>)
Process finished with exit code 0
requires_grad=True, therefore weight It's learnable .
3、 ... and 、 initialization
Enbedding Layer How to initialize the weight matrix ( Look up table ) Of ??
Observe nn.Embedding Corresponding source code :
class Embedding(Module):
............
if _weight is None:
self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
self.reset_parameters()
else:
............
def reset_parameters(self) -> None:
init.normal_(self.weight)
............
to update weight It mainly uses the instance method self.reset_parameters(), And this instance method calls initialization (init) Module normal_ Method .
Digression
about CNN Parameters in :
- Learnable parameters : Weight of convolution layer and full connection layer 、bias、BatchNorm Of [ The formula ] etc. .
- Non learnable parameters ( Hyperparameters ): Learning rate 、batch_size、weight_decay、 Depth, width and resolution of the model .
边栏推荐
- [network design] convnext:a convnet for the 2020s
- 5、 Image pixel statistics
- AttributeError: module ‘tensorflow‘ has no attribute ‘placeholder‘
- 引入Spacy模块出错—OSError: [E941] Can‘t find model ‘en‘.
- 【语义分割】Fully Attentional Network for Semantic Segmentation
- 二、如何保存MNIST数据集中train和test的图片?
- Change! Change! Change!
- The differences and reasons between MySQL with and without quotation marks when querying string types
- HR面必问问题——如何与HR斗志斗勇(收集于FPGA探索者)
- [tensorrt] convert pytorch into deployable tensorrt
猜你喜欢

【Transformer】SegFormer:Simple and Efficient Design for Semantic Segmentation with Transformers

These process knowledge you must know

迁移学习—— Transfer Feature Learning with Joint Distribution Adaptation

ML自学笔记5

ML15 neural network (1)

2022春招——芯动科技FPGA岗技术面(一面心得)

GAN:生成对抗网络 Generative Adversarial Networks

零基础学FPGA(五):时序逻辑电路设计之计数器(附有呼吸灯实验、简单组合逻辑设计介绍)

一、Focal Loss理论及代码实现

第三周周报 ResNet+ResNext
随机推荐
Anr Optimization: cause oom crash and corresponding solutions
HR面必问问题——如何与HR斗志斗勇(收集于FPGA探索者)
Typical cases of xdfs & China Daily Online Collaborative Editing Platform
6、 Pointer meter recognition based on deep learning key points
[semantic segmentation] Introduction to mapillary dataset
二、多并发实现接口压力测试
Configuration and use of Nacos external database
[target detection] KL loss: bounding box progression with uncertainty for accurate object detection
Wechat applet source code acquisition (download with tools)
Wechat built-in browser prohibits caching
GAN:生成对抗网络 Generative Adversarial Networks
Yum local source production
六、基于深度学习关键点的指针式表计识别
Transfer feature learning with joint distribution adaptation
ML8自学笔记-LDA原理公式推导
Are you sure you know the interaction problem of activity?
迁移学习——Transfer Joint Matching for Unsupervised Domain Adaptation
电脑视频暂停再继续,声音突然变大
逻辑回归-项目实战-信用卡检测任务(下)
AttributeError: module ‘tensorflow‘ has no attribute ‘placeholder‘