当前位置:网站首页>深度学习区分不同种类的图片
深度学习区分不同种类的图片
2022-07-30 16:11:00 【这个利弗莫尔不太冷】
数据集格式

之前利用resnet从一开始训练,效果比较差,后来利用谷歌的模型进行微调达到了很好的效果
训练代码如下:
from datasets import load_dataset
# /home/huhao/TensorFlow2.0_ResNet/dataset
# /home/huhao/dataset
import numpy as np
from datasets import load_metric
scene = load_dataset("/home/huhao/TensorFlow2.0_ResNet/dataset")
dataset = scene['train']
scene = dataset.train_test_split(test_size=0.2)
labels = scene["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
from transformers import AutoFeatureExtractor
# google/vit-base-patch16-224-in21k
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
def transforms(examples):
examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
del examples["image"]
return examples
scene = scene.with_transform(transforms)
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
def compute_metric(eval_pred):
metric = load_metric("accuracy")
logits,labels = eval_pred
print(logits,labels)
print(len(logits),len(labels))
predictions = np.argmax(logits,axis=-1)
print(len(predictions))
print('predictions')
print(predictions)
return metric.compute(predictions = predictions,references = labels)
model = AutoModelForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k",
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir = 'True',
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
load_best_model_at_end=False,
save_strategy='no',
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=scene["train"],
eval_dataset=scene["test"],
tokenizer=feature_extractor,
compute_metrics=compute_metric,
)
trainer.train()
trainer.evaluate()
trainer.save_model('/home/huhao/script/model')测试代码如下
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
# 我已经把训练好的模型上传到网上,这里下载即可使用
from datasets import load_dataset
# /home/huhao/TensorFlow2.0_ResNet/dataset
# /home/huhao/dataset
import numpy as np
from datasets import load_metric
# 这个是数据集加载的路径
scene = load_dataset("/home/huhao/script/dataset")
dataset = scene['train']
scene = dataset.train_test_split(test_size=0.2)
labels = scene["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
from transformers import AutoFeatureExtractor
# google/vit-base-patch16-224-in21k
feature_extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
def transforms(examples):
examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
del examples["image"]
return examples
scene = scene.with_transform(transforms)
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir = 'True',
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
load_best_model_at_end=False,
save_strategy='no',
)
model = AutoModelForImageClassification.from_pretrained(
"HaoHu/vit-base-patch16-224-in21k-classify-4scence",
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
def compute_metric(eval_pred):
metric = load_metric("f1")
logits,labels = eval_pred
print(len(logits),len(labels))
predictions = np.argmax(logits,axis=-1)
print('对测试集进行评估')
print('labels')
print(labels)
print('predictions')
print(predictions)
return metric.compute(predictions = predictions,references = labels,average='macro')
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
eval_dataset=scene["test"],
tokenizer=feature_extractor,
compute_metrics=compute_metric,
)
compute_metrics = trainer.evaluate()
# {'eval_loss': 0.04495017230510712, 'eval_accuracy': 0.9943181818181818, 'eval_runtime': 30.8715, 'eval_samples_per_second': 11.402, 'eval_steps_per_second': 1.425}
print('输出最后的结果eval_f1:')
print(compute_metrics['eval_f1'])
from doctest import Example
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, ImageClassificationPipeline
import os
extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
from transformers import pipeline
#generator = ImageClassificationPipeline(model=model, tokenizer=extractor)
vision_classifier = pipeline(task="image-classification",model = model,feature_extractor = extractor)
result_dict = {'City_road':0,'fog':1,'rain':2,'snow':3}
val_path = '/home/huhao/script/val/'
all_img = os.listdir(val_path)
for img in all_img:
tmp_score = 0
end_label = ''
img_path = os.path.join(val_path,img)
score_list = vision_classifier(img_path)
for sample in score_list:
score = sample['score']
label = sample['label']
if tmp_score < score:
tmp_score = score
end_label = label
print(result_dict[end_label])
边栏推荐
- C# List<T> 模板的案例
- 游戏显示分辨率的逆向分析
- 数组元素逆置
- DTSE Tech Talk丨第2期:1小时深度解读SaaS应用系统设计
- arcpy tutorial
- 围绕用户思维,木鸟与途家如何实现乡村民宿下的用户运营
- Wuhan Star Sets Sail: Overseas warehouse infrastructure has become a major tool for cross-border e-commerce companies to go overseas
- C语言学习之旅 【函数(二)】
- Databases - create databases, tables, functions, etc.
- nodejs environment variable settings
猜你喜欢

【SOC FPGA】Peripheral KEY LED

23. 请你谈谈关于IO同步、异步、阻塞、非阻塞的区别

(1) Cloud computing technology learning - virtualized vSphere learning

【HMS core】【FAQ】Account、IAP、Location Kit and HarmonyOS典型问题合集1

路遇又一个流量风口,民宿长期向好的逻辑选对了吗

经典实例分割模型Mask RCNN原理与测试

【HMS core】【Media】【Video Editing Service】 The online material cannot be displayed, it is always in the loading state or the network is abnormal

如何在分面中添加数学表达式标签?

围绕用户思维,木鸟与途家如何实现乡村民宿下的用户运营

【C语言】指针和数组的深入理解(第二期)
随机推荐
【HMS core】【FAQ】A collection of typical questions about Account, IAP, Location Kit and HarmonyOS 1
php如何去除字符串最后一位字符
Huawei ADS reports an error when obtaining conversion tracking parameters: getInstallReferrer IOException: getInstallReferrer not found installreferrer
为什么数据需要序列化
Qt 动态库与静态库
如何快速拷贝整个网站所有网页
rhce笔记2
JVM学习----垃圾回收
Placement Rules usage documentation
PyQt5快速开发与实战 9.2 数据库处理
Golang分布式应用之Redis怎么使用
Scheduling_Channel_Access_Based_on_Target_Wake_Time_Mechanism_in_802.11ax_WLANs
【SOC FPGA】外设KEY点LED
(一)云计算技术学习--虚拟化vSphere学习
【HMS core】【Media】【Video Editing Service】 The online material cannot be displayed, it is always in the loading state or the network is abnormal
Jetpack Compose 到底优秀在哪里?| 开发者说·DTalk
php how to query string occurrence position
详解最实用的几种dll注入方式
Public Key Retrieval is not allowed error solution
PCIE入门