当前位置:网站首页>中文语义匹配
中文语义匹配
2022-07-29 23:59:00 【论搬砖的艺术】
数据集

加载数据集
# 加载中文语义匹配数据集lcqmc
from paddlenlp.datasets import load_dataset
train_ds, dev_ds, test_ds = load_dataset("lcqmc", splits=["train", "dev", "test"])
# 数据集返回为MapDataset类型
print("数据类型:", type(train_ds))
# label代表标签,测试集中不包含标签信息
print("训练集样例:", train_ds[0])
print("验证集样例:", dev_ds[0])
print("测试集样例:", test_ds[0])

加载训练模型和分词器
PaddleNLP中Auto模块(包括AutoModel, AutoTokenizer及各种下游任务类)提供了方便易用的接口,无需指定模型类别,即可调用不同网络结构的预训练模型。PaddleNLP的预训练模型可以很容易地通过from_pretrained()方法加载,Transformer预训练模型汇总包含了40多个主流预训练模型,500多个模型权重。
AutoModelForSequenceClassification可用于Point-wise方式的二分类语义匹配任务,通过预训练模型获取输入文本对(query-title)的表示,之后将文本表示进行分类。PaddleNLP已经实现了ERNIE 3.0预训练模型,可以通过一行代码实现ERNIE 3.0预训练模型和分词器的加载。
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
model_name = "ernie-3.0-medium-zh"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_classes=len(train_ds.label_list))
tokenizer = AutoTokenizer.from_pretrained(model_name)
基于预训练模型的数据处理
Dataset中通常为原始数据,需要经过一定的数据处理并进行采样组batch。
通过Dataset的map函数,使用分词器将数据集中query文本和title文本拼接,从原始文本处理成模型的输入。
定义paddle.io.BatchSampler和collate_fn构建 paddle.io.DataLoader。
实际训练中,根据显存大小调整批大小batch_size和文本最大长度max_seq_length。
import functools
import numpy as np
from paddle.io import DataLoader, BatchSampler
from paddlenlp.data import DataCollatorWithPadding
# 数据预处理函数,利用分词器将文本转化为整数序列
def preprocess_function(examples, tokenizer, max_seq_length, is_test=False):
result = tokenizer(text=examples["query"], text_pair=examples["title"], max_seq_len=max_seq_length)
if not is_test:
result["labels"] = examples["label"]
return result
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_seq_length=128)
train_ds = train_ds.map(trans_func)
dev_ds = dev_ds.map(trans_func)
# collate_fn函数构造,将不同长度序列充到批中数据的最大长度,再将数据堆叠
collate_fn = DataCollatorWithPadding(tokenizer)
# 定义BatchSampler,选择批大小和是否随机乱序,进行DataLoader
train_batch_sampler = BatchSampler(train_ds, batch_size=64, shuffle=True)
dev_batch_sampler = BatchSampler(dev_ds, batch_size=128, shuffle=False)
train_data_loader = DataLoader(dataset=train_ds, batch_sampler=train_batch_sampler, collate_fn=collate_fn)
dev_data_loader = DataLoader(dataset=dev_ds, batch_sampler=dev_batch_sampler, collate_fn=collate_fn)
训练和评估
定义训练所需的优化器、损失函数、评价指标等,就可以开始进行预模型微调任务。
# Adam优化器、交叉熵损失函数、accuracy评价指标
optimizer = paddle.optimizer.AdamW(learning_rate=5e-5, parameters=model.parameters())
criterion = paddle.nn.loss.CrossEntropyLoss()
metric = paddle.metric.Accuracy()
# 开始训练
import time
import paddle.nn.functional as F
from eval import evaluate
epochs = 1 # 训练轮次
ckpt_dir = "ernie_ckpt" #训练过程中保存模型参数的文件夹
best_acc = 0
best_step = 0
global_step = 0 #迭代次数
tic_train = time.time()
for epoch in range(1, epochs + 1):
for step, batch in enumerate(train_data_loader, start=1):
input_ids, token_type_ids, labels = batch['input_ids'], batch['token_type_ids'], batch['labels']
# 计算模型输出、损失函数值、分类概率值、准确率
logits = model(input_ids, token_type_ids)
loss = criterion(logits, labels)
probs = F.softmax(logits, axis=1)
correct = metric.compute(probs, labels)
metric.update(correct)
acc = metric.accumulate()
# 每迭代10次,打印损失函数值、准确率、计算速度
global_step += 1
if global_step % 10 == 0:
print(
"global step %d, epoch: %d, batch: %d, loss: %.5f, accu: %.5f, speed: %.2f step/s"
% (global_step, epoch, step, loss, acc,
10 / (time.time() - tic_train)))
tic_train = time.time()
# 反向梯度回传,更新参数
loss.backward()
optimizer.step()
optimizer.clear_grad()
# 每迭代100次,评估当前训练的模型、保存当前最佳模型参数和分词器的词表等
if global_step % 100 == 0:
save_dir = ckpt_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print("global step", global_step, end=' ')
acc_eval = evaluate(model, criterion, metric, dev_data_loader)
if acc_eval > best_acc:
best_acc = acc_eval
best_step = global_step
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
预测
# 测试集数据预处理,利用分词器将文本转化为整数序列
trans_func_test = functools.partial(preprocess_function, tokenizer=tokenizer, max_seq_length=128, is_test=True)
test_ds_trans = test_ds.map(trans_func_test)
# 进行采样组batch
collate_fn_test = DataCollatorWithPadding(tokenizer)
test_batch_sampler = BatchSampler(test_ds_trans, batch_size=32, shuffle=False)
test_data_loader = DataLoader(dataset=test_ds_trans, batch_sampler=test_batch_sampler, collate_fn=collate_fn_test)
# 模型预测分类结果
import paddle.nn.functional as F
label_map = {
0: '不相似', 1: '相似'}
results = []
model.eval()
for batch in test_data_loader:
input_ids, token_type_ids = batch['input_ids'], batch['token_type_ids']
logits = model(batch['input_ids'], batch['token_type_ids'])
probs = F.softmax(logits, axis=-1)
idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist()
preds = [label_map[i] for i in idx]
results.extend(preds)
# 存储LCQMC预测结果
test_ds = load_dataset("lcqmc", splits=["test"])
res_dir = "./results"
if not os.path.exists(res_dir):
os.makedirs(res_dir)
with open(os.path.join(res_dir, "lcqmc.tsv"), 'w', encoding="utf8") as f:
f.write("label\tquery\ttitle\n")
for i, pred in enumerate(results):
f.write(pred+"\t"+test_ds[i]['query']+"\t"+test_ds[i]['title']+"\n")

