当前位置:网站首页>Hugging face 的入门使用
Hugging face 的入门使用
2022-07-28 05:24:00 【SCHLAU_tono】
安装开发环境
除了安装基本的python安装环境外,在Anaconda中安装datasets和transformers包。或者在Jupyter Notebook中直接运行
!pip install datasets
!pip install transformers
从这两个包中导入一些必要的包
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel,Trainer, TrainingArguments,AutoModelForSequenceClassification
读取数据和预训练模型,这里加载烂番茄(rotten_tomatoes)的评论包, 以及微软的BERT模型蒸馏microsoft/xtremedistil-l6-h256-uncased
烂番茄的评论包是个二分类数据库,因此加载Model的时候设置num_labels=2, 设置hidden_dropout_prob 防止过拟合。
dataset = load_dataset("rotten_tomatoes")
tokenizer = AutoTokenizer.from_pretrained("microsoft/xtremedistil-l6-h256-uncased")
model = AutoModelForSequenceClassification.from_pretrained("microsoft/xtremedistil-l6-h256-uncased",num_labels=2,hidden_dropout_prob=2)
需要注意的是, 在model card中官方给定的加载模型代码是:
tokenizer = AutoTokenizer.from_pretrained("microsoft/xtremedistil-l6-h256-uncased")
model = AutoModel.from_pretrained("microsoft/xtremedistil-l6-h256-uncased")
但在个人使用过程中,使用AutoModel会在trainer.train() 时出现 * TypeError: forward() got an unexpected keyword argument ‘label’ in main*的错误。 目前查到的原因是由于Automodel的forward中不包含label参数,但计算loss时需要label。Stackflow上有帖子1, Post2 解决方法。但都解决不了我的问题,最后我参考了在Colab上的一个简单的TrainEmotion例子, 解决方法是将AutoModel替换改为使用AutoModelForSequenceClassification
预处理数据
预处理部分需要做机件事:
- 给输入句子分词tokenize
- 建立字典, 并向量化
- 填充padding, 和裁剪truncation
- 分成Train,valid,test集合
- 准备 评估模型性能的metric
由于我们使用预训练的tokenizer,因此预处理部分的代码可以用简单的几行代码实现
def tokenize_fc(dataset):
return tokenizer(dataset['text'], max_length=256, padding='max_length', truncation=True)
tokenized_datasets=dataset.map(tokenize_fc,batched=True)
tokenized_datasets=tokenized_datasets.remove_columns(['text'])
tokenized_datasets.set_format('torch')
SEED=123
small_trainset=tokenized_datasets['train'].shuffle(seed=SEED).select(range(2000))
small_validset=tokenized_datasets['validation'].shuffle(seed=SEED).select(range(200))
small_testset=tokenized_datasets['test'].shuffle(seed=SEED)
处理完后的数据集tokenized_datasets的输出如下:
加载matric:
import numpy as np
from datasets import load_metric
metric = load_metric("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
开始训练模型
首先将model设置为运行在GPU
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
需要注意的是,通常来说在训练模型时,除了要把model设置为运行在GPU,也要把输入数据设置为运行在GPU,但是HuggingFace的模型内置自动将batch设置为运行在GPU,如果GPU可用。所以不需要重复设置。我自身也验证过:即使没有手动设置输入数据运行在GPU,模型在GPU和CPU的训练时间远远不同
By default, the Trainer will use the GPU if it is available. It will automatically put the model on te GPU as well as each batch as soon as that’s necessary. So just remove all .to() calls that you made manually. –BramVanroy
Hugging face 官方推荐使用Trainer()来规范统一训练和测试模型。这里我使用Trainer()来训练模型,之后可能再使用自定义的Trainer来训练模型
training_args = TrainingArguments(
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=20,
learning_rate=2e-05,
weight_decay=0.001,
evaluation_strategy="epoch")
trainer = Trainer(
model=model, # the instantiated Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=small_trainset, # training dataset
eval_dataset=small_validset, # evaluation dataset
compute_metrics=compute_metrics
)
trainer.train()
整体代码如下
坑边闲话:
在训练模型的时候遇到一些问题会记录在接下来的文档中。
边栏推荐
- 8类网线测试仪AEM testpro CV100 和FLUKE DSX-8000哪些事?
- clickhouse建宽表多少列最合适?
- Beginners choose sensors
- 论文神器 VS Code + LaTex + LaTex Workshop
- 短跳线DSX-8000测试正常,但是DSX-5000测试无长度显示?
- 福禄克DSX2-5000、DSX2-8000模块如何找到校准到期日期?
- (PHP graduation project) obtained based on PHP novel website management system
- set_multicycle_path
- Create a basic report using MS chart controls
- 低功耗设计-isolation cell
猜你喜欢

EMC experiment practical case ESD electrostatic experiment

毕业论文 | 文献综述应该怎么写

Surge impact immunity experiment (surge) -emc series Hardware Design Notes 6

(PHP graduation project) based on PHP online travel website management system to obtain

mysql join技巧

set_ case_ analysis

AEM-TESTpro K50和南粤勘察结下的缘分

Design and analysis of contactor coil control circuit

AEM online product promotion conference - Cable certification tester

When to replace jack socket for dsx-pc6 jumper module?
随机推荐
set_ clock_ groups
Learning notes on hardware circuit design 2 -- step-down power circuit
保研面试中常见的英语问题有哪些?
How to calibrate dsx2-8000? Calibration process?
Fluke fluke aircheck WiFi tester cannot configure file--- Ultimate solution experience
论文神器 VS Code + LaTex + LaTex Workshop
雷达成像 Matlab 仿真 1 —— LFM信号及其频谱
mysql join技巧
线缆测试中遇到苦恼---某厂商案例分析?
毕业论文 | 文献综述应该怎么写
Learning notes of hardware circuit design 1 -- temperature rise design
Best practices to ensure successful deployment of Poe devices
How does fluke dtx-1800 test cat7 network cable?
PLC的整体认识
测量电脑电池容量
Beginners choose sensors
Varistor design parameters and classic circuit recording hardware learning notes 5
低功耗设计-isolation cell
TCL和ELTCL?CDNEXT和CMRL?
【YOLOv5】环境搭建:Win11 + mx450