当前位置:网站首页>深入浅出对话系统——使用Transformer进行文本分类
深入浅出对话系统——使用Transformer进行文本分类
2022-07-04 03:32:00 【愤怒的可乐】
引言
我们将展示使用Transformer库中的模型来解决文本分类任务,任务来源于GLUE Benchmark。
任务 | 简述 | 类型 | 单/双句输入 | 标签 | 评估指标 |
---|---|---|---|---|---|
CoLA | 输入句子是否语法上可接受 | 分类 | 单句 | {0:no,1:yes} | Matthews相关系数 |
SST-2 | 电影评论情感二分类 | 分类 | 单句 | {0:negative,1:positive} | 准确率 |
MRPC | 输入的句子对是否语义等价 | 分类 | 双句 | {0:no,1:yes} | 准确率/F1 |
STS-B | 判断输入句子对的相似度1至5档 | 回归 | 双句 | [1.0,5.0] | Pearson/Spearman相关系数 |
QQP | 判断两个Quora的问题是否语义等价 | 分类 | 双句 | {0:not similary,1:similar} | 准确率/F1 |
MNLI | 判断后句是否蕴含了前句,三类标签蕴含、中性或相反 | 分类 | 双句 | {0:entailend,1:neutral,2:contradiction} | 准确率 |
QNLI | 从SQuAD阅读理解中抽出来的问答句子对,判断后句是否可回答前句的问题 | 分类 | 双句 | {0:entailend,1:not entailend} | 准确率 |
RTE | 判断输入的句子对后句是否蕴含前句 | 分类 | 双句 | {0:entailend,1:not entailend} | 准确率 |
WNLI | 输入句子对判断后句对前句中代词的指代消解是否正确 | 分类 | 双句 | {0:no,1:yes} | 准确率 |
对于以上任务,我们将展示如何使用简单的Dataset库加载数据集,同时使用transformer中的Trainer
接口对预训练模型进行微调。
数据加载
from datasets import load_dataset, load_metric
task = 'cola'
model_checkpoint = 'distilbert-base-uncased'
batch_size = 16
除了mnli-mm
以外,其他任务都可以通过任务名称进行加载。
actual_task = 'mnli' if task == 'mnli-mm' else task
dataset = load_dataset('glue', actual_task) # 加载数据集
metric = load_metric('glue', actual_task) # 加载数据集相关的评测指标
dataset
DatasetDict({
train: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 8551
})
validation: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1043
})
test: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1063
})
})
dataset['train'][0]
{
'idx': 0,
'label': 1,
'sentence': "Our friends won't buy this analysis, let alone the next one we propose."}
数据预处理
我们需要对数据进行预处理,预处理的工具叫Tokenizer
。首先对输入进行tokenize得到tokens,然后转化为预训练模型中需要对应的token ID,再转化为模型需要的输入格式。
为了达到数据预处理的目的,我们使用AutoTokenizer.from_pretrained
方式实例化我们的tokenizer,这样可以确保:
- 我们得到一个与预训练模型一一对应的tokenizer。
- 使用指定的模型checkpoint对应的tokenizer的时候,我们也下载了模型需要的词表库vocabulary,准确来说是tokens vocabulary。
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) # use_fast使用多线程快速tokenizer,有些模型可能没有
tokenizer既可以对单个文本进行预处理,也可以对一对文本进行预处理。
tokenizer('Hello,this one sentence!', 'And this sentence goes with it.')
{
'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
为了预处理我们的数据,我们需要知道不同数据和对应的数据格式,因此我们定义下面这个dict:
task_to_keys = {
'cola':('sentence', None),
'mnli':('premise', 'hypothesis'),
'mnli-mm':('premise', 'hypothesis'),
'mrpc':('sentence1', 'sentence2'),
'qnli':('question', 'sentence'),
'qqp':('question1', 'question2'),
'rte':('sentence1', 'sentence2'),
'sst2':('sentence', None),
'stsb':('sentence1', 'sentence2'),
'wnli':('sentence1', 'sentence2')
}
将预训练的代码放到一个函数中:
sentence1_key,sentence2_key = task_to_keys[task]
def preprocess(examples):
# 如果输入是两句话,则对它们都进行tokenize,否则只对第一句话进行。
if sentence2_key is None:
return tokenizer(examples[sentence1_key], truncation=True)
return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)
我们可以测试一下:
preprocess(dataset['train'][:5])
{
'input_ids': [[101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 1998, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 2030, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 1996, 2062, 2057, 2817, 16025, 1010, 1996, 13675, 16103, 2121, 2027, 2131, 1012, 102], [101, 2154, 2011, 2154, 1996, 8866, 2024, 2893, 14163, 8024, 3771, 1012, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
下面对数据集里面所有样本进行预处理,处理的方式是使用map函数,将预处理函数prepare_train_features
应用到所有样本上。
encoded_dataset = dataset.map(preprocess, batched=True)
返回的结果会进行缓存,下次可以不需要重新计算。如果不需要加载缓存,使用load_from_cache_file=False
参数。
微调预训练模型
既然我们是做seq2seq任务,那么需要一个能解决该任务的模型类。我们使用AutoModelForSequenceClassification
这个类,和tokenizer相似,from_pretrained
方法同样可以帮助我们下载并加载模型:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
# STS-B是一个回归问题,MNLI是一个3分类问题
num_labels = 3 if task.startswith('mnli') else 1 if task =='stsb' else 2
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
为了得到一个Trainer
训练工具,我们还需要3个要素,其中最重要的是训练的设定/参数TrainingArguments
。这个训练设定包含了能够定义训练过程的所有属性。
metric_name ='pearson' if task=='stsb' else 'matthews_correlation' if task =='cola' else 'accuracy'
args = TrainingArguments(
'test-glue',
evaluation_strategy='epoch',
save_strategy='epoch',
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=5,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model=metric_name,
)
evaluation_strategy='epoch'
指每个epoch都会做一次验证评估。
由于不同的任务需要不同的评测指标,我们定义一个函数来根据任务名称得到评价方法:
import numpy as np
def compute_metrics(eval_pred):
predictions, labels = eval_pred
if task != 'stsb':
predictions = np.argmax(predictions, axis=1)
else:
predictions = predictions[:,0]
return metric.compute(predictions=predictions, references=labels)
然后全部传给Trainer
:
validation_key = 'validation_mismatched' if task =='mnli-mm' else 'validation_matched' if task == 'mnli' else 'validation'
trainer = Trainer(
model,
args,
train_dataset=encoded_dataset['train'],
eval_dataset=encoded_dataset[validation_key],
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
开始训练:
trainer.train()
The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.
/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:310: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
FutureWarning,
***** Running training *****
Num examples = 8551
Num Epochs = 5
Instantaneous batch size per device = 16
Total train batch size (w. parallel, distributed & accumulation) = 16
Gradient Accumulation steps = 1
Total optimization steps = 2675
[2675/2675 03:05, Epoch 5/5]
Epoch Training Loss Validation Loss Matthews Correlation
1 0.524700 0.531839 0.397460
2 0.347900 0.517831 0.488779
3 0.235900 0.568724 0.520212
4 0.183700 0.776045 0.500206
5 0.138100 0.811927 0.521112
训练结束之后可以进行评估:
trainer.evaluate()
{
'epoch': 5.0,
'eval_loss': 0.8119268417358398,
'eval_matthews_correlation': 0.5211120728046958,
'eval_runtime': 0.8952,
'eval_samples_per_second': 1165.111,
'eval_steps_per_second': 73.727}
超参数搜索
如果不知道如何设定参数,我们可以间超参数搜索。Trainer
支持超参数搜索,但需要使用optuna或Ray Tune代码库。
!pip install optuna ray[tune]
超参数搜索时,Trainer
将会返回多个训练好的模型,所以需要传入一个定义好的模型从而让Trainer
可以不断重新初始化传入的模型:
def model_init():
return AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
和之前调用Trainer
类似:
trainer = Trainer(
model_init=model_init,
args=args,
train_dataset=encoded_dataset['train'],
eval_dataset=encoded_dataset[validation_key],
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
调用方法hyperparameter_seach
,但这个过程会很久。我们可以先用部分数据进行超参数搜索,再进行全量训练。
比如使用1/10的数据进行搜索:
best_run=trainer.hyperparameter_search(n_trials=10,direction='maximize')
它会返回效果最好的模型相关参数:
best_run
BestRun(run_id='2', objective=0.5408374954915984, hyperparameters={
'learning_rate': 1.8520571467952223e-05, 'num_train_epochs': 4, 'seed': 37, 'per_device_train_batch_size': 8})
如果想将Trainer
设置为搜索得到的最好参数,进行训练,可以这样:
for n,v in best_run.hyperparameters.items():
setattr(trainer.args,n,v)
trainer.train()
参考
- 贪心学院课程
边栏推荐
- In my spare time, I like to write some technical blogs and read some useless books. If you want to read more of my original articles, you can follow my personal wechat official account up technology c
- Future源碼一觀-JUC系列
- [Wu Enda deep learning] beginner learning record 3 (regularization / error reduction)
- Webhook triggers Jenkins for sonar detection
- I stepped on a foundation pit today
- static hostname; transient hostname; pretty hostname
- GUI Graphical user interface programming (XIV) optionmenu - what do you want your girlfriend to wear on Valentine's day
- Command Execution Vulnerability - command execution - vulnerability sites - code injection - vulnerability exploitation - joint execution - bypass (spaces, keyword filtering, variable bypass) - two ex
- Backpropagation formula derivation [Li Hongyi deep learning version]
- Package details_ Four access control characters_ Two details of protected
猜你喜欢
Nbear introduction and use diagram
Audio and video technology development weekly | 232
96% of the collected traffic is prevented by bubble mart of cloud hosting
Ningde times and BYD have refuted rumors one after another. Why does someone always want to harm domestic brands?
Problems and solutions of several concurrent scenarios of redis
Summary of Chinese remainder theorem
What are the virtual machine software? What are their respective functions?
JSON string conversion in unity
I stepped on a foundation pit today
Unity controls the selection of the previous and next characters
随机推荐
Keepalived set the master not to recapture the VIP after fault recovery (it is invalid to solve nopreempt)
Command Execution Vulnerability - command execution - vulnerability sites - code injection - vulnerability exploitation - joint execution - bypass (spaces, keyword filtering, variable bypass) - two ex
Solve the problems encountered by the laravel framework using mongodb
MySQL is dirty
How to use websocket to realize simple chat function in C #
The "two-way link" of pushing messages helps app quickly realize two-way communication capability
Zblog collection plug-in does not need authorization to stay away from the cracked version of zblog
Eh, the log time of MySQL server is less than 8h?
Unspeakable Prometheus monitoring practice
Leetcode 110 balanced binary tree
Safety tips - seat belt suddenly fails to pull? High speed police remind you how to use safety belts in a standardized way
96% of the collected traffic is prevented by bubble mart of cloud hosting
Package and download 10 sets of Apple CMS templates / download the source code of Apple CMS video and film website
MySQL backup notes
Enhanced for loop
The property of judging odd or even numbers about XOR.
Apple submitted the new MAC model to the regulatory database before the spring conference
The "message withdrawal" of a push message push, one click traceless message withdrawal makes the operation no longer difficult
[untitled]
Zlmediakit compilation and webrtc push-pull flow testing