当前位置:网站首页>[NLP] [textcnn] text classification
[NLP] [textcnn] text classification
2022-06-30 23:11:00 【myaijarvis】
TextCNN
Must see :【 Reference resources :【 Paper notes 】:Convolutional Neural Networks for Sentence Classification use CNN Do sentence classification - Xiaoqian loves to write code - Blog Garden 】
【 Reference resources : Convolution in NLP Applications in the field – With TextCNN For example _ Bili, Bili _bilibili】
【 Reference resources : Paper reading :Convolutional Neural Networks for Sentence Classification Sentence classification based on convolutional neural network _ Nanyoufu -CSDN Blog 】
【 Reference resources :TextCNN Tianchi Teaching _ Bili, Bili _bilibili】 Very well ( Yes Pytorch Code explanation )
Matching code :【 Reference resources :Datawhale Zero Basics NLP event - Task5 Text classification based on deep learning 2-2TextCNN- Tianchi Laboratory - Real time online data analysis collaboration tool , Enjoy free computing resources 】
Code can also be searched here 【 Reference resources : Tianchi Laboratory - Real time online data analysis collaboration tool , Enjoy free computing resources 】
The paper 1
Reference paper 1:《Convolutional Neural Networks for Sentence Classification 》 Convolution neural network for sentence classification 2014
The red line on the diagram , Window size is 2, Select two words each time for feature extraction ;
The window size of the yellow line is 3, Select three words each time for feature extraction , in other words ,“ window ” The meaning is “ Use a few words at a time ”, The reaction on the graph is “ The filter traverses several lines at a time ”;
The paper 2
Reference paper :《A Sensitivity Analysis of (and Practitioners’ Guide to) Convolutional Neural Networks for Sentence Classification 》 Convolution neural network for sentence classification ( And a practitioner's Guide ) Sensitivity analysis of 2016
TextCNN Consult and guide ( This version is now commonly used )

2 Windows are 2 Convolution kernel ( Yellow is ),2 Windows are 3 Convolution kernel ( Green is ),2 Windows are 4 Convolution kernel ( Red )
Pytorch Realization ( Simplified edition )
【 Reference resources :TextCNN Of PyTorch Realization _ Bili, Bili _bilibili】
Supporting articles https://wmathor.com/index.php/archives/1445/ It's very detailed 

import torch
import numpy as np
import torch.optim as optim
import torch.utils.data as Data
import torch.nn.functional as F
dtype = torch.FloatTensor
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good.
# TextCNN Parameters
embedding_size=2 # A word is represented by a two-dimensional vector
sequence_length=len(sentences[0].split ()) # 3 Sequence length All are the same length by default , All three words
num_classes=len(set(labels)) # 2
batch_size=3
word_list=" ".join(sentences).split() # Contains sentences All the words in There are repeated words
vocab=list(set(word_list)) # The dictionary Contains sentences All the words in
word2idx = {
w:i for i,w in enumerate(vocab)} # word : Indexes
vocab_size = len(vocab)
def make_data(sentences,labels):
inputs=[]
for sentence in sentences:
inputs.append([word2idx[n] for n in sentence.split()]) # Turn the sentence into the corresponding index sequence
targets=[]
for out in labels:
targets.append(out)
return inputs,targets
input_batch,target_batch=make_data(sentences,labels)
input_batch,target_batch=torch.LongTensor(input_batch),torch.LongTensor(target_batch)
dataset=Data.TensorDataset(input_batch,target_batch)
loader=Data.DataLoader(dataset,batch_size,True)
input_batch
tensor([[ 4, 11, 14],
[ 5, 9, 10],
[15, 0, 12],
[ 4, 6, 14],
[ 1, 2, 13],
[ 7, 3, 8]])
target_batch
tensor([1, 1, 1, 0, 0, 0])
from torch import nn
class TextCNN(nn.Module):
def __init__(self):
super(TextCNN, self).__init__()
self.W=nn.Embedding(num_embeddings=vocab_size,embedding_dim=embedding_size)
out_channels=3
self.conv = nn.Sequential(
# conv : [input_channel(=1), output_channel, kernel_size=(filter_height, filter_width=embedding_size), stride=1]
# because out_channels=3, So there will be 3 A convolution kernel is convoluted with the input data
# here filter_height=2 There is only one convolution kernel
nn.Conv2d(in_channels=1,out_channels=out_channels,kernel_size=(2,embedding_size))
# Output [batch_size, out_channels=3, 2, 1] out_channels That's ok batch_size List to elements , Each element is 2*1 Of
,nn.ReLU()
# pool : ((filter_height, filter_width))
,nn.MaxPool2d(kernel_size=(2,1)) # hold 2*1 Maximize pooled output of elements 1*1
)
# fc
self.fc = nn.Linear(in_features=out_channels,out_features=num_classes) # Output two classification
def forward(self, x):
''' X: [batch_size, sequence_length] '''
batch_size=x.shape[0] # How many sentences
# Into a cube such as [[ 4, 11, 14],...] 4 Represents a word , Then use the word vector [1,2] To express the word , namely [[ [1,2], 11, 14],...]
embedding_x = self.W(x) # [batch_size, sequence_length, embedding_size]
# The value added in the second dimension is 1 Dimensions The number of channels , single channel , A black-and-white image similar to a picture
# There is a line batch_size Column elements , Every element is sequence_length That's ok embedding_size Column
# Now the data can be convoluted , Because in tradition CNN in , The input data should be [batch_size, in_channel, height, width] This dimension
embedding_x = embedding_x.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size]
conved=self.conv(embedding_x) # [batch_size, output_channel,1,1]
flatten=conved.view(batch_size,-1) # [batch_size, output_channel*1*1]
output = self.fc(flatten)
return output
model=TextCNN().to(device=device)
criterion=nn.CrossEntropyLoss().to(device=device)
optimizer=optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(5000):
for batch_x, batch_y in loader:
batch_x=batch_x.to(device=device)
batch_y=batch_y.to(device=device)
pred=model(batch_x)
loss = criterion(pred, batch_y)
if (epoch +1) %1000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()
Epoch: 1000 loss = 0.030200
Epoch: 1000 loss = 0.054546
Epoch: 2000 loss = 0.014919
Epoch: 2000 loss = 0.007824
Epoch: 3000 loss = 0.002666
Epoch: 3000 loss = 0.005158
Epoch: 4000 loss = 0.001931
Epoch: 4000 loss = 0.000988
Epoch: 5000 loss = 0.000379
Epoch: 5000 loss = 0.000743
# Test
test_text = 'i hate me'
tests = [[word2idx[n] for n in test_text.split()]]
test_batch = torch.LongTensor(tests).to(device)
# Predict
model = model.eval()
predict = model(test_batch).data.max(1, keepdim=True)[1]
if predict[0][0] == 0:
print(test_text,"is Bad Mean...")
else:
print(test_text,"is Good Mean!!")
i hate me is Bad Mean...
Pytorch Realization 2
【 Reference resources :nlp-tutorial/TextCNN.py at master · graykode/nlp-tutorial】
# %%
# code by Tae Hwan Jung @graykode
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class TextCNN(nn.Module):
def __init__(self):
super(TextCNN, self).__init__()
self.num_filters_total = num_filters * len(filter_sizes)
self.W = nn.Embedding(vocab_size, embedding_size)
self.Weight = nn.Linear(self.num_filters_total, num_classes, bias=False)
self.Bias = nn.Parameter(torch.ones([num_classes]))
self.filter_list = nn.ModuleList([
nn.Conv2d(1, num_filters, kernel_size=(size, embedding_size))
for size in filter_sizes])
def forward(self, X):
embedded_chars = self.W(X) # [batch_size, sequence_length, sequence_length]
# Number of joined channels 1
embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size]
pooled_outputs = []
for i, conv in enumerate(self.filter_list):
# conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
h = F.relu(conv(embedded_chars)) # h:[batch_size(=6), output_channel(=3),output_height(=2), output_width(=1)]
# mp : ((filter_height, filter_width))
mp = nn.MaxPool2d((sequence_length - filter_sizes[i] + 1, 1)) # [2,1]
# mp(h):[batch_size(=6),output_channel(=3), output_height(=1), output_width(=1)]
# pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
pooled = mp(h).permute(0, 3, 2, 1)
pooled_outputs.append(pooled)
h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)]
model = self.Weight(h_pool_flat) + self.Bias # [batch_size, num_classes]
return model
if __name__ == '__main__':
embedding_size = 2 # embedding size
sequence_length = 3 # sequence length
num_classes = 2 # number of classes
# In the paper is 2,3,4
filter_sizes = [2, 2, 2] # n-gram windows # Convolution kernel size [filter_size,embedding_size]
num_filters = 3 # number of filters # 3 A convolution kernel will transform the input data into three channel data
# 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good.
word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {
w: i for i, w in enumerate(word_list)}
vocab_size = len(word_dict)
model = TextCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
inputs = torch.LongTensor([np.asarray([word_dict[n] for n in sen.split()]) for sen in sentences])
targets = torch.LongTensor([out for out in labels]) # To using Torch Softmax Loss function
# Training
for epoch in range(5000):
optimizer.zero_grad()
output = model(inputs)
# output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
loss = criterion(output, targets)
if (epoch + 1) % 1000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
loss.backward()
optimizer.step()
# Test
test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = torch.LongTensor(tests)
# Predict
predict = model(test_batch).data.max(1, keepdim=True)[1]
if predict[0][0] == 0:
print(test_text,"is Bad Mean...")
else:
print(test_text,"is Good Mean!!")
Pytorch Realization 3 ( undetermined )
【 Reference resources : Handwriting AI Produce : TextCNN Text classification , Line by line code recurrence ! coca UP Free answers !_ Bili, Bili _bilibili】
【 Reference resources :A-series-of-NLP/ Text classification /TextCNN_ Text classification at main · shouxieai/A-series-of-NLP】
边栏推荐
- The sandbox is being deployed on the polygon network
- How to use dataant to monitor Apache APIs IX
- Asynchronous transition scenario - generator
- 项目管理到底管的是什么?
- What does the software test report contain? How to obtain high quality software test reports?
- Redis - 01 cache: how to use read cache to improve system performance?
- pytorch 的Conv2d的详细解释
- CTFSHOW框架复现篇
- Lombok
- 理想中的接口自动化项目
猜你喜欢

