当前位置:网站首页>Simple dialogue system -- text classification using transformer
Simple dialogue system -- text classification using transformer
2022-07-04 03:37:00 【Angry coke】
introduction
We will show how to use Transformer Model in the library to solve the task of text classification , Task from GLUE Benchmark.
Mission | sketch | type | single / Double sentence input | label | Evaluation indicators |
---|---|---|---|---|---|
CoLA | Whether the input sentence is grammatically acceptable | classification | Single sentence | {0:no,1:yes} | Matthews The correlation coefficient |
SST-2 | Two categories of film comment emotion | classification | Single sentence | {0:negative,1:positive} | Accuracy rate |
MRPC | Whether the input sentence pairs are semantically equivalent | classification | Double sentence | {0:no,1:yes} | Accuracy rate /F1 |
STS-B | Judge the similarity of input sentence pairs 1 to 5 files | Return to | Double sentence | [1.0,5.0] | Pearson/Spearman The correlation coefficient |
QQP | Whether two Quora The question of whether semantic equivalence | classification | Double sentence | {0:not similary,1:similar} | Accuracy rate /F1 |
MNLI | Judge whether the latter sentence contains the former sentence , Three kinds of labels contain 、 Neutral or opposite | classification | Double sentence | {0:entailend,1:neutral,2:contradiction} | Accuracy rate |
QNLI | from SQuAD The question and answer sentences extracted from reading comprehension are right , Judge whether the latter sentence can answer the question of the former sentence | classification | Double sentence | {0:entailend,1:not entailend} | Accuracy rate |
RTE | Judge whether the input sentence contains the previous sentence | classification | Double sentence | {0:entailend,1:not entailend} | Accuracy rate |
WNLI | Input the sentence pair to judge whether the anaphora resolution of the pronoun in the following sentence is correct | classification | Double sentence | {0:no,1:yes} | Accuracy rate |
For the above tasks , We will show how to use simple Dataset Library load dataset , Use at the same time transformer Medium Trainer
Interface to fine tune the pre training model .
Data loading
from datasets import load_dataset, load_metric
task = 'cola'
model_checkpoint = 'distilbert-base-uncased'
batch_size = 16
except mnli-mm
outside , Other tasks can be loaded by task name .
actual_task = 'mnli' if task == 'mnli-mm' else task
dataset = load_dataset('glue', actual_task) # Load data set
metric = load_metric('glue', actual_task) # Load the evaluation indicators related to the dataset
dataset
DatasetDict({
train: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 8551
})
validation: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1043
})
test: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1063
})
})
dataset['train'][0]
{
'idx': 0,
'label': 1,
'sentence': "Our friends won't buy this analysis, let alone the next one we propose."}
Data preprocessing
We need to preprocess the data , The preprocessing tool is called Tokenizer
. First, the input is tokenize obtain tokens, Then it is transformed into the corresponding token ID, Then convert it to the input format required by the model .
In order to achieve the purpose of data preprocessing , We use AutoTokenizer.from_pretrained
Way to instantiate our tokenizer, This ensures that :
- We get a one-to-one correspondence with the pre training model tokenizer.
- Use the specified model checkpoint Corresponding tokenizer When , We also downloaded the thesaurus required by the model vocabulary, Exactly tokens vocabulary.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) # use_fast Use multithreading fast tokenizer, Some models may not
tokenizer You can preprocess a single text , You can also preprocess a pair of text .
tokenizer('Hello,this one sentence!', 'And this sentence goes with it.')
{
'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
To preprocess our data , We need to know different data and corresponding data formats , So we define the following dict:
task_to_keys = {
'cola':('sentence', None),
'mnli':('premise', 'hypothesis'),
'mnli-mm':('premise', 'hypothesis'),
'mrpc':('sentence1', 'sentence2'),
'qnli':('question', 'sentence'),
'qqp':('question1', 'question2'),
'rte':('sentence1', 'sentence2'),
'sst2':('sentence', None),
'stsb':('sentence1', 'sentence2'),
'wnli':('sentence1', 'sentence2')
}
Put the pre trained code into a function :
sentence1_key,sentence2_key = task_to_keys[task]
def preprocess(examples):
# If the input is two sentences , All of them tokenize, Otherwise, only the first sentence .
if sentence2_key is None:
return tokenizer(examples[sentence1_key], truncation=True)
return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)
We can test it :
preprocess(dataset['train'][:5])
{
'input_ids': [[101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 1998, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 2030, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 1996, 2062, 2057, 2817, 16025, 1010, 1996, 13675, 16103, 2121, 2027, 2131, 1012, 102], [101, 2154, 2011, 2154, 1996, 8866, 2024, 2893, 14163, 8024, 3771, 1012, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
Next, preprocess all samples in the data set , The way to deal with it is to use map function , The preprocessing function prepare_train_features
Apply to all samples .
encoded_dataset = dataset.map(preprocess, batched=True)
The returned results are cached , There is no need to recalculate next time . If you don't need to load the cache , Use load_from_cache_file=False
Parameters .
Fine tune the pre training model
Since we are doing seq2seq Mission , Then you need a model class that can solve the task . We use AutoModelForSequenceClassification
This class , and tokenizer be similar ,from_pretrained
Methods can also help us download and load models :
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
# STS-B It's a regression problem ,MNLI It's a 3 Classification problem
num_labels = 3 if task.startswith('mnli') else 1 if task =='stsb' else 2
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
In order to get one Trainer
Training tools , We also need 3 Elements , One of the most important is the setting of training / Parameters TrainingArguments
. This training setting contains all the attributes that can define the training process .
metric_name ='pearson' if task=='stsb' else 'matthews_correlation' if task =='cola' else 'accuracy'
args = TrainingArguments(
'test-glue',
evaluation_strategy='epoch',
save_strategy='epoch',
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=5,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model=metric_name,
)
evaluation_strategy='epoch'
Every one epoch Will do a verification evaluation .
Because different tasks require different evaluation indicators , We define a function to get the evaluation method according to the task name :
import numpy as np
def compute_metrics(eval_pred):
predictions, labels = eval_pred
if task != 'stsb':
predictions = np.argmax(predictions, axis=1)
else:
predictions = predictions[:,0]
return metric.compute(predictions=predictions, references=labels)
Then pass it all to Trainer
:
validation_key = 'validation_mismatched' if task =='mnli-mm' else 'validation_matched' if task == 'mnli' else 'validation'
trainer = Trainer(
model,
args,
train_dataset=encoded_dataset['train'],
eval_dataset=encoded_dataset[validation_key],
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
Start training :
trainer.train()
The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.
/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:310: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
FutureWarning,
***** Running training *****
Num examples = 8551
Num Epochs = 5
Instantaneous batch size per device = 16
Total train batch size (w. parallel, distributed & accumulation) = 16
Gradient Accumulation steps = 1
Total optimization steps = 2675
[2675/2675 03:05, Epoch 5/5]
Epoch Training Loss Validation Loss Matthews Correlation
1 0.524700 0.531839 0.397460
2 0.347900 0.517831 0.488779
3 0.235900 0.568724 0.520212
4 0.183700 0.776045 0.500206
5 0.138100 0.811927 0.521112
After the training, you can evaluate :
trainer.evaluate()
{
'epoch': 5.0,
'eval_loss': 0.8119268417358398,
'eval_matthews_correlation': 0.5211120728046958,
'eval_runtime': 0.8952,
'eval_samples_per_second': 1165.111,
'eval_steps_per_second': 73.727}
Super parameter search
If you don't know how to set parameters , We can search by super parameters .Trainer
Support super parameter search , But you need to use optuna or Ray Tune The code base .
!pip install optuna ray[tune]
When searching for super parameters ,Trainer
Multiple trained models will be returned , So you need to pass in a defined model to make Trainer
You can constantly reinitialize the incoming model :
def model_init():
return AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
And the previous call Trainer
similar :
trainer = Trainer(
model_init=model_init,
args=args,
train_dataset=encoded_dataset['train'],
eval_dataset=encoded_dataset[validation_key],
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
Calling method hyperparameter_seach
, But this process will take a long time . We can use part of the data to search for super parameters , Then do full training .
For example, use 1/10 Search for data :
best_run=trainer.hyperparameter_search(n_trials=10,direction='maximize')
It will return the best model related parameters :
best_run
BestRun(run_id='2', objective=0.5408374954915984, hyperparameters={
'learning_rate': 1.8520571467952223e-05, 'num_train_epochs': 4, 'seed': 37, 'per_device_train_batch_size': 8})
If you want to Trainer
Set as the best parameter obtained by search , Training , It can be like this :
for n,v in best_run.hyperparameters.items():
setattr(trainer.args,n,v)
trainer.train()
Reference resources
- Greedy college courses
边栏推荐
- MySQL data query optimization -- data structure of index
- New year's first race, submit bug reward more!
- Stm32bug [stlink forced update prompt appears in keilmdk, but it cannot be updated]
- 1day vulnerability pushback skills practice (3)
- Session learning diary 1
- (column 23) typical C language problem: find the minimum common multiple and maximum common divisor of two numbers. (two solutions)
- [.NET + mqtt]. Mise en œuvre de la communication mqtt dans l'environnement net 6 et démonstration de code pour l'abonnement et la publication de messages bilatéraux du serveur et du client
- JVM family -- heap analysis
- Want to do something in production? Then try these redis commands
- Zhihu million hot discussion: why can we only rely on job hopping for salary increase? Bosses would rather hire outsiders with a high salary than get a raise?
猜你喜欢
Package details_ Four access control characters_ Two details of protected
Add token validation in swagger
Stm32bug [stlink forced update prompt appears in keilmdk, but it cannot be updated]
New year's first race, submit bug reward more!
SQL语句加强练习(MySQL8.0为例)
Why is it recommended that technologists write blogs?
Es network layer
logistic regression
No clue about the data analysis report? After reading this introduction of smartbi, you will understand!
EV6 helps the product matrix, and Kia is making efforts in the high-end market. The global sales target in 2022 is 3.15 million?
随机推荐
JS object definition
In my spare time, I like to write some technical blogs and read some useless books. If you want to read more of my original articles, you can follow my personal wechat official account up technology c
[UE4] parse JSON string
What is the difference between enterprise wechat applet and wechat applet
[Wu Enda deep learning] beginner learning record 3 (regularization / error reduction)
Aperçu du code source futur - série juc
2022-07-03:数组里有0和1,一定要翻转一个区间,翻转:0变1,1变0。 请问翻转后可以使得1的个数最多是多少? 来自小红书。3.13笔试。
Objective-C member variable permissions
System integration meets the three business needs of enterprises
Webhook triggers Jenkins for sonar detection
This function has none of DETERMINISTIC, NO SQL..... (you *might* want to use the less safe log_bin_t
基于PHP的轻量企业销售管理系统
Session learning diary 1
2022 examination summary of quality controller - Equipment direction - general basis (quality controller) and examination questions and analysis of quality controller - Equipment direction - general b
SQL statement strengthening exercise (MySQL 8.0 as an example)
Rhcsa day 2
【.NET+MQTT】.NET6 環境下實現MQTT通信,以及服務端、客戶端的雙邊消息訂閱與發布的代碼演示
Audio and video technology development weekly | 232
If you have just joined a new company, don't be fired because of your mistakes
POSTECH | option compatible reward reverse reinforcement learning