当前位置:网站首页>Hugging face 的入门使用
Hugging face 的入门使用
2022-07-28 05:24:00 【SCHLAU_tono】
安装开发环境
除了安装基本的python安装环境外,在Anaconda中安装datasets和transformers包。或者在Jupyter Notebook中直接运行
!pip install datasets
!pip install transformers
从这两个包中导入一些必要的包
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel,Trainer, TrainingArguments,AutoModelForSequenceClassification
读取数据和预训练模型,这里加载烂番茄(rotten_tomatoes)的评论包, 以及微软的BERT模型蒸馏microsoft/xtremedistil-l6-h256-uncased
烂番茄的评论包是个二分类数据库,因此加载Model的时候设置num_labels=2, 设置hidden_dropout_prob 防止过拟合。
dataset = load_dataset("rotten_tomatoes")
tokenizer = AutoTokenizer.from_pretrained("microsoft/xtremedistil-l6-h256-uncased")
model = AutoModelForSequenceClassification.from_pretrained("microsoft/xtremedistil-l6-h256-uncased",num_labels=2,hidden_dropout_prob=2)
需要注意的是, 在model card中官方给定的加载模型代码是:
tokenizer = AutoTokenizer.from_pretrained("microsoft/xtremedistil-l6-h256-uncased")
model = AutoModel.from_pretrained("microsoft/xtremedistil-l6-h256-uncased")
但在个人使用过程中,使用AutoModel会在trainer.train() 时出现 * TypeError: forward() got an unexpected keyword argument ‘label’ in main*的错误。 目前查到的原因是由于Automodel的forward中不包含label参数,但计算loss时需要label。Stackflow上有帖子1, Post2 解决方法。但都解决不了我的问题,最后我参考了在Colab上的一个简单的TrainEmotion例子, 解决方法是将AutoModel替换改为使用AutoModelForSequenceClassification
预处理数据
预处理部分需要做机件事:
- 给输入句子分词tokenize
- 建立字典, 并向量化
- 填充padding, 和裁剪truncation
- 分成Train,valid,test集合
- 准备 评估模型性能的metric
由于我们使用预训练的tokenizer,因此预处理部分的代码可以用简单的几行代码实现
def tokenize_fc(dataset):
return tokenizer(dataset['text'], max_length=256, padding='max_length', truncation=True)
tokenized_datasets=dataset.map(tokenize_fc,batched=True)
tokenized_datasets=tokenized_datasets.remove_columns(['text'])
tokenized_datasets.set_format('torch')
SEED=123
small_trainset=tokenized_datasets['train'].shuffle(seed=SEED).select(range(2000))
small_validset=tokenized_datasets['validation'].shuffle(seed=SEED).select(range(200))
small_testset=tokenized_datasets['test'].shuffle(seed=SEED)
处理完后的数据集tokenized_datasets的输出如下:
加载matric:
import numpy as np
from datasets import load_metric
metric = load_metric("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
开始训练模型
首先将model设置为运行在GPU
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
需要注意的是,通常来说在训练模型时,除了要把model设置为运行在GPU,也要把输入数据设置为运行在GPU,但是HuggingFace的模型内置自动将batch设置为运行在GPU,如果GPU可用。所以不需要重复设置。我自身也验证过:即使没有手动设置输入数据运行在GPU,模型在GPU和CPU的训练时间远远不同
By default, the Trainer will use the GPU if it is available. It will automatically put the model on te GPU as well as each batch as soon as that’s necessary. So just remove all .to() calls that you made manually. –BramVanroy
Hugging face 官方推荐使用Trainer()来规范统一训练和测试模型。这里我使用Trainer()来训练模型,之后可能再使用自定义的Trainer来训练模型
training_args = TrainingArguments(
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=20,
learning_rate=2e-05,
weight_decay=0.001,
evaluation_strategy="epoch")
trainer = Trainer(
model=model, # the instantiated Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=small_trainset, # training dataset
eval_dataset=small_validset, # evaluation dataset
compute_metrics=compute_metrics
)
trainer.train()
整体代码如下
坑边闲话:
在训练模型的时候遇到一些问题会记录在接下来的文档中。
边栏推荐
- WebService error maximum message size quota for incoming messages (65536) has been exceeded
- Web scrolling subtitles (marquee example)
- How can fluke dsx2-5000 and dsx2-8000 modules find the calibration expiration date?
- How to calibrate dsx2-8000? Calibration process?
- 基于 YOLOV5 的 SAR 图像舰船检测
- 低功耗设计-isolation cell
- (PHP graduation project) based on PHP student daily behavior management system access
- 8类网线测试仪AEM testpro CV100 和FLUKE DSX-8000哪些事?
- EfficientNET_ V1
- 毕业论文 | 文献综述应该怎么写
猜你喜欢
随机推荐
PyTorch 学习笔记 1 —— Quick Start
(PHP graduation project) based on thinkphp5 community property management system
How can fluke dsx2-5000 and dsx2-8000 modules find the calibration expiration date?
测量电脑电池容量
ASP. Net read database bound to treeview recursive mode
低功耗设计-Power Switch
TVs tube parameters and selection
When to replace jack socket for dsx-pc6 jumper module?
VAN(DWConv+DWDilationConv+PWConv)
EXFO 730c optical time domain reflectometer only has IOLm optical eye to upgrade OTDR (open OTDR permission)
Efficient Net_V2
Learning notes on hardware circuit design 2 -- step-down power circuit
The short jumper dsx-8000 test is normal, but the dsx-5000 test has no length display?
Cautious speculation about fusion on Apple silicon
浅谈FLUKE光缆认证?何为CFP?何为OFP?
How does fluke dtx-1800 test cat7 network cable?
雷达成像 Matlab 仿真 2 —— 脉冲压缩与加窗
线缆测试中遇到苦恼---某厂商案例分析?
机器学习笔记 5 —— Logistic Regression
MAE 掩码自编码是可扩展的学习









