当前位置:网站首页>(pytorch进阶之路二)word embedding 和 position embedding
(pytorch进阶之路二)word embedding 和 position embedding
2022-06-22 14:28:00 【likeGhee】
word embedding
embedding作用是将高维离散token映射到低维稠密token
假设一个任务背景:英语翻译德语,首先我们需要构建一个英语句子 源序列 source sentence和一个目标序列 target sentence 德语,源序列 src_seq 和 目标序列 tgt_seq
序列该如何构建呢?接触过NLP应该都不陌生,序列的字符以词表dict的索引的形式表示
规定序列长度,假设为src和tgt len
# %%
import numpy
import torch as T
import torch.nn as nn
import torch.nn.functional as F
# %%
# 假设有两个句子
batch_size = 2
# 每个句子长度为2~5
src_len = T.randint(2, 5, (batch_size, ))
tgt_len = T.randint(2, 5, (batch_size, ))
# 方便研究,我们写死
src_len = T.Tensor([2, 4]).to(T.int32)
tgt_len = T.Tensor([4, 3]).to(T.int32)
print(src_len)
print(tgt_len)
输出结果tensor([2, 4]),tensor([4, 2]),说明src句子长度为2和4,tgt句子长度为4和2,一共两个句子
接着我们构建seq,假设src和tgt dict最大序号为8,就是最大单词数量都是8,随机生成seq放入list,为了保证句子长度一致,我还需要padding操作,使用functional里的pad函数,之后序列用unsqueeze、cat转换成[batch_size, max_len]形式的tensor作为batch输入
# %%
# 单词表大小
max_source_word_num = 8
max_target_word_num = 8
# 最大序列长度
max_source_seq_len = 5
max_target_seq_len = 5
# 生成seq
src_seq = [T.randint(1, max_source_word_num, (L,)) for L in src_len]
# padding
src_seq = list(map(lambda x: F.pad(x, (0, max_source_seq_len - len(x))), src_seq))
# 升一维方便我们拼接
src_seq = list(map(lambda x: T.unsqueeze(x, 0), src_seq))
# 拼接
src_seq = T.cat(src_seq, 0)
print(src_seq)
tgt_seq = [F.pad(T.randint(1, max_target_word_num, (L,)), (0, max_target_seq_len-L)) for L in tgt_len]
tgt_seq = list(map(lambda x: T.unsqueeze(x, 0), tgt_seq))
tgt_seq = T.cat(tgt_seq, 0)
print(tgt_seq)
输出结果:
输入完成,中间部分embedding,使用pytorch的API,nn.Embedding
第一个参数 num_embeddings,单词数,我们一般取最大单词表大小 + 1,padding的0算上
第二个参数 embedding_dim, 词向量维数,一般是512,方便我们取8
# %%
model_dim = 8
# 构造embedding table
src_embedding_table = nn.Embedding(max_source_word_num + 1, model_dim)
tgt_embedding_table = nn.Embedding(max_target_word_num + 1, model_dim)
print(src_embedding_table.weight.size())
# 测试一下forward
src_embedding = src_embedding_table(src_seq)
print(src_embedding.size())