2022-06-30: what does the following golang code output? A:0; B:2; C: Running error. package main import “fmt“ func main() { ints := make

The Sandbox 正在 Polygon 网络上进行部署

Doker's container data volume

理想中的接口自动化项目

【Android,Kotlin,TFLite】移动设备集成深度学习轻模型TFlite(物体检测篇)

In depth analysis of Apache bookkeeper series: Part 4 - back pressure

Prospects of world digitalization and machine intelligence in the next decade

One revolution, two forces and three links: the "carbon reduction" road map behind the industrial energy efficiency improvement action plan

Redis - 01 cache: how to use read cache to improve system performance?

Solution to the conflict between unique index and logical deletion
随机推荐
MIT博士论文 | 优化理论与机器学习实践
Fund managers' corporate governance and risk management
Two way data binding in wechat applet
未来十年世界数字化与机器智能展望
Cesiumjs 2022 ^ source code interpretation [6] - new architecture of modelempirical
电商秒杀系统
Meet the StreamNative | 杨子棵:是什么让我放弃了大厂 Offer
基金銷售行為規範及信息管理
基金客户服务
2022-06-30: what does the following golang code output? A:0; B:2; C: Running error. package main import “fmt“ func main() { ints := make
Dell r720 server installation network card Broadcom 5720 driver
35 giant technology companies jointly form the meta universe standard Forum Organization
As the public cloud market enters the deep water, can the calm Amazon cloud still sit still?
Fund clients and sales agencies
HP 惠普笔记本电脑 禁用触摸板 在插入鼠标后
Cesiumjs 2022 ^ source code interpretation [6] - new architecture of modelempirical
Redis' transaction and locking mechanism
多线程经典案例
Wechat applet transmits parameters (data-) by clicking events
分享十万级TPS的IM即时通讯综合消息系统的架构