当前位置:网站首页>李沐动手学深度学习V2-BERT微调和代码实现
李沐动手学深度学习V2-BERT微调和代码实现
2022-08-03 20:05:00 【cv_lhp】
一.BERT微调
1.介绍
自然语言推断是一个序列级别的文本对分类问题,而微调BERT只需要一个额外的基于多层感知机的架构对预训练好的BERT权重参数进行微调,如下图所示。下面将下载一个预训练好的小版本的BERT,然后对其进行微调,以便在SNLI数据集上进行自然语言推断。
2.加载预训练的BERT
在前面博客BERT预训练第二篇:李沐动手学深度学习V2-bert预训练数据集和代码实现和 BERT预训练第三篇:李沐动手学深度学习V2-BERT预训练和代码实现介绍了预训练的BERT(注意原始的BERT模型是在更大的语料库上预训练的,原始的BERT模型有数以亿计的参数)。在下面提供了两个版本的预训练的BERT:“bert.base”与原始的BERT基础模型一样大,需要大量的计算资源才能进行微调,而“bert.small”是一个小版本,以便于演示。
import os
import torch
from torch import nn
import d2l.torch
import json
import multiprocessing
d2l.torch.DATA_HUB['bert.base'] = (d2l.torch.DATA_URL + 'bert.base.torch.zip',
'225d66f04cae318b841a13d32af3acc165f253ac')
d2l.torch.DATA_HUB['bert.small'] = (d2l.torch.DATA_URL + 'bert.small.torch.zip',
'c72329e68a732bef0452e4b96a1c341c8910f81f')
两个预训练好的BERT模型都包含一个定义词表的“vocab.json”文件和一个预训练BERT参数的“pretrained.params”文件,load_pretrained_model函数用于加载预先训练好的BERT参数。
def load_pretrained_model(pretrained_model,num_hiddens,ffn_num_hiddens,num_heads,num_layers,dropout,max_len,devices):
data_dir = d2l.torch.download_extract(pretrained_model)
# 定义空词表以加载预定义词表
vocab = d2l.torch.Vocab()
vocab.idx_to_token = json.load(open(os.path.join(data_dir,'vocab.json')))
vocab.token_to_idx = {
token:idx for idx,token in enumerate(vocab.idx_to_token)}
bert = d2l.torch.BERTModel(len(vocab),num_hiddens=num_hiddens,norm_shape=[256],ffn_num_input=256,ffn_num_hiddens=ffn_num_hiddens,num_heads=num_heads,num_layers=num_layers,dropout=dropout,max_len=max_len,key_size=256,query_size=256,value_size=256,hid_in_features=256,mlm_in_features=256,nsp_in_features=256)
# bert = nn.DataParallel(bert,device_ids=devices).to(devices[0])
# bert.module.load_state_dict(torch.load(os.path.join(data_dir,'pretrained.params')),strict=False)
# 加载预训练BERT参数
bert.load_state_dict(torch.load(os.path.join(data_dir,'pretrained.params')))
return bert,vocab
为了便于在大多数机器上演示,下面加载和微调经过预训练BERT的小版本(“bert.mall”)。
devices = d2l.torch.try_all_gpus()[2:4]
bert,vocab = load_pretrained_model('bert.small',num_hiddens=256,ffn_num_hiddens=512,num_heads=4,num_layers=2,dropout=0.1,max_len=512,devices=devices)
3. 微调BERT的数据集
对于SNLI数据集的下游任务自然语言推断,定义一个定制的数据集类SNLIBERTDataset。在每个样本中,前提和假设形成一对文本序列,并被打包成一个BERT输入序列,片段索引用于区分BERT输入序列中的前提和假设。利用预定义的BERT输入序列的最大长度(max_len),持续移除输入文本对中较长文本的最后一个标记,直到满足max_len。为了加速生成用于微调BERT的SNLI数据集,使用4个工作进程并行生成训练或测试样本。
class SNLIBERTDataset(torch.utils.data.Dataset):
def __init__(self,dataset,max_len,vocab=None):
all_premises_hypotheses_tokens = [[p_tokens,h_tokens] for p_tokens,h_tokens in zip(*[d2l.torch.tokenize([s.lower() for s in sentences]) for sentences in dataset[:2]])]
self.vocab = vocab
self.max_len = max_len
self.labels = torch.tensor(dataset[2])
self.all_tokens_id,self.all_segments,self.all_valid_lens = self._preprocess(all_premises_hypotheses_tokens)
print(f'read {len(self.all_tokens_id)} examples')
def _preprocess(self,all_premises_hypotheses_tokens):
pool = multiprocessing.Pool(4)# 使用4个进程
out = pool.map(self._mp_worker,all_premises_hypotheses_tokens)
all_tokens_id = [tokens_id for tokens_id,segments,valid_len in out]
all_segments = [segments for tokens_id,segments,valid_len in out]
all_valid_lens = [valid_len for tokens_id,segments,valid_len in out]
return torch.tensor(all_tokens_id,dtype=torch.long),torch.tensor(all_segments,dtype=torch.long),torch.tensor(all_valid_lens)
def _mp_worker(self,premises_hypotheses_tokens):
p_tokens,h_tokens = premises_hypotheses_tokens
self._truncate_pair_of_tokens(p_tokens,h_tokens)
tokens,segments = d2l.torch.get_tokens_and_segments(p_tokens,h_tokens)
valid_len = len(tokens)
tokens_id = self.vocab[tokens]+[self.vocab['<pad>']]*(self.max_len-valid_len)
segments = segments+[0]*(self.max_len-valid_len)
return (tokens_id,segments,valid_len)
def _truncate_pair_of_tokens(self,p_tokens,h_tokens):
# 为BERT输入中的'<CLS>'、'<SEP>'和'<SEP>'词元保留位置
while (len(p_tokens)+len(h_tokens))>self.max_len-3:
if len(p_tokens)>len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_tokens_id[idx],self.all_segments[idx],self.all_valid_lens[idx]),self.labels[idx]
def __len__(self):
return len(self.all_tokens_id)
下载完SNLI数据集后,通过实例化SNLIBERTDataset类来生成训练和测试样本,这些样本将在自然语言推断的训练和测试期间进行小批量读取。
#在原始的BERT模型中,max_len=512
batch_size,max_len,num_workers = 512,128,d2l.torch.get_dataloader_workers()
data_dir = d2l.torch.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.torch.read_snli(data_dir,is_train=True),max_len,vocab)
test_set = SNLIBERTDataset(d2l.torch.read_snli(data_dir,is_train=False),max_len,vocab)
train_iter = torch.utils.data.DataLoader(train_set,batch_size,num_workers=num_workers,shuffle=True)
test_iter = torch.utils.data.DataLoader(test_set,batch_size,num_workers=num_workers,shuffle=False)
4. BERT微调
**用于自然语言推断的微调BERT只需要一个额外的多层感知机,该多层感知机由两个全连接层组成,**与前面BERT实现的博客BERT预训练第一篇:李沐动手学深度学习V2-bert和代码实现中BERTClassifier类中进行nsp预测的self.hidden和self.output的多层感知机结构一个。这个多层感知机将特殊的“”词元的BERT表示进行了转换,该词元同时编码前提和假设的信息,经过多层感知机后得到自然语言推断的输出分类特征维:蕴涵、矛盾和中性。
class BERTClassifier(nn.Module):
def __init__(self,bert):
super(BERTClassifier,self).__init__()
self.encoder = bert.encoder
self.hidden = bert.hidden
self.output = nn.Linear(256,3)
def forward(self,inputs):
tokens_X,segments_X,valid_lens_X = inputs
encoded_X = self.encoder(tokens_X,segments_X,valid_lens_X)
return self.output(self.hidden(encoded_X[:,0,:]))
下面将预训练的BERT模型bert被送到用于下游应用的BERTClassifier实例net中。在BERT微调的常见实现中,只有额外的多层感知机(net.output)的输出层的参数将从零开始学习。预训练BERT编码器(net.encoder)和额外的多层感知机的隐藏层(net.hidden)的所有参数都将进行微调。
net = BERTClassifier(bert)
在BERT预训练中MaskLM类和NextSentencePred类在其使用的多层感知机中都有一些参数,这些参数是预训练BERT模型bert中参数的一部分,然而这些参数仅用于计算预训练过程中的遮蔽语言模型损失和下一句预测损失。这两个损失函数与微调下游应用无关,因此当BERT微调时,MaskLM和NextSentencePred中采用的多层感知机的参数不会更新(陈旧的,staled)。
通过d2l.train_batch_ch13()函数使用SNLI的训练集(train_iter)和测试集(test_iter)对net模型进行训练和评估,结果如下图所示。
lr,num_epochs = 1e-4,5
optim = torch.optim.Adam(params=net.parameters(),lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
d2l.torch.train_ch13(net,train_iter,test_iter,loss,optim,num_epochs,devices)

5. 小结
- 针对下游应用对预训练的BERT模型进行微调,例如在SNLI数据集上进行自然语言推断。
- 在微调过程中,BERT模型成为下游应用模型的一部分,再加上多层感知机进行下游应用模型任务的训练和评估。
6. 使用原始BERT的预训练模型进行微调
微调一个更大的预训练BERT模型,该模型与原始的BERT基础模型一样大。修改load_pretrained_model函数中的参数设置:将“bert.mall”替换为“bert.base”,将num_hiddens=256、ffn_num_hiddens=512、num_heads=4和num_layers=2的值分别增加到768、3072、12和12,同时修改多层感知机输出层的Linear层为(nn.Linear(768,3),因为现在经过BERT模型输出特征维变为768),增加微调迭代轮数,代码如下所示。
import os
import torch
from torch import nn
import d2l.torch
import json
import multiprocessing
d2l.torch.DATA_HUB['bert.base'] = (d2l.torch.DATA_URL + 'bert.base.torch.zip',
'225d66f04cae318b841a13d32af3acc165f253ac')
d2l.torch.DATA_HUB['bert.small'] = (d2l.torch.DATA_URL + 'bert.small.torch.zip',
'c72329e68a732bef0452e4b96a1c341c8910f81f')
devices = d2l.torch.try_all_gpus()
def load_pretrained_model1(pretrained_model,num_hiddens,ffn_num_hiddens,num_heads,num_layers,dropout,max_len,devices):
data_dir = d2l.torch.download_extract(pretrained_model)
vocab = d2l.torch.Vocab()
vocab.idx_to_token = json.load(open(os.path.join(data_dir,'vocab.json')))
vocab.token_to_idx = {
token:idx for idx,token in enumerate(vocab.idx_to_token)}
bert = d2l.torch.BERTModel(len(vocab),num_hiddens=num_hiddens,norm_shape=[768],ffn_num_input=768,ffn_num_hiddens=ffn_num_hiddens,num_heads=num_heads,num_layers=num_layers,dropout=dropout,max_len=max_len,key_size=768,query_size=768,value_size=768,hid_in_features=768,mlm_in_features=768,nsp_in_features=768)
# bert = nn.DataParallel(bert,device_ids=devices).to(devices[0])
# bert.module.load_state_dict(torch.load(os.path.join(data_dir,'pretrained.params')),strict=False)
bert.load_state_dict(torch.load(os.path.join(data_dir,'pretrained.params')))
return bert,vocab
bert,vocab = load_pretrained_model1('bert.base',num_hiddens=768,ffn_num_hiddens=3072,num_heads=12,num_layers=12,dropout=0.1,max_len=512,devices=devices)
class SNLIBERTDataset(torch.utils.data.Dataset):
def __init__(self, dataset, max_len, vocab=None):
all_premises_hypotheses_tokens = [[p_tokens, h_tokens] for p_tokens, h_tokens in
zip(*[d2l.torch.tokenize([s.lower() for s in sentences]) for sentences in
dataset[:2]])]
self.vocab = vocab
self.max_len = max_len
self.labels = torch.tensor(dataset[2])
self.all_tokens_id, self.all_segments, self.all_valid_lens = self._preprocess(all_premises_hypotheses_tokens)
print(f'read {len(self.all_tokens_id)} examples')
def _preprocess(self, all_premises_hypotheses_tokens):
pool = multiprocessing.Pool(4) # 使用4个进程
out = pool.map(self._mp_worker, all_premises_hypotheses_tokens)
all_tokens_id = [tokens_id for tokens_id, segments, valid_len in out]
all_segments = [segments for tokens_id, segments, valid_len in out]
all_valid_lens = [valid_len for tokens_id, segments, valid_len in out]
return torch.tensor(all_tokens_id, dtype=torch.long), torch.tensor(all_segments,
dtype=torch.long), torch.tensor(
all_valid_lens)
def _mp_worker(self, premises_hypotheses_tokens):
p_tokens, h_tokens = premises_hypotheses_tokens
self._truncate_pair_of_tokens(p_tokens, h_tokens)
tokens, segments = d2l.torch.get_tokens_and_segments(p_tokens, h_tokens)
valid_len = len(tokens)
tokens_id = self.vocab[tokens] + [self.vocab['<pad>']] * (self.max_len - valid_len)
segments = segments + [0] * (self.max_len - valid_len)
return (tokens_id, segments, valid_len)
def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
# 为BERT输入中的'<CLS>'、'<SEP>'和'<SEP>'词元保留位置
while (len(p_tokens) + len(h_tokens)) > self.max_len - 3:
if len(p_tokens) > len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_tokens_id[idx], self.all_segments[idx], self.all_valid_lens[idx]), self.labels[idx]
def __len__(self):
return len(self.all_tokens_id)
#在原始的BERT模型中,max_len=512
batch_size, max_len, num_workers = 512, 128, d2l.torch.get_dataloader_workers()
data_dir = d2l.torch.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.torch.read_snli(data_dir, is_train=True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.torch.read_snli(data_dir, is_train=False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, num_workers=num_workers, shuffle=True)
test_iter = torch.utils.data.DataLoader(test_set, batch_size, num_workers=num_workers, shuffle=False)
class BERTClassifier(nn.Module):
def __init__(self, bert):
super(BERTClassifier, self).__init__()
self.encoder = bert.encoder
self.hidden = bert.hidden
self.output = nn.Linear(768, 3)
def forward(self, inputs):
tokens_X, segments_X, valid_lens_X = inputs
encoded_X = self.encoder(tokens_X, segments_X, valid_lens_X)
return self.output(self.hidden(encoded_X[:, 0, :]))
net = BERTClassifier(bert)
lr, num_epochs = 1e-4, 20
optim = torch.optim.Adam(params=net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
d2l.torch.train_ch13(net, train_iter, test_iter, loss, optim, num_epochs, devices)
7. 全部代码
import os
import torch
from torch import nn
import d2l.torch
import json
import multiprocessing
d2l.torch.DATA_HUB['bert.base'] = (d2l.torch.DATA_URL + 'bert.base.torch.zip',
'225d66f04cae318b841a13d32af3acc165f253ac')
d2l.torch.DATA_HUB['bert.small'] = (d2l.torch.DATA_URL + 'bert.small.torch.zip',
'c72329e68a732bef0452e4b96a1c341c8910f81f')
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens, num_heads, num_layers, dropout, max_len,
devices):
data_dir = d2l.torch.download_extract(pretrained_model)
# 定义空词表以加载预定义词表
vocab = d2l.torch.Vocab()
vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
vocab.token_to_idx = {
token: idx for idx, token in enumerate(vocab.idx_to_token)}
bert = d2l.torch.BERTModel(len(vocab), num_hiddens=num_hiddens, norm_shape=[256], ffn_num_input=256,
ffn_num_hiddens=ffn_num_hiddens, num_heads=num_heads, num_layers=num_layers,
dropout=dropout, max_len=max_len, key_size=256, query_size=256, value_size=256,
hid_in_features=256, mlm_in_features=256, nsp_in_features=256)
# bert = nn.DataParallel(bert,device_ids=devices).to(devices[0])
# bert.module.load_state_dict(torch.load(os.path.join(data_dir,'pretrained.params')),strict=False)
# 加载预训练BERT参数
bert.load_state_dict(torch.load(os.path.join(data_dir, 'pretrained.params')))
return bert, vocab
devices = d2l.torch.try_all_gpus()[2:4]
bert, vocab = load_pretrained_model('bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4, num_layers=2,
dropout=0.1, max_len=512, devices=devices)
class SNLIBERTDataset(torch.utils.data.Dataset):
def __init__(self, dataset, max_len, vocab=None):
all_premises_hypotheses_tokens = [[p_tokens, h_tokens] for p_tokens, h_tokens in
zip(*[d2l.torch.tokenize([s.lower() for s in sentences]) for sentences in
dataset[:2]])]
self.vocab = vocab
self.max_len = max_len
self.labels = torch.tensor(dataset[2])
self.all_tokens_id, self.all_segments, self.all_valid_lens = self._preprocess(all_premises_hypotheses_tokens)
print(f'read {len(self.all_tokens_id)} examples')
def _preprocess(self, all_premises_hypotheses_tokens):
pool = multiprocessing.Pool(4) # 使用4个进程
out = pool.map(self._mp_worker, all_premises_hypotheses_tokens)
all_tokens_id = [tokens_id for tokens_id, segments, valid_len in out]
all_segments = [segments for tokens_id, segments, valid_len in out]
all_valid_lens = [valid_len for tokens_id, segments, valid_len in out]
return torch.tensor(all_tokens_id, dtype=torch.long), torch.tensor(all_segments,
dtype=torch.long), torch.tensor(
all_valid_lens)
def _mp_worker(self, premises_hypotheses_tokens):
p_tokens, h_tokens = premises_hypotheses_tokens
self._truncate_pair_of_tokens(p_tokens, h_tokens)
tokens, segments = d2l.torch.get_tokens_and_segments(p_tokens, h_tokens)
valid_len = len(tokens)
tokens_id = self.vocab[tokens] + [self.vocab['<pad>']] * (self.max_len - valid_len)
segments = segments + [0] * (self.max_len - valid_len)
return (tokens_id, segments, valid_len)
def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
# 为BERT输入中的'<CLS>'、'<SEP>'和'<SEP>'词元保留位置
while (len(p_tokens) + len(h_tokens)) > self.max_len - 3:
if len(p_tokens) > len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_tokens_id[idx], self.all_segments[idx], self.all_valid_lens[idx]), self.labels[idx]
def __len__(self):
return len(self.all_tokens_id)
#在原始的BERT模型中,max_len=512
batch_size, max_len, num_workers = 512, 128, d2l.torch.get_dataloader_workers()
data_dir = d2l.torch.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.torch.read_snli(data_dir, is_train=True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.torch.read_snli(data_dir, is_train=False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, num_workers=num_workers, shuffle=True)
test_iter = torch.utils.data.DataLoader(test_set, batch_size, num_workers=num_workers, shuffle=False)
class BERTClassifier(nn.Module):
def __init__(self, bert):
super(BERTClassifier, self).__init__()
self.encoder = bert.encoder
self.hidden = bert.hidden
self.output = nn.Linear(256, 3)
def forward(self, inputs):
tokens_X, segments_X, valid_lens_X = inputs
encoded_X = self.encoder(tokens_X, segments_X, valid_lens_X)
return self.output(self.hidden(encoded_X[:, 0, :]))
net = BERTClassifier(bert)
lr, num_epochs = 1e-4, 5
optim = torch.optim.Adam(params=net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
d2l.torch.train_ch13(net, train_iter, test_iter, loss, optim, num_epochs, devices)
8. 相关链接
BERT预训练第一篇:李沐动手学深度学习V2-bert和代码实现
BERT预训练第二篇:李沐动手学深度学习V2-bert预训练数据集和代码实现
BERT预训练第三篇:李沐动手学深度学习V2-BERT预训练和代码实现
BERT微调第一篇:李沐动手学深度学习V2-自然语言推断与数据集SNLI和代码实现
BERT微调第二篇:李沐动手学深度学习V2-BERT微调和代码实现
边栏推荐
猜你喜欢
随机推荐
百利药业IPO过会:扣非后年亏1.5亿 奥博资本是股东
ARMuseum
友宏医疗与Actxa签署Pre-M Diabetes TM 战略合作协议
自定义form表单验证
YARN功能介绍、交互流程及调度策略
Go语言为任意类型添加方法
涨薪5K必学高并发核心编程,限流原理与实战,分布式计数器限流
LeetCode 622. 设计循环队列
async 和 await 原来这么简单
调用EasyCVR云台控制接口时,因网络延迟导致云台操作异常该如何解决?
不要再用if-else
dpkg强制安装软件
PHP according to the longitude and latitude calculated distance two points
matplotlib画polygon, circle
LeetCode 952. Calculate Maximum Component Size by Common Factor
告诉你0基础怎么学好游戏建模?
8.3模拟赛总结
调用EasyCVR接口时视频流请求出现404,并报错SSL Error,是什么原因?
「游戏建模干货」建模大师几步操作,学习经典,赶紧脑补一下吧
Teach you to locate online MySQL slow query problem hand by hand, package teaching package meeting









