当前位置:网站首页>李沐动手学深度学习V2-自然语言推断与数据集SNLI和代码实现
李沐动手学深度学习V2-自然语言推断与数据集SNLI和代码实现
2022-08-03 20:05:00 【cv_lhp】
一. 斯坦福自然语言推断(SNLI)数据集
1. 介绍
自然语言推断(natural language inference)主要研究
假设(hypothesis)是否可以从前提(premise)中推断出来,
其中两者都是文本序列。
换言之,自然语言推断决定了一对文本序列之间的逻辑关系。这类关系通常分为三种类型:
- 蕴涵(entailment):假设可以从前提中推断出来。
- 矛盾(contradiction):假设的否定可以从前提中推断出来。
- 中性(neutral):所有其他情况。
自然语言推断也被称为识别文本蕴涵任务。
例如,下面的一个文本对将被贴上“蕴涵”的标签,因为假设中的“表白”可以从前提中的“拥抱”中推断出来。
前提:两个女人拥抱在一起。
假设:两个女人在示爱。
下面是一个“矛盾”的例子,因为“运行编码示例”表示“不睡觉”,而不是“睡觉”。
前提:一名男子正在运行Dive Into Deep Learning的编码示例。
假设:该男子正在睡觉。
第三个例子显示了一种“中性”关系,因为“正在为我们表演”这一事实无法推断出“出名”或“不出名”。
前提:音乐家们正在为我们表演。
假设:音乐家很有名。
自然语言推断一直是理解自然语言的中心话题。它有着广泛的应用,从信息检索到开放领域的问答。为了研究这个问题,我们将首先研究一个流行的自然语言推断基准数据集。
2. 下载SNLI数据集
SNLI是由500000多个带标签的英语句子对组成的集合,在路径…/data/snli_1.0中下载并存储提取的SNLI数据集。
import torch
import torch.nn
import d2l.torch
import os
import re
d2l.torch.DATA_HUB['SNLI'] = (
'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
'9fcde07509c7e87ec61c640c1b2753d9041758e4')
data_dir = d2l.torch.download_extract('SNLI')
3. 数据集读取
原始的SNLI数据集包含的信息比我们在实验中真正需要的信息丰富得多。因此,我们定义函数read_snli以仅提取数据集的一部分,然后返回前提、假设及其标签的列表。
def read_snli(data_dir,is_train=True):
"""将SNLI数据集解析为前提、假设和标签"""
def extract_text(s):
# 删除我们不会使用的信息
s = re.sub('\\(','',s)
s = re.sub('\\)','',s)
# 用一个空格替换两个或多个连续的空格
s = re.sub('\\s{2,}',' ',s)
return s.strip()
label_set = {
'entailment':0,'contradiction':1,'neutral':2}
file_path = os.path.join(data_dir,'snli_1.0_train.txt' if is_train else 'snli_1.0_test.txt')
with open(file_path,'r') as f :
rows = [row.split('\t') for row in f.readlines()[1:]] # rows是一个list of list嵌套列表
premises = [extract_text(row[1]) for row in rows if row[0] in label_set] # premises是一个列表,里面元素是一个每一个样本的前提
hypotheses = [extract_text(row[2]) for row in rows if row[0] in label_set] #hypotheses是一个列表,里面元素是一个每一个样本(每一行)的假设
labels = [label_set[row[0]] for row in rows if row[0] in label_set] #labels是一个列表,里面元素是一个每一个样本的label,为0,1,2标签
return premises,hypotheses,labels
打印前3对前提和假设,以及它们的标签(“0”、“1”和“2”分别对应于“蕴涵”、“矛盾”和“中性”)。
train_data = read_snli(data_dir,is_train=True)
for x0,x1,y in zip(train_data[0][:3],train_data[1][:3],train_data[2][:3]):
print('premise:',x0)
print('hypothesis:',x1)
print('label:',y)
输出结果如下:
premise: A person on a horse jumps over a broken down airplane .
hypothesis: A person is training his horse for a competition .
label: 2
premise: A person on a horse jumps over a broken down airplane .
hypothesis: A person is at a diner , ordering an omelette .
label: 1
premise: A person on a horse jumps over a broken down airplane .
hypothesis: A person is outdoors , on a horse .
label: 0
训练集约有550000对,测试集约有10000对。下面显示了训练集和测试集中的三个标签“蕴涵”、“矛盾”和“中性”是平衡的。
test_data = read_snli(data_dir,is_train=False)
for data in [train_data,test_data]:
print([[label for label in data[2]].count(i) for i in range(3)])
输出结果如下:
[183416, 183187, 182764]
[3368, 3237, 3219]
4. 定义用于加载数据集的类
下面定义一个用于加载SNLI数据集的类。类构造函数中的变量num_steps指定文本序列的长度,使得每个小批量序列将具有相同的形状。也即是在较长序列中的前num_steps个标记之后的标记被截断,而特殊标记“”将被附加到较短的序列后,直到它们的长度变为num_steps。通过实现__getitem__功能,我们可以任意访问带有索引idx的前提、假设和标签。
class SNLIDataset(torch.utils.data.Dataset):
"""用于加载SNLI数据集的自定义数据集"""
def __init__(self,dataset,num_steps,vocab=None):
self.num_steps = num_steps
all_premises_tokens = d2l.torch.tokenize(dataset[0],token='word') # all_premises_tokens为一个list of list嵌套列表,列表里面每个元素是每个样本的token词元列表
all_hypotheses_tokens = d2l.torch.tokenize(dataset[1],token='word') # all_hypotheses_tokens为一个list of list嵌套列表,列表里面每个元素是每个样本的token词元列表
if vocab is None:
self.vocab = d2l.torch.Vocab(tokens=all_premises_tokens+all_hypotheses_tokens,min_freq=5,reserved_tokens=['<pad>'])
else:
self.vocab = vocab
self.all_premises_tokens = self._pad(all_premises_tokens)
self.all_hypotheses_tokens = self._pad(all_hypotheses_tokens)
self.all_labels = torch.tensor(dataset[2])
print(f'read {len(self.all_premises_tokens)} examples')
def _pad(self,lines):
return torch.tensor([d2l.torch.truncate_pad(self.vocab[line],self.num_steps,self.vocab['<pad>']) for line in lines])
def __getitem__(self, idx):
return (self.all_premises_tokens[idx],self.all_hypotheses_tokens[idx]),self.all_labels[idx]
def __len__(self):
return len(self.all_premises_tokens)
5. 整合代码
调用read_snli函数和SNLIDataset类来下载SNLI数据集,并返回训练集和测试集的DataLoader实例,以及训练集的词表。注意我们必须使用从训练集构造的词表作为测试集的词表。因此在训练集中训练的模型将不知道来自测试集的任何新词元。
def load_data_snli(batch_size,num_steps = 50):
"""下载SNLI数据集并返回数据迭代器和词表"""
num_workers = d2l.torch.get_dataloader_workers()
data_dir = d2l.torch.download_extract('SNLI')
train_data = read_snli(data_dir,is_train=True)
test_data = read_snli(data_dir,is_train=False)
train_dataset = SNLIDataset(train_data,num_steps,vocab=None)#训练集需要构建自己的vocab
test_dataset = SNLIDataset(test_data,num_steps,vocab=train_dataset.vocab)#注意测试集需要使用train_dataset训练集里面的vocab
train_iter = torch.utils.data.DataLoader(train_dataset,batch_size,shuffle = True,num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_dataset,batch_size,shuffle = False,num_workers=num_workers)#测试集里面的shuffle必须是False,也即是每一轮epoch加载一个批量样本时,样本顺序不会改变
return train_iter,test_iter,train_dataset.vocab
将批量大小设置为128,序列长度设置为50,并调用load_data_snli函数来获取数据迭代器和词表,然后打印词表大小。
train_iter,test_iter,vocab = load_data_snli(128,50)
len(vocab)
输出结果如下:
read 549367 examples
read 9824 examples
18678
打印第一个小批量的形状,前提和假设作为两个输入X[0]和X[1]。
#打印第一个批量的相关输入数据和label数据
for X,Y in train_iter:
print(X[0].shape) #前提的一个批量序列数据
print(X[1].shape) #假设的一个批量序列数据
print(Y.shape) # label的一个批量label数据
break
输出结果如下:
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128])
6. 小结
- 自然语言推断研究“假设”是否可以从“前提”推断出来,其中两者都是文本序列。
- 在自然语言推断中,前提和假设之间的关系包括蕴涵关系、矛盾关系和中性关系。
- 斯坦福自然语言推断(SNLI)语料库是一个比较流行的自然语言推断基准数据集。
7. 全部代码
import torch
import torch.nn
import d2l.torch
import os
import re
d2l.torch.DATA_HUB['SNLI'] = (
'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
'9fcde07509c7e87ec61c640c1b2753d9041758e4')
data_dir = d2l.torch.download_extract('SNLI')
def read_snli(data_dir, is_train=True):
"""将SNLI数据集解析为前提、假设和标签"""
def extract_text(s):
# 删除我们不会使用的信息
s = re.sub('\\(', '', s)
s = re.sub('\\)', '', s)
# 用一个空格替换两个或多个连续的空格
s = re.sub('\\s{2,}', ' ', s)
return s.strip()
label_set = {
'entailment': 0, 'contradiction': 1, 'neutral': 2}
file_path = os.path.join(data_dir, 'snli_1.0_train.txt' if is_train else 'snli_1.0_test.txt')
with open(file_path, 'r') as f:
rows = [row.split('\t') for row in f.readlines()[1:]] # rows是一个list of list嵌套列表
premises = [extract_text(row[1]) for row in rows if row[0] in label_set] # premises是一个列表,里面元素是一个每一个样本的前提
hypotheses = [extract_text(row[2]) for row in rows if row[0] in label_set] #hypotheses是一个列表,里面元素是一个每一个样本(每一行)的假设
labels = [label_set[row[0]] for row in rows if row[0] in label_set] #labels是一个列表,里面元素是一个每一个样本的label,为0,1,2标签
return premises, hypotheses, labels
train_data = read_snli(data_dir, is_train=True)
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):
print('premise:', x0)
print('hypothesis:', x1)
print('label:', y)
test_data = read_snli(data_dir, is_train=False)
for data in [train_data, test_data]:
print([[label for label in data[2]].count(i) for i in range(3)])
[[1, 2, 3], [5, 8, 9]] + [[1, 2, 3], [5, 8, 9]]
class SNLIDataset(torch.utils.data.Dataset):
"""用于加载SNLI数据集的自定义数据集"""
def __init__(self, dataset, num_steps, vocab=None):
self.num_steps = num_steps
all_premises_tokens = d2l.torch.tokenize(dataset[0],
token='word') # all_premises_tokens为一个list of list嵌套列表,列表里面每个元素是每个样本的token词元列表
all_hypotheses_tokens = d2l.torch.tokenize(dataset[1],
token='word') # all_hypotheses_tokens为一个list of list嵌套列表,列表里面每个元素是每个样本的token词元列表
if vocab is None:
self.vocab = d2l.torch.Vocab(tokens=all_premises_tokens + all_hypotheses_tokens, min_freq=5,
reserved_tokens=['<pad>'])
else:
self.vocab = vocab
self.all_premises_tokens = self._pad(all_premises_tokens)
self.all_hypotheses_tokens = self._pad(all_hypotheses_tokens)
self.all_labels = torch.tensor(dataset[2])
print(f'read {len(self.all_premises_tokens)} examples')
def _pad(self, lines):
return torch.tensor(
[d2l.torch.truncate_pad(self.vocab[line], self.num_steps, self.vocab['<pad>']) for line in lines])
def __getitem__(self, idx):
return (self.all_premises_tokens[idx], self.all_hypotheses_tokens[idx]), self.all_labels[idx]
def __len__(self):
return len(self.all_premises_tokens)
def load_data_snli(batch_size, num_steps=50):
"""下载SNLI数据集并返回数据迭代器和词表"""
num_workers = d2l.torch.get_dataloader_workers()
data_dir = d2l.torch.download_extract('SNLI')
train_data = read_snli(data_dir, is_train=True)
test_data = read_snli(data_dir, is_train=False)
train_dataset = SNLIDataset(train_data, num_steps, vocab=None) #训练集需要构建自己的vocab
test_dataset = SNLIDataset(test_data, num_steps, vocab=train_dataset.vocab) #注意测试集需要使用train_dataset训练集里面的vocab
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False,
num_workers=num_workers) #测试集里面的shuffle必须是False,也即是每一轮epoch加载一个批量样本时,样本顺序不会改变
return train_iter, test_iter, train_dataset.vocab
train_iter, test_iter, vocab = load_data_snli(128, 50)
len(vocab)
#打印第一个批量的相关输入数据和label数据
for X, Y in train_iter:
print(X[0].shape) #前提的一个批量序列数据
print(X[1].shape) #假设的一个批量序列数据
print(Y.shape) # label的一个批量label数据
break
8. 相关链接
BERT预训练第一篇:李沐动手学深度学习V2-bert和代码实现
BERT预训练第二篇:李沐动手学深度学习V2-bert预训练数据集和代码实现
BERT预训练第三篇:李沐动手学深度学习V2-BERT预训练和代码实现
BERT微调第一篇:李沐动手学深度学习V2-自然语言推断与数据集SNLI和代码实现
BERT微调第二篇:李沐动手学深度学习V2-BERT微调和代码实现
边栏推荐
- DeepMCP网络详解
- 【leetcode】剑指 Offer II 009. 乘积小于 K 的子数组(滑动窗口、双指针)
- Node version switching tool NVM and npm source manager nrm
- Alexa染料标记RNA核糖核酸|RNA-Alexa 514|RNA-Alexa 488|RNA-Alexa 430
- 染料修饰核酸RNA|[email protected] 610/[email protected] 594/Alexa 56
- Matlab paper illustration drawing template No. 42 - bubble matrix diagram (correlation coefficient matrix diagram)
- Detailed explanation of JWT
- 涨薪5K必学高并发核心编程,限流原理与实战,分布式计数器限流
- Pytorch GPU 训练环境搭建
- net-snmp编译报错:/usr/bin/ld: cannot find crti.o: No such file or directory
猜你喜欢
EasyCVR平台海康摄像头语音对讲功能配置的3个注意事项
async 和 await 原来这么简单
MapReduce介绍及执行过程
Detailed demonstration pytorch framework implementations old photo repair (GPU)
华为设备配置VRRP负载分担
【STM32】标准库-自定义BootLoader
云服务器如何安全使用本地的AD/LDAP?
那些年我写过的语言
- [email protected] 594/[email prote"/>
RNA核糖核酸修饰Alexa 568/[email protected] 594/[email prote
Go语言类型与接口的关系
随机推荐
php根据两点经纬度计算距离
JWT详解
调用EasyCVR云台控制接口时,因网络延迟导致云台操作异常该如何解决?
染料修饰核酸RNA|[email protected] 610/[email protected] 594/Alexa 56
CSDN帐号管理规范
揭秘5名运维如何轻松管理数亿级流量系统
NNLM、RNNLM等语言模型 实现 下一单词预测(next-word prediction)
机器学习中专业术语的个人理解与总结(纯小白)
alicloud3搭建wordpress
ERROR: You don‘t have the SNMP perl module installed.
redis常用命令,HSET,XADD,XREAD,DEL等
Matlab paper illustration drawing template No. 42 - bubble matrix diagram (correlation coefficient matrix diagram)
List类的超详细解析!(超2w+字)
【飞控开发高级教程4】疯壳·开源编队无人机-360 度翻滚
转运RNA(tRNA)甲基化修饰7-甲基胞嘧啶(m7C)|tRNA-m7G
MySQL Basics
Alexa染料标记RNA核糖核酸|RNA-Alexa 514|RNA-Alexa 488|RNA-Alexa 430
PHP according to the longitude and latitude calculated distance two points
async 和 await 原来这么简单
Mapper输出数据中文乱码