当前位置:网站首页>Bert的使用方法
Bert的使用方法
2022-07-28 05:22:00 【Alan and fish】
1.导入Bert库
我在写代码的时候看到很多代码有的使用以下这种方式导入
from pytorch_pretrained_bert import BertTokenizer,BertModel
有的使用transformer的方式导入的,所有我就有的时候有点郁闷究竟使用那种方式导入.
from transformers import BertTokenizer,BertConfig,BertModel
根据这个博主的博文,https://blog.csdn.net/qq_43391414/article/details/118252012
知道transerformers包包又名pytorch-transformers或者pytorch-pretrained-bert”
但是根据一些了解,实际上transformers库是最新的版本(以前称为pytorch-transformers和pytorch-pretrained-bert)
所以它在前两者的基础上对一些函数与方法进行了改进,包括一些函数可能只有在transformers库里才能使用,所以使用transformers库比较方便。
它提供了一系列的STOA(最先进)模型的实现,包括(Bert、XLNet、RoBERTa等)。
所以导入bert模型的时候推荐使用以下方式比较好
from transformers import BertTokenizer,BertModel
2.bert模型中究竟要输入什么样的格式的数据
我在写代码的时候遇到的情况是,有的人的代码直接把句子分词,处理成id的格式输入到bert模型中,有的人要把数据处理成input_ids,mask_attention,token…各种各样的格式,五花八门的,作为一个小白进入干进入深度学习领域,感觉不是很友好.数据预处理这块,一个人的代码就有一千种写法,真的不知道相信谁.
所以我就看到了bert模型的源码,看到了他的foward函数:
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
这些就是我们要使用bert模型的输入数据格式:
1.input_ids:
- 数据类型:torch.Tensor
- 表示的内容:经过 tokenizer 分词后的 subword 对应的下标列表
2.attention_mask
- 数据类型:torch.Tensor
- 表示的内容:在 self-attention 过程中,这一块 mask 用于标记 subword 所处句子和 padding 的区别,将有用的信息用1表示,将 padding 部分填充为 0;
3.token_type_ids:
- 数据类型:torch.Tensor
- 表示的内容:标记 subword 当前所处句子(第一句/第二句/ padding),如果只有一句话使用0填充
4.position_ids:
- 数据类型:torch.Tensor
- 表示的内容:标记当前词所在句子的位置下标;用1表示padding出来的值
5.head_mask:
- 数据类型:torch.Tensor
- 表示的内容:用于将某些层的某些注意力计算无效化;
6.inputs_embeds:
- 数据类型:torch.Tensor
- 表示的内容:如果提供了,那就不需要input_ids,跨过 embedding lookup 过程直接作为 Embedding 进入 Encoder 计算;
7.encoder_hidden_states:
- 数据类型:torch.Tensor
- 表示的内容:这一部分在 BertModel 配置为 decoder 时起作用,将执行 cross-attention 而不是 self-attention;
8.encoder_attention_mask:
- 数据类型:torch.Tensor
- 表示的内容:这个参数貌似是把预先计算好的 K-V 乘积传入,以降低 cross-attention 的开销(因为原本这部分是重复计算);
9.past_key_values:
- 数据类型:List[torch.FloatTensor]
- 表示的内容:这个参数貌似是把预先计算好的 K-V 乘积传入,以降低 cross-attention 的开销(因为原本这部分是重复计算);
10.use_cache:
- 数据类型:bool
- 表示的内容:将保存上一个参数并传回,加速 decoding;
11.output_attentions:
- 数据类型:bool
- 表示的内容:是否返回中间每层的 attention 输出;
12.output_hidden_states:
- 数据类型:bool
- 表示的内容:是否返回中间每层的输出;
13.return_dict:
- 数据类型:bool
- 表示的内容:是否按键值对的形式(ModelOutput 类,也可以当作 tuple 用)返回输出,默认为真。
====================================================
在这里可以使用Dataset先将数据处理好,放到这里面,然后将Dataset放到DataLoader中,设置好批次,一个批次一个批次的加载数据.
详细的写法可以看我写的笔记,Dataset和DataLoader的使用方法.
3.Bert模型的输出
输入Bert模型中只要输入,input_ids,attention_mask,token_type_ids就可以了,下面只是我的部分代码.
out = self.bert(x['input_ids'], x['attention_mask'], x['token_type_ids'])
输入之后的输出out包括以下四个数据
- last_hidden_state:
torch.FloatTensor类型的,最后一个隐藏层的序列的输出。大小是(batch_size, sequence_length, hidden_size) sequence_length是我们截取的句子的长度,hidden_size是768. - pooler_output:
torch.FloatTensor类型的,[CLS]的这个token的输出,输出的大小是(batch_size, hidden_size) - hidden_states :
tuple(torch.FloatTensor)这是输出的一个可选项,如果输出,需要指定config.output_hidden_states=True,它也是一个元组,它的第一个元素是embedding,其余元素是各层的输出,每个元素的形状是(batch_size, sequence_length, hidden_size) - attentions:
这也是输出的一个可选项,如果输出,需要指定config.output_attentions=True,它也是一个元组,它的元素是每一层的注意力权重,用于计算self-attention heads的加权平均值
边栏推荐
猜你喜欢

Hit your face ins? Mars digital collection platform explores digital collection light social networking

XShell突然间无法连接虚拟机

Distributed cluster architecture scenario optimization solution: session sharing problem
Sqoop安装及使用

小程序开发流程详细是什么呢?

4个角度教你选小程序开发工具?

服务可靠性保障-watchdog

Structured streaming in spark

微信小程序手机号正则校验规则

数藏如何实现WEB3.0社交
随机推荐
Idempotent component
微服务架构认知、服务治理-Eureka
文旅头部结合数字藏品效应显著,但如何支撑用户持续购买力
CertPathValidatorException:validity check failed
微信上的小程序店铺怎么做?
Installation and use of sqoop
小程序开发系统有哪些优点?为什么要选择它?
数字藏品以虚强实,赋能实体经济发展
CertPathValidatorException:validity check failed
微信小程序开发语言一般有哪些?
小程序商城制作一个需要多少钱?一般包括哪些费用?
【二】redis基础命令与使用场景
There is a problem with MySQL paging
使用pycharm创建虚拟环境
MarsNFT :个人如何发行数字藏品?
Sorting and paging, multi table query after class exercise
3:Mysql 主从复制搭建
2:为什么要读写分离
小程序开发流程详细是什么呢?
raise RuntimeError(‘DataLoader worker (pid(s) {}) exited unexpectedly‘.format(pids_str))RuntimeErro