边栏推荐
- NumPy(二)
- 月薪15k的阿里测试岗,面试原来这么简单
- Framework 到底该怎么学习?
- Add, delete, modify and query the database
- The basic parallel I/O port of single chip microcomputer development
- codeforces每日5题(均1600)-第二十六天
- Paper Intensive Reading - YOLOv3: An Incremental Improvement
- ZLMediaKit源码分析 - NotifyCenter
- Some personal understandings about MySQL indexes (partially refer to MySQL45 lectures)
- Vulkan与OpenGL对比——Vulkan的全新渲染架构
猜你喜欢

Install PyCharm on Raspberry Pi

EA&UML日拱一卒-多任务编程超入门-(2)进程和线程

EA&UML日拱一卒-多任务编程超入门-(7)关于mutex,你必须知道的

MySQL 用 BETWEEN AND 日期查询包含范围边界

从面试官角度分析:面试功能测试工程师主要考察哪些能力?

Apache Doris 1.1 特性揭秘:Flink 实时写入如何兼顾高吞吐和低延时

学会使用MySQL的Explain执行计划,SQL性能调优从此不再困难

微信小程序获取手机号getPhoneNumber接口报错44002

Brute force recursion to dynamic programming 04 (digital string conversion)

【openlayers】Map【1】
随机推荐
ZLMediaKit源码学习——UDP
BEVDetNet:Bird‘s Eye View LiDAR Point Cloud based Real-time 3D Object Detection for Autonomous Drivi
codeforces 线段树题单
全国双非院校考研信息汇总整理 Part.3
Install PyCharm on Raspberry Pi
29岁从事功能测试被辞,面试2个月都找不到工作吗?
WIN2008的IIS上下载文件大小限制之修改
Adaptive feature fusion pyramid network for multi-classes agriculturalpest detection
Override and customize dependent native Bean methods
Some personal understandings about MySQL indexes (partially refer to MySQL45 lectures)
【集训DAY18】Welcome J and Z 【动态规划】
经典论文-SqueezeNet论文及实践
WLAN笔记
C陷阱与缺陷 第5章 库函数 5.1 返回整数的getchar函数
能源企业数字化转型背景下的数据安全治理实践路径
Elephant Swap:借助ePLATO提供加密市场的套利空间
容器化数据库必经之道
Music theory & guitar skills
1326. 灌溉花园的最少水龙头数目 动态规划
2022/7/29 Exam Summary