当前位置:网站首页>Word vector training based on nnlm
Word vector training based on nnlm
2022-06-12 06:06:00 【Singing under the hedge】
List of articles
be based on NNLM Word vector training
adopt NNLM Training word vectors
One 、 corpus

Two 、 Complete code
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
dtype = torch.FloatTensor
sentence = ["i like dog","i love coffee","i hate milk"]
word_list = " ".join(sentence).split()# Divide the words according to the blanks
word_list = list(set(word_list))# use set Remove duplicate words
word_dict = {
w:i for i,w in enumerate(word_list)}# to word One id
#{'i': 0, 'like': 1, 'dog': 2, 'coffee': 3, 'hate': 4, 'milk': 5, 'love': 6}
number_dict = {
i:w for i,w in enumerate(word_list)}
#{0: 'i', 1: 'like', 2: 'dog', 3: 'coffee', 4: 'hate', 5: 'milk', 6: 'love'}
n_class = len(word_dict)#7
n_step =2
n_hidden=2
m=2
def make_batch(sentence):
input_batch = []
target_batch =[]
for sen in sentence:
word = sen.split()
input = [word_dict[n] for n in word[:-1]]
target = word_dict[word[-1]]
input_batch.append(input)
target_batch.append(target)
return input_batch,target_batch
class NNLM(nn.Module):
def __init__(self):
super(NNLM,self).__init__()
self.C = nn.Embedding(n_class,m)
self.H = nn.Parameter(torch.randn(n_step*m,n_hidden).type(dtype))
self.W = nn.Parameter(torch.randn(n_step*m,n_class).type(dtype))
self.d = nn.Parameter(torch.randn(n_hidden).type(dtype))
self.U = nn.Parameter(torch.randn(n_hidden,n_class).type(dtype))
self.b = nn.Parameter(torch.randn(n_class).type(dtype))
def forward(self,X):
X = self.C(X)
X = X.view(-1, n_step * m)
tanh = torch.tanh(self.d + torch.mm(X,self.H))
output = self.b +torch.mm(X,self.W)+torch.mm(tanh,self.U)
return output
model = NNLM()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)
input_batch,target_batch = make_batch(sentence)
input_batch = Variable(torch.LongTensor(input_batch))
target_batch = Variable(torch.LongTensor(target_batch))
for epoch in range(5000):
optimizer.zero_grad()
output = model(input_batch)
loss = criterion(output,target_batch)
if(epoch+1)%1000 == 0 :
print('Epoch','%04d' %(epoch+1),'cost=','{:.6f}'.format(loss))
loss.backward()
optimizer.step()
predict = model(input_batch).data.max(1,keepdim= True)[1]
print([sen.split()[:2] for sen in sentence],'->',[number_dict[n.item()] for n in predict.squeeze()])
experimental result

边栏推荐
- Leetcode-646. Longest number pair chain
- 姿态估计之2D人体姿态估计 - PifPaf:Composite Fields for Human Pose Estimation
- Unity vscode cannot jump to definition
- Leetcode-717. 1-bit and 2-bit characters (O (1) solution)
- Leetcode simple problem: converting an integer to the sum of two zero free integers
- Who is more fierce in network acceleration? New king reappeared in CDN field
- Houdini & UE4 programmed generation of mountains and multi vegetation scattering points
- Analysis of memory management mechanism of (UE4 4.26) UE4 uobject
- SQLite cross compile dynamic library
- Json-c common APIs
猜你喜欢

Why don't databases use hash tables?

摄像头拍摄运动物体,产生运动模糊/拖影的原因分析

cv2.fillPoly coco annotator segment坐标转化为mask图像

Idea common configuration

肝了一个月的 DDD,一文带你掌握

How do I get the date and time from the Internet- How to get DateTime from the internet?

前台展示LED数字(计算器上数字类型)

SQLite cross compile dynamic library

dlib 人脸检测

(UE4 4.27) customize globalshader
随机推荐
Project and build Publishing
sqlite交叉編譯動態庫
三年磨一剑:蚂蚁金服的研发效能洞察实践
Nrf52832 services et fonctionnalités personnalisés
Json-c common APIs
EBook upload
Simple spiral ladder generation for Houdini program modeling
IDEA常用配置
Houdini terrain creation
Annotation configuration of filter
交叉编译libev
A preliminary understanding of function
A month's worth of DDD will help you master it
Leetcode-1552. Magnetic force between two balls
R语言大作业(四):上海市、东京 1997-2018 年GDP值分析
关于 Sensor flicker/banding现象的解释
China embolic coil market trend report, technical innovation and market forecast
Unable to access this account. You may need to update your password or grant the account permission to synchronize to this device. Tencent corporate email
Image processing: image[:,:,:: -1], image[:,: -1,:], image[:,: -1,:]
Quickly master makefile file writing