当前位置:网站首页>深度学习区分不同种类的图片
深度学习区分不同种类的图片
2022-07-30 16:11:00 【这个利弗莫尔不太冷】
数据集格式
之前利用resnet从一开始训练,效果比较差,后来利用谷歌的模型进行微调达到了很好的效果
训练代码如下:
from datasets import load_dataset
# /home/huhao/TensorFlow2.0_ResNet/dataset
# /home/huhao/dataset
import numpy as np
from datasets import load_metric
scene = load_dataset("/home/huhao/TensorFlow2.0_ResNet/dataset")
dataset = scene['train']
scene = dataset.train_test_split(test_size=0.2)
labels = scene["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
from transformers import AutoFeatureExtractor
# google/vit-base-patch16-224-in21k
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
def transforms(examples):
examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
del examples["image"]
return examples
scene = scene.with_transform(transforms)
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
def compute_metric(eval_pred):
metric = load_metric("accuracy")
logits,labels = eval_pred
print(logits,labels)
print(len(logits),len(labels))
predictions = np.argmax(logits,axis=-1)
print(len(predictions))
print('predictions')
print(predictions)
return metric.compute(predictions = predictions,references = labels)
model = AutoModelForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k",
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir = 'True',
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
load_best_model_at_end=False,
save_strategy='no',
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=scene["train"],
eval_dataset=scene["test"],
tokenizer=feature_extractor,
compute_metrics=compute_metric,
)
trainer.train()
trainer.evaluate()
trainer.save_model('/home/huhao/script/model')
测试代码如下
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
# 我已经把训练好的模型上传到网上,这里下载即可使用
from datasets import load_dataset
# /home/huhao/TensorFlow2.0_ResNet/dataset
# /home/huhao/dataset
import numpy as np
from datasets import load_metric
# 这个是数据集加载的路径
scene = load_dataset("/home/huhao/script/dataset")
dataset = scene['train']
scene = dataset.train_test_split(test_size=0.2)
labels = scene["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
from transformers import AutoFeatureExtractor
# google/vit-base-patch16-224-in21k
feature_extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
def transforms(examples):
examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
del examples["image"]
return examples
scene = scene.with_transform(transforms)
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir = 'True',
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
load_best_model_at_end=False,
save_strategy='no',
)
model = AutoModelForImageClassification.from_pretrained(
"HaoHu/vit-base-patch16-224-in21k-classify-4scence",
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
def compute_metric(eval_pred):
metric = load_metric("f1")
logits,labels = eval_pred
print(len(logits),len(labels))
predictions = np.argmax(logits,axis=-1)
print('对测试集进行评估')
print('labels')
print(labels)
print('predictions')
print(predictions)
return metric.compute(predictions = predictions,references = labels,average='macro')
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
eval_dataset=scene["test"],
tokenizer=feature_extractor,
compute_metrics=compute_metric,
)
compute_metrics = trainer.evaluate()
# {'eval_loss': 0.04495017230510712, 'eval_accuracy': 0.9943181818181818, 'eval_runtime': 30.8715, 'eval_samples_per_second': 11.402, 'eval_steps_per_second': 1.425}
print('输出最后的结果eval_f1:')
print(compute_metrics['eval_f1'])
from doctest import Example
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, ImageClassificationPipeline
import os
extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
from transformers import pipeline
#generator = ImageClassificationPipeline(model=model, tokenizer=extractor)
vision_classifier = pipeline(task="image-classification",model = model,feature_extractor = extractor)
result_dict = {'City_road':0,'fog':1,'rain':2,'snow':3}
val_path = '/home/huhao/script/val/'
all_img = os.listdir(val_path)
for img in all_img:
tmp_score = 0
end_label = ''
img_path = os.path.join(val_path,img)
score_list = vision_classifier(img_path)
for sample in score_list:
score = sample['score']
label = sample['label']
if tmp_score < score:
tmp_score = score
end_label = label
print(result_dict[end_label])
边栏推荐
- How to use Redis for distributed applications in Golang
- 武汉星起航:海外仓基础建设成为跨境电商企业的一大出海利器
- Promise Notes (1)
- STM32F407定时器输入捕获
- [TypeScript]简介、开发环境搭建、基本类型
- 涨姿势了!原来这才是多线程正确实现方式
- rhce笔记1
- How to remove last character from string in php
- 谷歌工程师『代码补全』工具;『Transformers NLP』随书代码;FastAPI开发模板;PyTorch模型加速工具;前沿论文 | ShowMeAI资讯日报
- 动态规划 --- 状态压缩DP 详细解释
猜你喜欢
随机推荐
武汉星起航跨境电商有前景吗?亚马逊的未来趋势如何发展?
[HMS core] [FAQ] A collection of typical questions about push kit, analysis services, and video editing services 3
字符串加千分位符与递归数组求和
The service already exists! Solution
Image information extraction DEM
Pytorch 训练技巧
路遇又一个流量风口,民宿长期向好的逻辑选对了吗
Golang分布式应用之Redis怎么使用
【AGC】质量服务1-崩溃服务示例
武汉星起航:海外仓基础建设成为跨境电商企业的一大出海利器
SocialFi 何以成就 Web3 去中心化社交未来
【HMS core】【FAQ】push kit、WisePlay DRM、Location Kit、Health Kit、3D Modeling Kit、SignPal Kit典型问题合集4
Public Key Retrieval is not allowed error solution
为什么数据需要序列化
Placement Rules usage documentation
[TypeScript] Introduction, Development Environment Construction, Basic Types
[HMS core] [FAQ] Collection of typical problems of push kit, AR Engine, advertising service, scanning service 2
围绕用户思维,木鸟与途家如何实现乡村民宿下的用户运营
Array element inverse
Promise笔记(一)