当前位置:网站首页>Simple dialogue system -- text classification using transformer
Simple dialogue system -- text classification using transformer
2022-07-04 03:37:00 【Angry coke】
introduction
We will show how to use Transformer Model in the library to solve the task of text classification , Task from GLUE Benchmark.
Mission | sketch | type | single / Double sentence input | label | Evaluation indicators |
---|---|---|---|---|---|
CoLA | Whether the input sentence is grammatically acceptable | classification | Single sentence | {0:no,1:yes} | Matthews The correlation coefficient |
SST-2 | Two categories of film comment emotion | classification | Single sentence | {0:negative,1:positive} | Accuracy rate |
MRPC | Whether the input sentence pairs are semantically equivalent | classification | Double sentence | {0:no,1:yes} | Accuracy rate /F1 |
STS-B | Judge the similarity of input sentence pairs 1 to 5 files | Return to | Double sentence | [1.0,5.0] | Pearson/Spearman The correlation coefficient |
QQP | Whether two Quora The question of whether semantic equivalence | classification | Double sentence | {0:not similary,1:similar} | Accuracy rate /F1 |
MNLI | Judge whether the latter sentence contains the former sentence , Three kinds of labels contain 、 Neutral or opposite | classification | Double sentence | {0:entailend,1:neutral,2:contradiction} | Accuracy rate |
QNLI | from SQuAD The question and answer sentences extracted from reading comprehension are right , Judge whether the latter sentence can answer the question of the former sentence | classification | Double sentence | {0:entailend,1:not entailend} | Accuracy rate |
RTE | Judge whether the input sentence contains the previous sentence | classification | Double sentence | {0:entailend,1:not entailend} | Accuracy rate |
WNLI | Input the sentence pair to judge whether the anaphora resolution of the pronoun in the following sentence is correct | classification | Double sentence | {0:no,1:yes} | Accuracy rate |
For the above tasks , We will show how to use simple Dataset Library load dataset , Use at the same time transformer Medium Trainer
Interface to fine tune the pre training model .
Data loading
from datasets import load_dataset, load_metric
task = 'cola'
model_checkpoint = 'distilbert-base-uncased'
batch_size = 16
except mnli-mm
outside , Other tasks can be loaded by task name .
actual_task = 'mnli' if task == 'mnli-mm' else task
dataset = load_dataset('glue', actual_task) # Load data set
metric = load_metric('glue', actual_task) # Load the evaluation indicators related to the dataset
dataset
DatasetDict({
train: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 8551
})
validation: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1043
})
test: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1063
})
})
dataset['train'][0]
{
'idx': 0,
'label': 1,
'sentence': "Our friends won't buy this analysis, let alone the next one we propose."}
Data preprocessing
We need to preprocess the data , The preprocessing tool is called Tokenizer
. First, the input is tokenize obtain tokens, Then it is transformed into the corresponding token ID, Then convert it to the input format required by the model .
In order to achieve the purpose of data preprocessing , We use AutoTokenizer.from_pretrained
Way to instantiate our tokenizer, This ensures that :
- We get a one-to-one correspondence with the pre training model tokenizer.
- Use the specified model checkpoint Corresponding tokenizer When , We also downloaded the thesaurus required by the model vocabulary, Exactly tokens vocabulary.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) # use_fast Use multithreading fast tokenizer, Some models may not
tokenizer You can preprocess a single text , You can also preprocess a pair of text .
tokenizer('Hello,this one sentence!', 'And this sentence goes with it.')
{
'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
To preprocess our data , We need to know different data and corresponding data formats , So we define the following dict:
task_to_keys = {
'cola':('sentence', None),
'mnli':('premise', 'hypothesis'),
'mnli-mm':('premise', 'hypothesis'),
'mrpc':('sentence1', 'sentence2'),
'qnli':('question', 'sentence'),
'qqp':('question1', 'question2'),
'rte':('sentence1', 'sentence2'),
'sst2':('sentence', None),
'stsb':('sentence1', 'sentence2'),
'wnli':('sentence1', 'sentence2')
}
Put the pre trained code into a function :
sentence1_key,sentence2_key = task_to_keys[task]
def preprocess(examples):
# If the input is two sentences , All of them tokenize, Otherwise, only the first sentence .
if sentence2_key is None:
return tokenizer(examples[sentence1_key], truncation=True)
return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)
We can test it :
preprocess(dataset['train'][:5])
{
'input_ids': [[101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 1998, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 2030, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 1996, 2062, 2057, 2817, 16025, 1010, 1996, 13675, 16103, 2121, 2027, 2131, 1012, 102], [101, 2154, 2011, 2154, 1996, 8866, 2024, 2893, 14163, 8024, 3771, 1012, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
Next, preprocess all samples in the data set , The way to deal with it is to use map function , The preprocessing function prepare_train_features
Apply to all samples .
encoded_dataset = dataset.map(preprocess, batched=True)
The returned results are cached , There is no need to recalculate next time . If you don't need to load the cache , Use load_from_cache_file=False
Parameters .
Fine tune the pre training model
Since we are doing seq2seq Mission , Then you need a model class that can solve the task . We use AutoModelForSequenceClassification
This class , and tokenizer be similar ,from_pretrained
Methods can also help us download and load models :
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
# STS-B It's a regression problem ,MNLI It's a 3 Classification problem
num_labels = 3 if task.startswith('mnli') else 1 if task =='stsb' else 2
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
In order to get one Trainer
Training tools , We also need 3 Elements , One of the most important is the setting of training / Parameters TrainingArguments
. This training setting contains all the attributes that can define the training process .
metric_name ='pearson' if task=='stsb' else 'matthews_correlation' if task =='cola' else 'accuracy'
args = TrainingArguments(
'test-glue',
evaluation_strategy='epoch',
save_strategy='epoch',
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=5,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model=metric_name,
)
evaluation_strategy='epoch'
Every one epoch Will do a verification evaluation .
Because different tasks require different evaluation indicators , We define a function to get the evaluation method according to the task name :
import numpy as np
def compute_metrics(eval_pred):
predictions, labels = eval_pred
if task != 'stsb':
predictions = np.argmax(predictions, axis=1)
else:
predictions = predictions[:,0]
return metric.compute(predictions=predictions, references=labels)
Then pass it all to Trainer
:
validation_key = 'validation_mismatched' if task =='mnli-mm' else 'validation_matched' if task == 'mnli' else 'validation'
trainer = Trainer(
model,
args,
train_dataset=encoded_dataset['train'],
eval_dataset=encoded_dataset[validation_key],
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
Start training :
trainer.train()
The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.
/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:310: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
FutureWarning,
***** Running training *****
Num examples = 8551
Num Epochs = 5
Instantaneous batch size per device = 16
Total train batch size (w. parallel, distributed & accumulation) = 16
Gradient Accumulation steps = 1
Total optimization steps = 2675
[2675/2675 03:05, Epoch 5/5]
Epoch Training Loss Validation Loss Matthews Correlation
1 0.524700 0.531839 0.397460
2 0.347900 0.517831 0.488779
3 0.235900 0.568724 0.520212
4 0.183700 0.776045 0.500206
5 0.138100 0.811927 0.521112
After the training, you can evaluate :
trainer.evaluate()
{
'epoch': 5.0,
'eval_loss': 0.8119268417358398,
'eval_matthews_correlation': 0.5211120728046958,
'eval_runtime': 0.8952,
'eval_samples_per_second': 1165.111,
'eval_steps_per_second': 73.727}
Super parameter search
If you don't know how to set parameters , We can search by super parameters .Trainer
Support super parameter search , But you need to use optuna or Ray Tune The code base .
!pip install optuna ray[tune]
When searching for super parameters ,Trainer
Multiple trained models will be returned , So you need to pass in a defined model to make Trainer
You can constantly reinitialize the incoming model :
def model_init():
return AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
And the previous call Trainer
similar :
trainer = Trainer(
model_init=model_init,
args=args,
train_dataset=encoded_dataset['train'],
eval_dataset=encoded_dataset[validation_key],
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
Calling method hyperparameter_seach
, But this process will take a long time . We can use part of the data to search for super parameters , Then do full training .
For example, use 1/10 Search for data :
best_run=trainer.hyperparameter_search(n_trials=10,direction='maximize')
It will return the best model related parameters :
best_run
BestRun(run_id='2', objective=0.5408374954915984, hyperparameters={
'learning_rate': 1.8520571467952223e-05, 'num_train_epochs': 4, 'seed': 37, 'per_device_train_batch_size': 8})
If you want to Trainer
Set as the best parameter obtained by search , Training , It can be like this :
for n,v in best_run.hyperparameters.items():
setattr(trainer.args,n,v)
trainer.train()
Reference resources
- Greedy college courses
边栏推荐
- Nbear introduction and use diagram
- Exercices de renforcement des déclarations SQL (MySQL 8.0 par exemple)
- Rhcsa day 2
- Zigzag scan
- static hostname; transient hostname; pretty hostname
- [untitled]
- SQL語句加强練習(MySQL8.0為例)
- Redis notes (I) Linux installation process of redis
- Objective-C string class, array class
- Www 2022 | taxoenrich: self supervised taxonomy complemented by Structural Semantics
猜你喜欢
National standard gb28181 protocol platform easygbs fails to start after replacing MySQL database. How to deal with it?
MySQL maxscale realizes read-write separation
Zhihu million hot discussion: why can we only rely on job hopping for salary increase? Bosses would rather hire outsiders with a high salary than get a raise?
[untitled]
JVM family -- monitoring tools
Ningde times and BYD have refuted rumors one after another. Why does someone always want to harm domestic brands?
Code Execution Vulnerability - no alphanumeric rce create_ function()
Cache general management class + cache httpcontext Current. Cache and httpruntime Differences between caches
MySQL data query optimization -- data structure of index
Recursive structure
随机推荐
[latex] production of complex tables: excel2latex and detail adjustment
Examination question bank of constructor decoration direction post skills (constructor) and examination data of constructor decoration direction post skills (constructor) in 2022
[PaddleSeg 源码阅读] PaddleSeg Transform 的 Normalize操作
[PaddleSeg 源码阅读] PaddleSeg计算Dice
Imperial cms7.5 imitation "D9 download station" software application download website source code
Love and self-discipline and strive to live a core life
XSS prevention
[source code analysis] model parallel distributed training Megatron (5) -- pipestream flush
Object oriented -- encapsulation, inheritance, polymorphism
MySQL backup notes
Defensive programming skills
Objective C attribute keyword
Want to do something in production? Then try these redis commands
MySQL maxscale realizes read-write separation
[database I] database overview, common commands, view the table structure of 'demo data', simple query, condition query, sorting data, data processing function (single row processing function), groupi
Experience summary of the 12th Blue Bridge Cup (written for the first time)
Lichuang EDA learning notes 14: PCB board canvas settings
[Valentine's Day confession code] - Valentine's Day is approaching, and more than 10 romantic love effects are given to the one you love
Osnabrueck University | overview of specific architectures in the field of reinforcement learning
Add IDM to Google browser