当前位置:网站首页>利用huggingface进行文本分类
利用huggingface进行文本分类
2022-06-30 00:42:00 【这个利弗莫尔不太冷】
在 Hub 中,您可以找到 AI 社区共享的 27,000 多个模型,这些模型在情感分析、对象检测、文本生成、语音识别等任务上具有最先进的性能。
from transformers import pipeline
#sentiment_pipeline = pipeline("sentiment-analysis")
data = [
"This is wonderful and easy to put together. My cats love it.",
"This cat tree is almost perfect. I wanted a tall tree, and this one delivers. It reaches almost to the top of my 8\' ceiling",
"The super large box had disintegrated by the time it arrived to my doorstep & large portions were missing from a 89” solid wood cat tree. I took detailed pictures of the box before & after unpacking & laying out all contents. Several pieces were badly damaged & 3 crucial pieces were missing.<br/>A 45 minute phone call with Amazon resulted in Amazon requesting missing parts from Armarkat who never responded despite my repeated attempts to follow-through. Amazon offered for me to purchase another box, pack it & haul the box (weighs more than I weigh) to a place to be picked up. There’s no opportunity to do that where I live.<br/><br/>It’s a very expensive loss"]
sentiment_pipeline = pipeline("sentiment-analysis")
print(sentiment_pipeline(data))
在自己的亚马逊数据集上训练

from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset
import os.path as osp
import os
import numpy as np
from datasets import load_metric
### pretrained model :distilbert-base-uncased
### bert-base-uncased
### gpt2
### distilgpt2
def get_list(path,file_list):
end_list = []
for sample in file_list:
sample_path = osp.join(path,sample)
end_list.append(sample_path)
return end_list
def get_dataset(dataset_path):
test_path = osp.join(dataset_path,'test/')
train_path = osp.join(dataset_path,'train/')
val_path = osp.join(dataset_path,'val/')
test_file_list = os.listdir(test_path)
train_file_list = os.listdir(train_path)
val_file_list = os.listdir(val_path)
test_list = get_list(test_path,test_file_list)
train_list = get_list(train_path,train_file_list)
val_list = get_list(val_path,val_file_list)
return test_list,train_list,val_list
def check_the_wrong_sample(labels,predictions):
val_folder = '/cloud/cloud_disk/users/huh/dataset/nlp_dataset/question_dataset/process_data/cattree_product_quality/val'
end_folder = '/cloud/cloud_disk/users/huh/dataset/nlp_dataset/question_dataset/process_data/cattree_product_quality/wrong_sample'
sample_list = os.listdir(val_folder)
index = 0
for samle in labels:
if samle != predictions[index]:
print(index)
print(sample_list[index])
wrong_sample_path = osp.join(val_folder,sample_list[index])
end_sample_path = osp.join(end_folder,sample_list[index])
os.system("cp {} {}".format(wrong_sample_path,end_sample_path))
index +=1
def compute_metric(eval_pred):
metric = load_metric("accuracy")
logits,labels = eval_pred
print(logits,labels)
print(len(logits),len(labels))
predictions = np.argmax(logits,axis=-1)
print(len(predictions))
print('predictions')
print(predictions)
check_the_wrong_sample(labels,predictions)
return metric.compute(predictions = predictions,references = labels)
def train(dataset_path):
test_list,train_list,val_list = get_dataset(dataset_path)
question_dataset = load_dataset('json', data_files={'train':train_list,'test':test_list,'val':val_list})
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)
tokenized_imdb = question_dataset.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
training_args = TrainingArguments(
output_dir="./results",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=5,
weight_decay=0.01,
logging_steps = 50,
run_name = "catree",
save_strategy='no'
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_imdb["train"],
eval_dataset=tokenized_imdb["val"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metric
)
trainer.train()
trainer.evaluate()
if __name__ == '__main__':
#dataset_path = '/cloud/cloud_disk/users/huh/dataset/nlp_dataset/question_dataset/process_data/catree_personality_2.0'
dataset_path = '/cloud/cloud_disk/users/huh/dataset/nlp_dataset/question_dataset/process_data/cattree_product_quality'
train(dataset_path)
边栏推荐
- Mysql Duplicate entry ‘xxx‘ for key ‘xxx‘
- Which direction of network development is better? Data communication engineer learning path sharing
- 太卷了~ 八股文,面试最强王者!
- Simple pages
- Command line Basics
- If the amount exceeds 6 digits after the decimal point, only 6 digits will be reserved, and if it is less than 6 digits, it will remain the same - Basic accumulation
- Interviewer: how to solve the problem of massive requests for data that does not exist in redis, which affects the database?
- How to switch to root in xshell
- [PHP] PHP pressure test, error reporting: generally, each socket address (Protocol / network address / port) is only allowed to be used
- 【three.js】WEB3D初次体验
猜你喜欢

Bytek suffered a disastrous defeat in the interview: he was hanged on one side, but fortunately Huawei pushed him in, and he got an offer on three sides
![[MRCTF2020]Ezpop-1|php序列化](/img/f8/6164b4123e0d1f3b90980ebb7b4097.png)
[MRCTF2020]Ezpop-1|php序列化

Yunna | fixed assets system management, NC system management where are the fixed assets

How much is the fixed asset management system and the price of the fixed asset management system

How to seamlessly transition from traditional microservice framework to service grid ASM
![[lorawan node application] the application and power consumption of Anxin ra-08/ra-08h module in lorawan network](/img/5d/9cff7bd25841c1ca6e5ab8e2994f51.png)
[lorawan node application] the application and power consumption of Anxin ra-08/ra-08h module in lorawan network

网易云音乐内测音乐社交 App“MUS”,通过音乐匹配同频朋友

如何拒绝期末复习无用功?猿辅导:找准适合自己的复习方法很重要

Le module twincat 3 el7211 contrôle les servocommandes baffle
![[qnx hypervisor 2.2 user manual]6.2.2 communication between guest and host](/img/a4/a84f916d3aa2cc59f5b686cd32797a.png)
[qnx hypervisor 2.2 user manual]6.2.2 communication between guest and host
随机推荐
中小企业签署ERP合同时,需要注意这几点
玉米地里的小鸟
mysql 死锁
ML:置信区间的简介(精密度/准确度/精确度的三者区别及其关系)、使用方法、案例应用之详细攻略
浮点数通信
Database learning notes (sql03)
Yunna | fixed assets system management, NC system management where are the fixed assets
【mysql篇-基础篇】通用语法2
The third bullet of wechat payment app application for refund
About SQL: create a view_ XB view, whose function is ① to delete views with duplicate names before creating them ② to display the number of male and female students in this class in the XSB table, and
如何在IDEA中自定義模板、快速生成完整的代碼?
[mrctf2020]ezpop-1 | PHP serialization
Le module twincat 3 el7211 contrôle les servocommandes baffle
Outsourcing for 3 years is a waste
Top performance version 2 reading notes (IV) -- memory monitoring
优秀的测试/开发程序员与普通的程序员对比......
Which direction of network development is better? Data communication engineer learning path sharing
在线SQL转CSV工具
【Spark】scala基础操作(持续更新)
How to seamlessly transition from traditional microservice framework to service grid ASM