当前位置:网站首页>中文语义匹配
中文语义匹配
2022-07-29 23:59:00 【论搬砖的艺术】
数据集

加载数据集
# 加载中文语义匹配数据集lcqmc
from paddlenlp.datasets import load_dataset
train_ds, dev_ds, test_ds = load_dataset("lcqmc", splits=["train", "dev", "test"])
# 数据集返回为MapDataset类型
print("数据类型:", type(train_ds))
# label代表标签,测试集中不包含标签信息
print("训练集样例:", train_ds[0])
print("验证集样例:", dev_ds[0])
print("测试集样例:", test_ds[0])

加载训练模型和分词器
PaddleNLP中Auto模块(包括AutoModel, AutoTokenizer及各种下游任务类)提供了方便易用的接口,无需指定模型类别,即可调用不同网络结构的预训练模型。PaddleNLP的预训练模型可以很容易地通过from_pretrained()方法加载,Transformer预训练模型汇总包含了40多个主流预训练模型,500多个模型权重。
AutoModelForSequenceClassification可用于Point-wise方式的二分类语义匹配任务,通过预训练模型获取输入文本对(query-title)的表示,之后将文本表示进行分类。PaddleNLP已经实现了ERNIE 3.0预训练模型,可以通过一行代码实现ERNIE 3.0预训练模型和分词器的加载。
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
model_name = "ernie-3.0-medium-zh"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_classes=len(train_ds.label_list))
tokenizer = AutoTokenizer.from_pretrained(model_name)
基于预训练模型的数据处理
Dataset中通常为原始数据,需要经过一定的数据处理并进行采样组batch。
通过Dataset的map函数,使用分词器将数据集中query文本和title文本拼接,从原始文本处理成模型的输入。
定义paddle.io.BatchSampler和collate_fn构建 paddle.io.DataLoader。
实际训练中,根据显存大小调整批大小batch_size和文本最大长度max_seq_length。
import functools
import numpy as np
from paddle.io import DataLoader, BatchSampler
from paddlenlp.data import DataCollatorWithPadding
# 数据预处理函数,利用分词器将文本转化为整数序列
def preprocess_function(examples, tokenizer, max_seq_length, is_test=False):
result = tokenizer(text=examples["query"], text_pair=examples["title"], max_seq_len=max_seq_length)
if not is_test:
result["labels"] = examples["label"]
return result
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_seq_length=128)
train_ds = train_ds.map(trans_func)
dev_ds = dev_ds.map(trans_func)
# collate_fn函数构造,将不同长度序列充到批中数据的最大长度,再将数据堆叠
collate_fn = DataCollatorWithPadding(tokenizer)
# 定义BatchSampler,选择批大小和是否随机乱序,进行DataLoader
train_batch_sampler = BatchSampler(train_ds, batch_size=64, shuffle=True)
dev_batch_sampler = BatchSampler(dev_ds, batch_size=128, shuffle=False)
train_data_loader = DataLoader(dataset=train_ds, batch_sampler=train_batch_sampler, collate_fn=collate_fn)
dev_data_loader = DataLoader(dataset=dev_ds, batch_sampler=dev_batch_sampler, collate_fn=collate_fn)
训练和评估
定义训练所需的优化器、损失函数、评价指标等,就可以开始进行预模型微调任务。
# Adam优化器、交叉熵损失函数、accuracy评价指标
optimizer = paddle.optimizer.AdamW(learning_rate=5e-5, parameters=model.parameters())
criterion = paddle.nn.loss.CrossEntropyLoss()
metric = paddle.metric.Accuracy()
# 开始训练
import time
import paddle.nn.functional as F
from eval import evaluate
epochs = 1 # 训练轮次
ckpt_dir = "ernie_ckpt" #训练过程中保存模型参数的文件夹
best_acc = 0
best_step = 0
global_step = 0 #迭代次数
tic_train = time.time()
for epoch in range(1, epochs + 1):
for step, batch in enumerate(train_data_loader, start=1):
input_ids, token_type_ids, labels = batch['input_ids'], batch['token_type_ids'], batch['labels']
# 计算模型输出、损失函数值、分类概率值、准确率
logits = model(input_ids, token_type_ids)
loss = criterion(logits, labels)
probs = F.softmax(logits, axis=1)
correct = metric.compute(probs, labels)
metric.update(correct)
acc = metric.accumulate()
# 每迭代10次,打印损失函数值、准确率、计算速度
global_step += 1
if global_step % 10 == 0:
print(
"global step %d, epoch: %d, batch: %d, loss: %.5f, accu: %.5f, speed: %.2f step/s"
% (global_step, epoch, step, loss, acc,
10 / (time.time() - tic_train)))
tic_train = time.time()
# 反向梯度回传,更新参数
loss.backward()
optimizer.step()
optimizer.clear_grad()
# 每迭代100次,评估当前训练的模型、保存当前最佳模型参数和分词器的词表等
if global_step % 100 == 0:
save_dir = ckpt_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print("global step", global_step, end=' ')
acc_eval = evaluate(model, criterion, metric, dev_data_loader)
if acc_eval > best_acc:
best_acc = acc_eval
best_step = global_step
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
预测
# 测试集数据预处理,利用分词器将文本转化为整数序列
trans_func_test = functools.partial(preprocess_function, tokenizer=tokenizer, max_seq_length=128, is_test=True)
test_ds_trans = test_ds.map(trans_func_test)
# 进行采样组batch
collate_fn_test = DataCollatorWithPadding(tokenizer)
test_batch_sampler = BatchSampler(test_ds_trans, batch_size=32, shuffle=False)
test_data_loader = DataLoader(dataset=test_ds_trans, batch_sampler=test_batch_sampler, collate_fn=collate_fn_test)
# 模型预测分类结果
import paddle.nn.functional as F
label_map = {
0: '不相似', 1: '相似'}
results = []
model.eval()
for batch in test_data_loader:
input_ids, token_type_ids = batch['input_ids'], batch['token_type_ids']
logits = model(batch['input_ids'], batch['token_type_ids'])
probs = F.softmax(logits, axis=-1)
idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist()
preds = [label_map[i] for i in idx]
results.extend(preds)
# 存储LCQMC预测结果
test_ds = load_dataset("lcqmc", splits=["test"])
res_dir = "./results"
if not os.path.exists(res_dir):
os.makedirs(res_dir)
with open(os.path.join(res_dir, "lcqmc.tsv"), 'w', encoding="utf8") as f:
f.write("label\tquery\ttitle\n")
for i, pred in enumerate(results):
f.write(pred+"\t"+test_ds[i]['query']+"\t"+test_ds[i]['title']+"\n")

