当前位置:网站首页>深入浅出对话系统——使用Transformer进行文本分类
深入浅出对话系统——使用Transformer进行文本分类
2022-07-04 03:32:00 【愤怒的可乐】
引言
我们将展示使用Transformer库中的模型来解决文本分类任务,任务来源于GLUE Benchmark。
任务 | 简述 | 类型 | 单/双句输入 | 标签 | 评估指标 |
---|---|---|---|---|---|
CoLA | 输入句子是否语法上可接受 | 分类 | 单句 | {0:no,1:yes} | Matthews相关系数 |
SST-2 | 电影评论情感二分类 | 分类 | 单句 | {0:negative,1:positive} | 准确率 |
MRPC | 输入的句子对是否语义等价 | 分类 | 双句 | {0:no,1:yes} | 准确率/F1 |
STS-B | 判断输入句子对的相似度1至5档 | 回归 | 双句 | [1.0,5.0] | Pearson/Spearman相关系数 |
QQP | 判断两个Quora的问题是否语义等价 | 分类 | 双句 | {0:not similary,1:similar} | 准确率/F1 |
MNLI | 判断后句是否蕴含了前句,三类标签蕴含、中性或相反 | 分类 | 双句 | {0:entailend,1:neutral,2:contradiction} | 准确率 |
QNLI | 从SQuAD阅读理解中抽出来的问答句子对,判断后句是否可回答前句的问题 | 分类 | 双句 | {0:entailend,1:not entailend} | 准确率 |
RTE | 判断输入的句子对后句是否蕴含前句 | 分类 | 双句 | {0:entailend,1:not entailend} | 准确率 |
WNLI | 输入句子对判断后句对前句中代词的指代消解是否正确 | 分类 | 双句 | {0:no,1:yes} | 准确率 |
对于以上任务,我们将展示如何使用简单的Dataset库加载数据集,同时使用transformer中的Trainer
接口对预训练模型进行微调。
数据加载
from datasets import load_dataset, load_metric
task = 'cola'
model_checkpoint = 'distilbert-base-uncased'
batch_size = 16
除了mnli-mm
以外,其他任务都可以通过任务名称进行加载。
actual_task = 'mnli' if task == 'mnli-mm' else task
dataset = load_dataset('glue', actual_task) # 加载数据集
metric = load_metric('glue', actual_task) # 加载数据集相关的评测指标
dataset
DatasetDict({
train: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 8551
})
validation: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1043
})
test: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1063
})
})
dataset['train'][0]
{
'idx': 0,
'label': 1,
'sentence': "Our friends won't buy this analysis, let alone the next one we propose."}
数据预处理
我们需要对数据进行预处理,预处理的工具叫Tokenizer
。首先对输入进行tokenize得到tokens,然后转化为预训练模型中需要对应的token ID,再转化为模型需要的输入格式。
为了达到数据预处理的目的,我们使用AutoTokenizer.from_pretrained
方式实例化我们的tokenizer,这样可以确保:
- 我们得到一个与预训练模型一一对应的tokenizer。
- 使用指定的模型checkpoint对应的tokenizer的时候,我们也下载了模型需要的词表库vocabulary,准确来说是tokens vocabulary。
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) # use_fast使用多线程快速tokenizer,有些模型可能没有
tokenizer既可以对单个文本进行预处理,也可以对一对文本进行预处理。
tokenizer('Hello,this one sentence!', 'And this sentence goes with it.')
{
'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
为了预处理我们的数据,我们需要知道不同数据和对应的数据格式,因此我们定义下面这个dict:
task_to_keys = {
'cola':('sentence', None),
'mnli':('premise', 'hypothesis'),
'mnli-mm':('premise', 'hypothesis'),
'mrpc':('sentence1', 'sentence2'),
'qnli':('question', 'sentence'),
'qqp':('question1', 'question2'),
'rte':('sentence1', 'sentence2'),
'sst2':('sentence', None),
'stsb':('sentence1', 'sentence2'),
'wnli':('sentence1', 'sentence2')
}
将预训练的代码放到一个函数中:
sentence1_key,sentence2_key = task_to_keys[task]
def preprocess(examples):
# 如果输入是两句话,则对它们都进行tokenize,否则只对第一句话进行。
if sentence2_key is None:
return tokenizer(examples[sentence1_key], truncation=True)
return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)
我们可以测试一下:
preprocess(dataset['train'][:5])
{
'input_ids': [[101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 1998, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 2030, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 1996, 2062, 2057, 2817, 16025, 1010, 1996, 13675, 16103, 2121, 2027, 2131, 1012, 102], [101, 2154, 2011, 2154, 1996, 8866, 2024, 2893, 14163, 8024, 3771, 1012, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
下面对数据集里面所有样本进行预处理,处理的方式是使用map函数,将预处理函数prepare_train_features
应用到所有样本上。
encoded_dataset = dataset.map(preprocess, batched=True)
返回的结果会进行缓存,下次可以不需要重新计算。如果不需要加载缓存,使用load_from_cache_file=False
参数。
微调预训练模型
既然我们是做seq2seq任务,那么需要一个能解决该任务的模型类。我们使用AutoModelForSequenceClassification
这个类,和tokenizer相似,from_pretrained
方法同样可以帮助我们下载并加载模型:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
# STS-B是一个回归问题,MNLI是一个3分类问题
num_labels = 3 if task.startswith('mnli') else 1 if task =='stsb' else 2
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
为了得到一个Trainer
训练工具,我们还需要3个要素,其中最重要的是训练的设定/参数TrainingArguments
。这个训练设定包含了能够定义训练过程的所有属性。
metric_name ='pearson' if task=='stsb' else 'matthews_correlation' if task =='cola' else 'accuracy'
args = TrainingArguments(
'test-glue',
evaluation_strategy='epoch',
save_strategy='epoch',
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=5,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model=metric_name,
)
evaluation_strategy='epoch'
指每个epoch都会做一次验证评估。
由于不同的任务需要不同的评测指标,我们定义一个函数来根据任务名称得到评价方法:
import numpy as np
def compute_metrics(eval_pred):
predictions, labels = eval_pred
if task != 'stsb':
predictions = np.argmax(predictions, axis=1)
else:
predictions = predictions[:,0]
return metric.compute(predictions=predictions, references=labels)
然后全部传给Trainer
:
validation_key = 'validation_mismatched' if task =='mnli-mm' else 'validation_matched' if task == 'mnli' else 'validation'
trainer = Trainer(
model,
args,
train_dataset=encoded_dataset['train'],
eval_dataset=encoded_dataset[validation_key],
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
开始训练:
trainer.train()
The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.
/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:310: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
FutureWarning,
***** Running training *****
Num examples = 8551
Num Epochs = 5
Instantaneous batch size per device = 16
Total train batch size (w. parallel, distributed & accumulation) = 16
Gradient Accumulation steps = 1
Total optimization steps = 2675
[2675/2675 03:05, Epoch 5/5]
Epoch Training Loss Validation Loss Matthews Correlation
1 0.524700 0.531839 0.397460
2 0.347900 0.517831 0.488779
3 0.235900 0.568724 0.520212
4 0.183700 0.776045 0.500206
5 0.138100 0.811927 0.521112
训练结束之后可以进行评估:
trainer.evaluate()
{
'epoch': 5.0,
'eval_loss': 0.8119268417358398,
'eval_matthews_correlation': 0.5211120728046958,
'eval_runtime': 0.8952,
'eval_samples_per_second': 1165.111,
'eval_steps_per_second': 73.727}
超参数搜索
如果不知道如何设定参数,我们可以间超参数搜索。Trainer
支持超参数搜索,但需要使用optuna或Ray Tune代码库。
!pip install optuna ray[tune]
超参数搜索时,Trainer
将会返回多个训练好的模型,所以需要传入一个定义好的模型从而让Trainer
可以不断重新初始化传入的模型:
def model_init():
return AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
和之前调用Trainer
类似:
trainer = Trainer(
model_init=model_init,
args=args,
train_dataset=encoded_dataset['train'],
eval_dataset=encoded_dataset[validation_key],
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
调用方法hyperparameter_seach
,但这个过程会很久。我们可以先用部分数据进行超参数搜索,再进行全量训练。
比如使用1/10的数据进行搜索:
best_run=trainer.hyperparameter_search(n_trials=10,direction='maximize')
它会返回效果最好的模型相关参数:
best_run
BestRun(run_id='2', objective=0.5408374954915984, hyperparameters={
'learning_rate': 1.8520571467952223e-05, 'num_train_epochs': 4, 'seed': 37, 'per_device_train_batch_size': 8})
如果想将Trainer
设置为搜索得到的最好参数,进行训练,可以这样:
for n,v in best_run.hyperparameters.items():
setattr(trainer.args,n,v)
trainer.train()
参考
- 贪心学院课程
边栏推荐
- The "message withdrawal" of a push message push, one click traceless message withdrawal makes the operation no longer difficult
- Amélioration de l'efficacité de la requête 10 fois! 3 solutions d'optimisation pour résoudre le problème de pagination profonde MySQL
- Have you entered the workplace since the first 00???
- Easy to win insert sort
- Session learning diary 1
- Rhcsa day 2
- [database I] database overview, common commands, view the table structure of 'demo data', simple query, condition query, sorting data, data processing function (single row processing function), groupi
- Optimization theory: definition of convex function + generalized convex function
- what does ctrl + d do?
- [UE4] parse JSON string
猜你喜欢
Leetcode51.n queen
Stm32bug [stlink forced update prompt appears in keilmdk, but it cannot be updated]
Ningde times and BYD have refuted rumors one after another. Why does someone always want to harm domestic brands?
Nbear introduction and use diagram
Consul of distributed service registration discovery and unified configuration management
Record a problem that soft deletion fails due to warehouse level error
1day vulnerability pushback skills practice (3)
I stepped on a foundation pit today
If you have just joined a new company, don't be fired because of your mistakes
WP collection plug-in free WordPress collection hang up plug-in
随机推荐
Redis transaction
Unity writes a character controller. The mouse controls the screen to shake and the mouse controls the shooting
Leecode 122. Zuijia timing of buying and selling stocks ②
Talking about custom conditions and handling errors in MySQL Foundation
Jenkins continuous integration environment construction V (Jenkins common construction triggers)
Package details_ Four access control characters_ Two details of protected
[database I] database overview, common commands, view the table structure of 'demo data', simple query, condition query, sorting data, data processing function (single row processing function), groupi
CSP drawing
Contest3145 - the 37th game of 2021 freshman individual training match_ F: Smallest ball
Leetcode 110 balanced binary tree
Contest3145 - the 37th game of 2021 freshman individual training match_ 1: Origami
Li Chuang EDA learning notes 13: electrical network for drawing schematic diagram
XSS prevention
(column 23) typical C language problem: find the minimum common multiple and maximum common divisor of two numbers. (two solutions)
Love and self-discipline and strive to live a core life
[development team follows] API specification
Monitoring - Prometheus introduction
Osnabrueck University | overview of specific architectures in the field of reinforcement learning
static hostname; transient hostname; pretty hostname
Latex tips slash \backslash