position embedding
Attention is all you need中有PE(position embedding)的表达式,大体思路是将单词在句子的位置信息转换为一个向量,再与WE(word embedding)相加
首先PE是一个二维的矩阵:[max_len, dim],最大长度可以和max_source_seq_len一致,这里规定max_position_len=5
PE矩阵可以看作是两个矩阵相乘(不是矩阵乘法而是广播逐个元素的乘法),一个矩阵是pos(/左边),另一个矩阵是i(/右边),奇数列和偶数列再分别乘sin和cos
# %%
max_position_len = 5
pos_matrix = T.arange(max_position_len).reshape((-1, 1))
print(pos_matrix)
# 因为要分奇数列和偶数列,所以间隔为2
i_matrix = T.pow(10000, T.arange(0, model_dim, 2).reshape([1, -1]) / model_dim)
print(i_matrix)
# 构建embedding矩阵
pe_embedding_table = T.zeros([max_position_len, model_dim])
# 偶数列,行不变,0::2偶数列,意思是下标从0开始,直到最后,取步长为2的所有元素
pe_embedding_table[:, 0::2] = T.sin(pos_matrix / i_matrix)
# 奇数列
pe_embedding_table[:, 1::2] = T.cos(pos_matrix / i_matrix)
print(pe_embedding_table)
构造nn.Module,替换掉weight
# %%
# 改写nn Module weight方式创建pe embedding
pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)
print(pe_embedding.weight.size())
构造输入,我们需要传入位置索引,自然就是用range操作了,最后计算出PE
# %%
# 构造位置索引
src_pos = T.cat([T.unsqueeze(T.arange(max_position_len), 0) for _ in src_len] , 0)
print(src_pos)
tgt_pos = T.cat([T.unsqueeze(T.arange(max_position_len), 0) for _ in tgt_len] , 0)
# forword计算src-pe
src_pe_embedding = pe_embedding(src_pos)
print(src_pe_embedding.size())
完整代码
# %%
from pyexpat import model
from turtle import pos
import numpy
import torch as T
import torch.nn as nn
import torch.nn.functional as F
# %%
# 假设有两个句子
batch_size = 2
# 每个句子长度为2~5
src_len = T.randint(2, 5, (batch_size, ))
tgt_len = T.randint(2, 5, (batch_size, ))
print(src_len)
print(tgt_len)
# 方便研究,我们写死
src_len = T.Tensor([2, 4]).to(T.int32)
tgt_len = T.Tensor([4, 3]).to(T.int32)
print(src_len)
print(tgt_len)
# %%
# 单词表大小
max_source_word_num = 8
max_target_word_num = 8
# 最大序列长度
max_source_seq_len = 5
max_target_seq_len = 5
# 生成seq
src_seq = [T.randint(1, max_source_word_num, (L,)) for L in src_len]
# padding
src_seq = list(map(lambda x: F.pad(x, (0, max_source_seq_len - len(x))), src_seq))
# 升一维方便我们拼接
src_seq = list(map(lambda x: T.unsqueeze(x, 0), src_seq))
# 拼接
src_seq = T.cat(src_seq, 0)
print(src_seq)
tgt_seq = [F.pad(T.randint(1, max_target_word_num, (L,)), (0, max_target_seq_len-L)) for L in tgt_len]
tgt_seq = list(map(lambda x: T.unsqueeze(x, 0), tgt_seq))
tgt_seq = T.cat(tgt_seq, 0)
print(tgt_seq)
# %%
model_dim = 8
src_embedding_table = nn.Embedding(max_source_word_num + 1, model_dim)
tgt_embedding_table = nn.Embedding(max_target_word_num + 1, model_dim)
print(src_embedding_table.weight.size())
# 测试一下forward
src_embedding = src_embedding_table(src_seq)
print(src_embedding.size())
# %%
# %%
max_position_len = 5
pos_matrix = T.arange(max_position_len).reshape((-1, 1))
print(pos_matrix)
# 因为要分奇数列和偶数列,所以间隔为2
i_matrix = T.pow(10000, T.arange(0, model_dim, 2).reshape([1, -1]) / model_dim)
print(i_matrix)
# 构建embedding矩阵
pe_embedding_table = T.zeros([max_position_len, model_dim])
# 偶数列,行不变,0::2偶数列,意思是下标从0开始,直到最后,取步长为2的所有元素
pe_embedding_table[:, 0::2] = T.sin(pos_matrix / i_matrix)
# 奇数列
pe_embedding_table[:, 1::2] = T.cos(pos_matrix / i_matrix)
print(pe_embedding_table)
# %%
# 改写nn Module weight方式创建pe embedding
pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)
print(pe_embedding.weight.size())
# %%
# 构造位置索引
src_pos = T.cat([T.unsqueeze(T.arange(max_position_len), 0) for _ in src_len] , 0)
print(src_pos)
tgt_pos = T.cat([T.unsqueeze(T.arange(max_position_len), 0) for _ in tgt_len] , 0)
# forword计算src-pe
src_pe_embedding = pe_embedding(src_pos)
print(src_pe_embedding.size())
边栏推荐
- MongoDB在腾讯零售优码中的应用
- 历时100天、小鱼搭建了个机器人交流社区!!现公开邀请版主中!
- Keil simulation and VSPD
- Wallys/DR7915-wifi6-MT7915-MT7975-2T2R-support-OpenWRT-802.11AX-supporting-MiniPCIe
- 求求了,别被洗脑了,这才是90%中国人的生存实况
- mysql如何修改存储引擎为innodb
- RealNetworks vs. Microsoft: the battle in the early streaming media industry
- I took a private job and earned 15250. Is it still necessary to do my main business?
- 网络安全的五大特点有哪些?五大属性是什么?
- 润迈德医疗通过聆讯:年内亏损6.3亿 平安资本是股东
猜你喜欢

曾经,我同时兼职5份工作,只为给女友买个新款耳环......
![Found several packages [runtime, main] in ‘/usr/local/Cellar/go/1.18/libexec/src/runtime;](/img/75/d2ad171d49611a6578faf2d390af29.jpg)
Found several packages [runtime, main] in ‘/usr/local/Cellar/go/1.18/libexec/src/runtime;

MongoDB在腾讯零售优码中的应用

基于最小化三维NDT距离的快速精确点云配准

好风凭借力 – 使用Babelfish 加速迁移 SQL Server 的代码转换实践

Countdown to the conference - Amazon cloud technology innovation conference invites you to build a new AI engine!

On the routing tree of gin

Live broadcast goes to sea | domestic live broadcast room produces explosive products again. How can "roll out" win the world

鸿世电器冲刺创业板:年营收6亿 刘金贤股权曾被广德小贷冻结

PowerPoint 教程,如何在 PowerPoint 中添加水印?
随机推荐
Countdown to the conference - Amazon cloud technology innovation conference invites you to build a new AI engine!
NF RESNET: network signal analysis worth reading after removing BN normalization | ICLR 2021
Once, I had 5 part-time jobs just to buy a new earring for my girlfriend
What are strong and weak symbols in embedded systems?
我靠副业一年全款买房:那个你看不起的行业,未来十年很赚钱!
英国考虑基于国家安全因素让Arm在伦敦上市
数据库连接池:代码目录
At 19:00 this Thursday evening, the 7th live broadcast of battle code Pioneer - how third-party application developers contribute to open source
MongoDB在腾讯零售优码中的应用
Is SQL analysis query unavailable in the basic version?
Is pioneer futures reliable? How to open a futures account safely?
Show me my personal work list for the past two years. I earn 6K a month in my spare time. It's so delicious to have a sideline
大佬们 2.2.1cdc 监控sqlsever 只能拿到全量的数据 后期增量的数据拿不到 咋回事啊
KEIL仿真和vspd
基于最小化三维NDT距离的快速精确点云配准
华为机器学习服务银行卡识别功能,一键实现银行卡识别与绑定
数据库连接池:压力测试
mysql如何修改存储引擎为innodb
C语言学生成绩排名系统
封装api时候token的处理