当前位置:网站首页>深度学习区分不同种类的图片
深度学习区分不同种类的图片
2022-07-30 16:11:00 【这个利弗莫尔不太冷】
数据集格式
之前利用resnet从一开始训练,效果比较差,后来利用谷歌的模型进行微调达到了很好的效果
训练代码如下:
from datasets import load_dataset
# /home/huhao/TensorFlow2.0_ResNet/dataset
# /home/huhao/dataset
import numpy as np
from datasets import load_metric
scene = load_dataset("/home/huhao/TensorFlow2.0_ResNet/dataset")
dataset = scene['train']
scene = dataset.train_test_split(test_size=0.2)
labels = scene["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
from transformers import AutoFeatureExtractor
# google/vit-base-patch16-224-in21k
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
def transforms(examples):
examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
del examples["image"]
return examples
scene = scene.with_transform(transforms)
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
def compute_metric(eval_pred):
metric = load_metric("accuracy")
logits,labels = eval_pred
print(logits,labels)
print(len(logits),len(labels))
predictions = np.argmax(logits,axis=-1)
print(len(predictions))
print('predictions')
print(predictions)
return metric.compute(predictions = predictions,references = labels)
model = AutoModelForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k",
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir = 'True',
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
load_best_model_at_end=False,
save_strategy='no',
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=scene["train"],
eval_dataset=scene["test"],
tokenizer=feature_extractor,
compute_metrics=compute_metric,
)
trainer.train()
trainer.evaluate()
trainer.save_model('/home/huhao/script/model')
测试代码如下
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
# 我已经把训练好的模型上传到网上,这里下载即可使用
from datasets import load_dataset
# /home/huhao/TensorFlow2.0_ResNet/dataset
# /home/huhao/dataset
import numpy as np
from datasets import load_metric
# 这个是数据集加载的路径
scene = load_dataset("/home/huhao/script/dataset")
dataset = scene['train']
scene = dataset.train_test_split(test_size=0.2)
labels = scene["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
from transformers import AutoFeatureExtractor
# google/vit-base-patch16-224-in21k
feature_extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
def transforms(examples):
examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
del examples["image"]
return examples
scene = scene.with_transform(transforms)
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir = 'True',
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
load_best_model_at_end=False,
save_strategy='no',
)
model = AutoModelForImageClassification.from_pretrained(
"HaoHu/vit-base-patch16-224-in21k-classify-4scence",
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
def compute_metric(eval_pred):
metric = load_metric("f1")
logits,labels = eval_pred
print(len(logits),len(labels))
predictions = np.argmax(logits,axis=-1)
print('对测试集进行评估')
print('labels')
print(labels)
print('predictions')
print(predictions)
return metric.compute(predictions = predictions,references = labels,average='macro')
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
eval_dataset=scene["test"],
tokenizer=feature_extractor,
compute_metrics=compute_metric,
)
compute_metrics = trainer.evaluate()
# {'eval_loss': 0.04495017230510712, 'eval_accuracy': 0.9943181818181818, 'eval_runtime': 30.8715, 'eval_samples_per_second': 11.402, 'eval_steps_per_second': 1.425}
print('输出最后的结果eval_f1:')
print(compute_metrics['eval_f1'])
from doctest import Example
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, ImageClassificationPipeline
import os
extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
from transformers import pipeline
#generator = ImageClassificationPipeline(model=model, tokenizer=extractor)
vision_classifier = pipeline(task="image-classification",model = model,feature_extractor = extractor)
result_dict = {'City_road':0,'fog':1,'rain':2,'snow':3}
val_path = '/home/huhao/script/val/'
all_img = os.listdir(val_path)
for img in all_img:
tmp_score = 0
end_label = ''
img_path = os.path.join(val_path,img)
score_list = vision_classifier(img_path)
for sample in score_list:
score = sample['score']
label = sample['label']
if tmp_score < score:
tmp_score = score
end_label = label
print(result_dict[end_label])
边栏推荐
猜你喜欢
随机推荐
3D激光SLAM:LeGO-LOAM论文解读---实验对比
Leetcode 118. 杨辉三角
【开发者必看】【push kit】推送服务典型问题合集2
三维重建方法汇总
23. Please talk about the difference between IO synchronization, asynchronous, blocking and non-blocking
70 lines of code, a desktop automatic translation artifact
游戏窗口化的逆向分析
Redis 复习计划 - Redis 数据结构和持久化机制
【SOC】经典输出hello world
PCIE入门
Pytorch 训练技巧
绕开驱动层检测的无痕注入
Leetcode 118. Yanghui Triangle
Qt 动态库与静态库
支付系统架构设计详解,精彩!
谷歌工程师『代码补全』工具;『Transformers NLP』随书代码;FastAPI开发模板;PyTorch模型加速工具;前沿论文 | ShowMeAI资讯日报
LeetCode-283-移动零
基于STM32F407使用ADC采集电压实验
3D激光SLAM:LeGO-LOAM论文解读---激光雷达里程计与建图
Why is there no data reported when the application is connected to Huawei Analytics in the application debugging mode?