当前位置:网站首页>Task6: using transformer for emotion analysis
Task6: using transformer for emotion analysis
2022-07-03 13:14:00 【Levi Bebe】
# Data preparation
import torch
import random
import numpy as np
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
# Load pre training model
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Define the length
def tokenize_and_cut(sentence):
tokens = tokenizer.tokenize(sentence)
tokens = tokens[:max_input_length-2]
return tokens
from torchtext.legacy import data
TEXT = data.Field(batch_first = True,
use_vocab = False,
tokenize = tokenize_and_cut,
preprocessing = tokenizer.convert_tokens_to_ids,
init_token = init_token_idx,
eos_token = eos_token_idx,
pad_token = pad_token_idx,
unk_token = unk_token_idx)
LABEL = data.LabelField(dtype = torch.float)
# Load data
from torchtext.legacy import datasets
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
train_data, valid_data = train_data.split(random_state = random.seed(SEED))
# Set up the device
BATCH_SIZE = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
(train_data, valid_data, test_data),
batch_size = BATCH_SIZE,
device = device)
# Build the model
from transformers import BertTokenizer, BertModel
bert = BertModel.from_pretrained('bert-base-uncased')
import torch.nn as nn
class BERTGRUSentiment(nn.Module):
def __init__(self,
bert,
hidden_dim,
output_dim,
n_layers,
bidirectional,
dropout):
super().__init__()
self.bert = bert
embedding_dim = bert.config.to_dict()['hidden_size']
self.rnn = nn.GRU(embedding_dim,
hidden_dim,
num_layers = n_layers,
bidirectional = bidirectional,
batch_first = True,
dropout = 0 if n_layers < 2 else dropout)
self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, text):
#text = [batch size, sent len]
with torch.no_grad():
embedded = self.bert(text)[0]
#embedded = [batch size, sent len, emb dim]
_, hidden = self.rnn(embedded)
#hidden = [n layers * n directions, batch size, emb dim]
if self.rnn.bidirectional:
hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))
else:
hidden = self.dropout(hidden[-1,:,:])
#hidden = [batch size, hid dim]
output = self.out(hidden)
#output = [batch size, out dim]
return output
HIDDEN_DIM = 256
OUTPUT_DIM = 1
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.25
model = BERTGRUSentiment(bert,
HIDDEN_DIM,
OUTPUT_DIM,
N_LAYERS,
BIDIRECTIONAL,
DROPOUT)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {
count_parameters(model):,} trainable parameters')
for name, param in model.named_parameters():
if name.startswith('bert'):
param.requires_grad = False
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {
count_parameters(model):,} trainable parameters')
# Training models
import torch.optim as optim
optimizer = optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss()
model = model.to(device)
criterion = criterion.to(device)
def binary_accuracy(preds, y):
""" Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8 """
#round predictions to the closest integer
rounded_preds = torch.round(torch.sigmoid(preds))
correct = (rounded_preds == y).float() #convert into float for division
acc = correct.sum() / len(correct)
return acc
def train(model, iterator, optimizer, criterion):
epoch_loss = 0
epoch_acc = 0
model.train()
for batch in iterator:
optimizer.zero_grad()
predictions = model(batch.text).squeeze(1)
loss = criterion(predictions, batch.label)
acc = binary_accuracy(predictions, batch.label)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
def evaluate(model, iterator, criterion):
epoch_loss = 0
epoch_acc = 0
model.eval()
with torch.no_grad():
for batch in iterator:
predictions = model(batch.text).squeeze(1)
loss = criterion(predictions, batch.label)
acc = binary_accuracy(predictions, batch.label)
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
import time
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
# Training
N_EPOCHS = 5
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut6-model.pt')
print(f'Epoch: {
epoch+1:02} | Epoch Time: {
epoch_mins}m {
epoch_secs}s')
print(f'\tTrain Loss: {
train_loss:.3f} | Train Acc: {
train_acc*100:.2f}%')
print(f'\t Val. Loss: {
valid_loss:.3f} | Val. Acc: {
valid_acc*100:.2f}%')
# We will load the parameters that provide us with the best loss value on the validation set , And apply these parameters to the test set - And achieved the best results on the test set
model.load_state_dict(torch.load('tut6-model.pt'))
test_loss, test_acc = evaluate(model, test_iterator, criterion)
print(f'Test Loss: {
test_loss:.3f} | Test Acc: {
test_acc*100:.2f}%')
def predict_sentiment(model, tokenizer, sentence):
model.eval()
tokens = tokenizer.tokenize(sentence)
tokens = tokens[:max_input_length-2]
indexed = [init_token_idx] + tokenizer.convert_tokens_to_ids(tokens) + [eos_token_idx]
tensor = torch.LongTensor(indexed).to(device)
tensor = tensor.unsqueeze(0)
prediction = torch.sigmoid(model(tensor))
return prediction.item()
predict_sentiment(model, tokenizer, "This film is terrible")
predict_sentiment(model, tokenizer, "This film is great")
边栏推荐
- Flink SQL knows why (12): is it difficult to join streams? (top)
- PostgreSQL installation
- [Exercice 5] [principe de la base de données]
- Kotlin - improved decorator mode
- SQL learning notes (I)
- 解决 System has not been booted with systemd as init system (PID 1). Can‘t operate.
- Node. Js: use of express + MySQL
- sitesCMS v3.1.0发布,上线微信小程序
- C graphical tutorial (Fourth Edition)_ Chapter 15 interface: interfacesamplep271
- Create a dojo progress bar programmatically: Dojo ProgressBar
猜你喜欢

(first) the most complete way to become God of Flink SQL in history (full text 180000 words, 138 cases, 42 pictures)

2022-02-09 survey of incluxdb cluster

【数据库原理复习题】

Grid connection - Analysis of low voltage ride through and island coexistence

Solve system has not been booted with SYSTEMd as init system (PID 1) Can‘t operate.

Flink SQL knows why (13): is it difficult to join streams? (next)

Integer case study of packaging

Elk note 24 -- replace logstash consumption log with gohangout

【数据库原理及应用教程(第4版|微课版)陈志泊】【第六章习题】

今日睡眠质量记录77分
随机推荐
解决 System has not been booted with systemd as init system (PID 1). Can‘t operate.
C graphical tutorial (Fourth Edition)_ Chapter 18 enumerator and iterator: enumerator samplep340
【習題七】【數據庫原理】
【数据库原理及应用教程(第4版|微课版)陈志泊】【第五章习题】
【习题五】【数据库原理】
Flink SQL knows why (17): Zeppelin, a sharp tool for developing Flink SQL
【数据库原理复习题】
显卡缺货终于到头了:4000多块可得3070Ti,比原价便宜2000块拿下3090Ti
[network counting] Chapter 3 data link layer (2) flow control and reliable transmission, stop waiting protocol, backward n frame protocol (GBN), selective retransmission protocol (SR)
【数据库原理及应用教程(第4版|微课版)陈志泊】【SQLServer2012综合练习】
2022-02-11 heap sorting and recursion
[exercice 7] [principe de la base de données]
How to get user location in wechat applet?
关于CPU缓冲行的理解
Solve system has not been booted with SYSTEMd as init system (PID 1) Can‘t operate.
Huffman coding experiment report
Analysis of the influence of voltage loop on PFC system performance
【判断题】【简答题】【数据库原理】
C graphical tutorial (Fourth Edition)_ Chapter 13 entrustment: what is entrustment? P238
Elk note 24 -- replace logstash consumption log with gohangout