当前位置:网站首页>Task5: multi type emotion analysis
Task5: multi type emotion analysis
2022-07-03 13:14:00 【Levi Bebe】
In this study , We will have 6 Data sets of classes perform classification
You can use jupyter notebook function !!!
import torch
from torchtext.legacy import data
from torchtext.legacy import datasets
import random
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
TEXT = data.Field(tokenize = 'spacy',tokenizer_language = 'en_core_web_sm')
LABEL = data.LabelField()
train_data, test_data = datasets.TREC.splits(TEXT, LABEL, fine_grained=False)
train_data, valid_data = train_data.split(random_state = random.seed(SEED))
# Building a vocabulary
MAX_VOCAB_SIZE = 25_000
TEXT.build_vocab(train_data,
max_size = MAX_VOCAB_SIZE,
vectors = "glove.6B.100d",
unk_init = torch.Tensor.normal_)
LABEL.build_vocab(train_data)
# Build iterators
BATCH_SIZE = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
(train_data, valid_data, test_data),
batch_size = BATCH_SIZE,
device = device)
# Model building
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim,
dropout, pad_idx):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.convs = nn.ModuleList([
nn.Conv2d(in_channels = 1,
out_channels = n_filters,
kernel_size = (fs, embedding_dim))
for fs in filter_sizes
])
self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, text):
#text = [sent len, batch size]
text = text.permute(1, 0)
#text = [batch size, sent len]
embedded = self.embedding(text)
#embedded = [batch size, sent len, emb dim]
embedded = embedded.unsqueeze(1)
#embedded = [batch size, 1, sent len, emb dim]
conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]
#conv_n = [batch size, n_filters, sent len - filter_sizes[n]]
pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
#pooled_n = [batch size, n_filters]
cat = self.dropout(torch.cat(pooled, dim = 1))
#cat = [batch size, n_filters * len(filter_sizes)]
return self.fc(cat)
# Model parameter settings
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
N_FILTERS = 100
FILTER_SIZES = [2,3,4]
OUTPUT_DIM = len(LABEL.vocab)
DROPOUT = 0.5
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]
model = CNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)
# Load pre training model
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)
# use 0 Initialize unknown weights and padding Parameters
UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]
model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)
model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)
# Set up loss
import torch.optim as optim
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
model = model.to(device)
criterion = criterion.to(device)
# Calculation accuracy
def categorical_accuracy(preds, y):
""" Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8 """
top_pred = preds.argmax(1, keepdim = True)
correct = top_pred.eq(y.view_as(top_pred)).sum()
acc = correct.float() / y.shape[0]
return acc
# Training
def train(model, iterator, optimizer, criterion):
epoch_loss = 0
epoch_acc = 0
model.train()
for batch in iterator:
optimizer.zero_grad()
predictions = model(batch.text)
loss = criterion(predictions, batch.label)
acc = categorical_accuracy(predictions, batch.label)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
# evaluation
def evaluate(model, iterator, criterion):
epoch_loss = 0
epoch_acc = 0
model.eval()
with torch.no_grad():
for batch in iterator:
predictions = model(batch.text)
loss = criterion(predictions, batch.label)
acc = categorical_accuracy(predictions, batch.label)
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
# Time statistics
import time
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
# Training models
N_EPOCHS = 5
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut5-model.pt')
print(f'Epoch: {
epoch+1:02} | Epoch Time: {
epoch_mins}m {
epoch_secs}s')
print(f'\tTrain Loss: {
train_loss:.3f} | Train Acc: {
train_acc*100:.2f}%')
print(f'\t Val. Loss: {
valid_loss:.3f} | Val. Acc: {
valid_acc*100:.2f}%')
# test model
model.load_state_dict(torch.load('tut5-model.pt'))
test_loss, test_acc = evaluate(model, test_iterator, criterion)
print(f'Test Loss: {
test_loss:.3f} | Test Acc: {
test_acc*100:.2f}%')
边栏推荐
- 01 three solutions to knapsack problem (greedy dynamic programming branch gauge)
- [Database Principle and Application Tutorial (4th Edition | wechat Edition) Chen Zhibo] [Chapter 7 exercises]
- Differences and connections between final and static
- 2022-01-27 redis cluster technology research
- Fabric. JS three methods of changing pictures (including changing pictures in the group and caching)
- [comprehensive question] [Database Principle]
- Idea full text search shortcut ctr+shift+f failure problem
- Logback 日志框架
- 我的创作纪念日:五周年
- Two solutions of leetcode101 symmetric binary tree (recursion and iteration)
猜你喜欢
Finite State Machine FSM
The 35 required questions in MySQL interview are illustrated, which is too easy to understand
解决 System has not been booted with systemd as init system (PID 1). Can‘t operate.
[review questions of database principles]
Huffman coding experiment report
【数据库原理及应用教程(第4版|微课版)陈志泊】【第六章习题】
Flink SQL knows why (16): dlink, a powerful tool for developing enterprises with Flink SQL
Brief introduction to mvcc
STM32 and motor development (from MCU to architecture design)
Sword finger offer 12 Path in matrix
随机推荐
我的创作纪念日:五周年
阿南的疑惑
My creation anniversary: the fifth anniversary
【习题六】【数据库原理】
Some thoughts on business
【判断题】【简答题】【数据库原理】
regular expression
2022-01-27 redis cluster brain crack problem analysis
Flink SQL knows why (12): is it difficult to join streams? (top)
How to get user location in wechat applet?
人身变声器的原理
已解决TypeError: Argument ‘parser‘ has incorrect type (expected lxml.etree._BaseParser, got type)
【习题五】【数据库原理】
Node. Js: use of express + MySQL
Mysqlbetween implementation selects the data range between two values
Flink SQL knows why (13): is it difficult to join streams? (next)
2022-01-27 use liquibase to manage MySQL execution version
剑指 Offer 14- I. 剪绳子
Flink SQL knows why (16): dlink, a powerful tool for developing enterprises with Flink SQL
Flink SQL knows why (XV): changed the source code and realized a batch lookup join (with source code attached)