当前位置:网站首页>利用huggingface进行文本分类
利用huggingface进行文本分类
2022-06-30 00:42:00 【这个利弗莫尔不太冷】
在 Hub 中,您可以找到 AI 社区共享的 27,000 多个模型,这些模型在情感分析、对象检测、文本生成、语音识别等任务上具有最先进的性能。
from transformers import pipeline
#sentiment_pipeline = pipeline("sentiment-analysis")
data = [
"This is wonderful and easy to put together. My cats love it.",
"This cat tree is almost perfect. I wanted a tall tree, and this one delivers. It reaches almost to the top of my 8\' ceiling",
"The super large box had disintegrated by the time it arrived to my doorstep & large portions were missing from a 89” solid wood cat tree. I took detailed pictures of the box before & after unpacking & laying out all contents. Several pieces were badly damaged & 3 crucial pieces were missing.<br/>A 45 minute phone call with Amazon resulted in Amazon requesting missing parts from Armarkat who never responded despite my repeated attempts to follow-through. Amazon offered for me to purchase another box, pack it & haul the box (weighs more than I weigh) to a place to be picked up. There’s no opportunity to do that where I live.<br/><br/>It’s a very expensive loss"]
sentiment_pipeline = pipeline("sentiment-analysis")
print(sentiment_pipeline(data))
在自己的亚马逊数据集上训练

from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset
import os.path as osp
import os
import numpy as np
from datasets import load_metric
### pretrained model :distilbert-base-uncased
### bert-base-uncased
### gpt2
### distilgpt2
def get_list(path,file_list):
end_list = []
for sample in file_list:
sample_path = osp.join(path,sample)
end_list.append(sample_path)
return end_list
def get_dataset(dataset_path):
test_path = osp.join(dataset_path,'test/')
train_path = osp.join(dataset_path,'train/')
val_path = osp.join(dataset_path,'val/')
test_file_list = os.listdir(test_path)
train_file_list = os.listdir(train_path)
val_file_list = os.listdir(val_path)
test_list = get_list(test_path,test_file_list)
train_list = get_list(train_path,train_file_list)
val_list = get_list(val_path,val_file_list)
return test_list,train_list,val_list
def check_the_wrong_sample(labels,predictions):
val_folder = '/cloud/cloud_disk/users/huh/dataset/nlp_dataset/question_dataset/process_data/cattree_product_quality/val'
end_folder = '/cloud/cloud_disk/users/huh/dataset/nlp_dataset/question_dataset/process_data/cattree_product_quality/wrong_sample'
sample_list = os.listdir(val_folder)
index = 0
for samle in labels:
if samle != predictions[index]:
print(index)
print(sample_list[index])
wrong_sample_path = osp.join(val_folder,sample_list[index])
end_sample_path = osp.join(end_folder,sample_list[index])
os.system("cp {} {}".format(wrong_sample_path,end_sample_path))
index +=1
def compute_metric(eval_pred):
metric = load_metric("accuracy")
logits,labels = eval_pred
print(logits,labels)
print(len(logits),len(labels))
predictions = np.argmax(logits,axis=-1)
print(len(predictions))
print('predictions')
print(predictions)
check_the_wrong_sample(labels,predictions)
return metric.compute(predictions = predictions,references = labels)
def train(dataset_path):
test_list,train_list,val_list = get_dataset(dataset_path)
question_dataset = load_dataset('json', data_files={'train':train_list,'test':test_list,'val':val_list})
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)
tokenized_imdb = question_dataset.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
training_args = TrainingArguments(
output_dir="./results",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=5,
weight_decay=0.01,
logging_steps = 50,
run_name = "catree",
save_strategy='no'
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_imdb["train"],
eval_dataset=tokenized_imdb["val"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metric
)
trainer.train()
trainer.evaluate()
if __name__ == '__main__':
#dataset_path = '/cloud/cloud_disk/users/huh/dataset/nlp_dataset/question_dataset/process_data/catree_personality_2.0'
dataset_path = '/cloud/cloud_disk/users/huh/dataset/nlp_dataset/question_dataset/process_data/cattree_product_quality'
train(dataset_path)
边栏推荐
- [MySQL basic] general syntax 2
- 阿四的情绪波动
- 面试官:大量请求 Redis 不存在的数据,从而影响数据库,该如何解决?
- How to switch to root in xshell
- The third bullet of wechat payment app application for refund
- [mrctf2020]ezpop-1 | PHP serialization
- 【服装软件】服装出产办理体系选型的准则有哪些?
- Sofaregistry source code | data synchronization module analysis
- 测试用例设计方法之等价类划分方法
- Relevance - canonical correlation analysis
猜你喜欢

传统微服务框架如何无缝过渡到服务网格 ASM

在线SQL转CSV工具

Yunna | fixed assets information system management, information-based fixed assets management

MySQL advanced 1

赛芯电子冲刺科创板上市:拟募资6.23亿元,共有64项专利申请信息
![[cloud native] kernel security in container scenario](/img/cc/828a8f246b28cb02b7efa1bdd8dee4.png)
[cloud native] kernel security in container scenario

I / o initial et son fonctionnement de base

Distributed task scheduling elasticjob demo
![[qnx hypervisor 2.2 user manual]6.2.2 communication between guest and host](/img/a4/a84f916d3aa2cc59f5b686cd32797a.png)
[qnx hypervisor 2.2 user manual]6.2.2 communication between guest and host

How to switch to root in xshell
随机推荐
Relevance - canonical correlation analysis
Intermittent sampling and forwarding interference
2022年最新最详细IDEA关联数据库方式、在IDEA中进行数据库的可视化操作(包含图解过程)
Flask web minimalist tutorial (III) - Sqlalchemy (part a)
MySQL基礎2
leetcode-1. Sum of two numbers
阿四的情绪波动
Sofaregistry source code | data synchronization module analysis
练习副“产品”:自制七彩提示字符串展示工具(for循环、if条件判断)
Small and medium-sized enterprises should pay attention to these points when signing ERP contracts
阿洛觉得自己迷茫
【每日一题】二叉树的前后序遍历
[daily question 1] traversal of binary tree
Use of shortcut keys for idea tools
数据库学习笔记(SQL03)
Practical application of information security
【编程题】迷宫问题
Mysql Duplicate entry ‘xxx‘ for key ‘xxx‘
YuMinHong: my retreat and advance; The five best software architecture patterns that architects must understand; Redis kills 52 consecutive questions | manong weekly VIP member exclusive email weekly
How to seamlessly transition from traditional microservice framework to service grid ASM