边栏推荐
- From the perspective: the interviewer interview function test engineer mainly inspects what ability?
- Codeforces Round #805 (Div. 3)总结
- devops学习(十) Jenkins 流水线
- 学会使用MySQL的Explain执行计划,SQL性能调优从此不再困难
- 彻底搞懂kubernetes调度框架与插件
- Comprehensive Practice - Three-Mison Chess Mini Game
- devops学习(五) Jenkins 简单完成持续部署
- shell编写规范和变量
- Minesweeper game in c language
- 全国双非院校考研信息汇总整理 Part.6
猜你喜欢

The go language (functions, closures, defer, panic/recover, recursion, structure, json serialization and deserialization)

Some personal understandings about MySQL indexes (partially refer to MySQL45 lectures)

关于 byte 的范围

vim相关介绍(三)

Go日志库——logrus

Windows 安装 MySQL 5.7详细步骤

绘制几何图形

Codeforces Round #805 (Div. 3) Summary

关于MySQL索引的一些个人理解(部分参考MySQL45讲)

Reading notes. This is the psychology: see through the essence of the pseudo psychology (version 10)"
随机推荐
EA&UML日拱一卒-多任务编程超入门-(7)关于mutex,你必须知道的
Paper Intensive Reading - YOLOv3: An Incremental Improvement
关于MySQL索引的一些个人理解(部分参考MySQL45讲)
go语言(函数、闭包、defer、panic/recover,递归,结构体,json序列化与反序列化)
Some personal understandings about MySQL indexes (partially refer to MySQL45 lectures)
C陷阱与缺陷 第3章 语义“陷阱” 3.10 为函数main提供返回值
UE4 makes crosshair + recoil
Brute force recursion to dynamic programming 04 (digital string conversion)
关于 byte 的范围
devops学习(五) Jenkins 简单完成持续部署
单片机开发之基本并行I/O口
How to design and implement report collaboration system for instruction set data products——Development practice of industrial collaborative manufacturing project based on instruction set IoT operating
windows下 PHP 安装
重庆OI 2005 新年好
DFS对树的遍历及一些优化
EA&UML日拱一卒-状态图::重画按钮状态图
NumPy(二)
Adaptive feature fusion pyramid network for multi-classes agriculturalpest detection
全国双非院校考研信息汇总整理 Part.3
rk-boot framework combat (1)