当前位置:网站首页>Evaluate:huggingface评价指标模块入门详细介绍
Evaluate:huggingface评价指标模块入门详细介绍
2022-08-03 07:30:00 【木尧大兄弟】
一、介绍
evaluate 是huggingface在2022年5月底搞的一个用于评估机器学习模型和数据集的库,需 python 3.7 及以上。包含三种评估类型:
- Metric:用来通过预测值和参考值进行模型评价,是传统意义上的指标,比如 f1、bleu、rouge 等。
- Comparison:同一个测试集对两个(多个)模型评价,比如俩模型结果的 match 程度。
- Measurement:用来评价数据集,比如字数、去重后的词数等。
二、安装
pip安装:
pip install evaluate
源码安装:
git clone https://github.com/huggingface/evaluate.git
cd evaluate
pip install -e .
检查是否装好(会输出预测结果Dict):
python -c "import evaluate; print(evaluate.load('accuracy').compute(references=[1], predictions=[1]))"
三、使用
3.1 load方法
evaluate中的每个指标都是一个单独的Python模块,通过 evaluate.load()(点击查看文档) 函数快速加载,其中load函数的常用参数如下:
- path:必选,str类型。可以是指标名(如
accuracy
或 社区的铁汁们贡献 的muyaostudio/myeval
),如果源码安装还可以是路径名(如./metrics/rouge
或./metrics/rogue/rouge.py
)。我用的后者,因为直接传指标名会联网下载评价脚本,但单位的网不给力。 - config_name:可选,str类型。指标的配置(如 GLUE 指标的每个子集都有一个配置)
- module_type:上文三种评价类型之一,默认
metric
,可选comparison
或measurement
- cache_dir:可选,存储临时预测和引用的路径(默认为
~/.cache/huggingface/evaluate/
)
import evaluate
# module_type 默认为 'metric'
accuracy = evaluate.load("accuracy")
# module_type 显式指定 'metric','comparison','measurement',防止重名。
word_length = evaluate.load("word_length", module_type="measurement")
3.2 列出可用指标
list_evaluation_modules 列出官方(和社区)里有哪些指标,还能看到点赞信息,一共三个参数:
- module_type:要列出的评估模块类型,None是全部,可选
metric
,comparison
,measurement
。 - include_community:是否包含社区模块,默认
True
。 - with_details:返回指标的完整详细Dict信息,而不是str类型的指标名。默认
False
。
print(evaluate.list_evaluation_modules(
module_type="measurement",
include_community=True,
with_details=True)
)
3.3 评估模块都有的属性
所有评估模块都附带一系列有用的属性,这些属性有助于使用存储在 evaluate.EvaluationModuleInfo 对象中的模块,属性如下:
- description:指标介绍
- citation:latex参考文献
- features:输入格式和类型,比如predictions、references等
- inputs_description:输入参数描述文档
- homepage:指标官网
- license:指标许可证
- codebase_urls:指标基于的代码地址
- reference_urls:指标的参考地址
3.4 计算指标值(一次性计算/增量计算)
方式一:一次性计算
函数:EvaluationModule.compute(),传入list/array/tensor等类型的参数references
和predictions
。
>>> import evaluate
>>> metric_name = './evaluate/metrics/accuracy'
>>> accuracy = evaluate.load(metric_name)
>>> accuracy.compute(references=[0,1,0,1], predictions=[1,0,0,1])
{
'accuracy': 0.5} # 输出结果
方式二:单增量的增量计算
函数: EvaluationModule.add(),用于for循环一对一对
地里添加ref和pred,添加完退出循环之后统一计算指标。
>>> for ref, pred in zip([0,0,0,1], [0,0,0,1]):
... accuracy.add(references=ref, predictions=pred)
...
>>> accuracy.compute()
{
'accuracy': 1.0} # 输出结果
方式三:多增量的增量计算
函数: EvaluationModule.add_batch(),用于for循环多对多对
地里添加ref和pred(下面例子是一次添加3对),添加完退出循环之后统一计算指标。
>>> for refs, preds in zip([[0,1],[0,1],[0,1]], [[1,0],[0,1],[0,1]]):
... accuracy.add_batch(references=refs, predictions=preds)
...
>>> accuracy.compute()
{
'accuracy': 0.6666666666666666} # 输出结果
3.5 保存评价结果
函数:evaluate.save(),参数为 path_or_file
, 用于存储文件的路径或文件。如果只提供文件夹,则结果文件将以 result-%Y%m%d-%H%M%S.json
的格式保存;可传 dict 类型的关键字参数 **result
,**params
。
>>> result = accuracy.compute(references=[0,1,0,1], predictions=[1,0,0,1])
>>> hyperparams = {
"model": "bert-base-uncased"}
>>> evaluate.save("./results/", experiment="run 42", **result, **hyperparams)
3.6 自动评估
有点像 Trainer 的封装,可以直接把 evaluate.evaluator() 用做自动评估,且能通过strategy
参数的调整来计算置信区间和标准误差,有助于评估值的稳定性:
from transformers import pipeline
from datasets import load_dataset
from evaluate import evaluator
import evaluate
pipe = pipeline("text-classification", model="lvwerra/distilbert-imdb", device=0)
data = load_dataset("imdb", split="test").shuffle().select(range(1000))
metric = evaluate.load("accuracy")
results = eval.compute(model_or_pipeline=pipe, data=data, metric=metric,
label_mapping={
"NEGATIVE": 0, "POSITIVE": 1},
strategy="bootstrap", n_resamples=200)
print(results)
>>> {
'accuracy':
... {
... 'confidence_interval': (0.906, 0.9406749892841922),
... 'standard_error': 0.00865213251082787,
... 'score': 0.923
... }
... }
边栏推荐
- 加载properties文件,容器总结
- 一文搞懂什么是@Component和@Bean注解以及如何使用
- 剑指offer专项突击版第18天
- @Async注解的坑,小心
- Postman will return to results generated CSV file to the local interface
- consul理解
- Roson的Qt之旅#104 QML Image控件
- Detailed explanation of cause and effect diagram of test case design method
- 酷雷曼上新6大功能,全景营销持续加码
- How to choose a reliable and formal training institution for the exam in September?
猜你喜欢
随机推荐
亿流量大考(1):日增上亿数据,把MySQL直接搞宕机了...
Shell运维开发基础(一)
22-08-02 西安 尚医通(02)Vscode、ES6、nodejs、npm、Bable转码器
华为设备配置BFD状态与接口状态联动
pyspark df secondary sorting
Taro框架-微信小程序-内嵌h5页面
调用feign报错openfeign/feign-core/10.4.0/feign-core-10.4.0.jar
Haisi project summary
学习Glide 常用场景的写法 +
【着色器实现Glow可控局部发光效果_Shader效果第十三篇】
如何使用电子邮件营销在五个步骤中增加产品评论
Roson的Qt之旅#105 QML Image引用大尺寸图片
jolt语法
DSP Trick:向量长度估算
- display image API OpenCV 】 【 imshow () to a depth (data type) at different image processing methods
HCIP笔记整理 2022/7/20
10 分钟彻底理解 Redis 的持久化和主从复制
五、《图解HTTP》报文首部和HTTP缓存
华为设备配置BFD与接口联动(触发与BFD联动的接口物理状态变为Down)
训练正常&异常的GAN损失函数loss变化应该是怎么样的