当前位置:网站首页>深度学习区分不同种类的图片
深度学习区分不同种类的图片
2022-07-30 16:11:00 【这个利弗莫尔不太冷】
数据集格式

之前利用resnet从一开始训练,效果比较差,后来利用谷歌的模型进行微调达到了很好的效果
训练代码如下:
from datasets import load_dataset
# /home/huhao/TensorFlow2.0_ResNet/dataset
# /home/huhao/dataset
import numpy as np
from datasets import load_metric
scene = load_dataset("/home/huhao/TensorFlow2.0_ResNet/dataset")
dataset = scene['train']
scene = dataset.train_test_split(test_size=0.2)
labels = scene["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
from transformers import AutoFeatureExtractor
# google/vit-base-patch16-224-in21k
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
def transforms(examples):
examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
del examples["image"]
return examples
scene = scene.with_transform(transforms)
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
def compute_metric(eval_pred):
metric = load_metric("accuracy")
logits,labels = eval_pred
print(logits,labels)
print(len(logits),len(labels))
predictions = np.argmax(logits,axis=-1)
print(len(predictions))
print('predictions')
print(predictions)
return metric.compute(predictions = predictions,references = labels)
model = AutoModelForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k",
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir = 'True',
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
load_best_model_at_end=False,
save_strategy='no',
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=scene["train"],
eval_dataset=scene["test"],
tokenizer=feature_extractor,
compute_metrics=compute_metric,
)
trainer.train()
trainer.evaluate()
trainer.save_model('/home/huhao/script/model')测试代码如下
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
# 我已经把训练好的模型上传到网上,这里下载即可使用
from datasets import load_dataset
# /home/huhao/TensorFlow2.0_ResNet/dataset
# /home/huhao/dataset
import numpy as np
from datasets import load_metric
# 这个是数据集加载的路径
scene = load_dataset("/home/huhao/script/dataset")
dataset = scene['train']
scene = dataset.train_test_split(test_size=0.2)
labels = scene["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
from transformers import AutoFeatureExtractor
# google/vit-base-patch16-224-in21k
feature_extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
def transforms(examples):
examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
del examples["image"]
return examples
scene = scene.with_transform(transforms)
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir = 'True',
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
load_best_model_at_end=False,
save_strategy='no',
)
model = AutoModelForImageClassification.from_pretrained(
"HaoHu/vit-base-patch16-224-in21k-classify-4scence",
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
def compute_metric(eval_pred):
metric = load_metric("f1")
logits,labels = eval_pred
print(len(logits),len(labels))
predictions = np.argmax(logits,axis=-1)
print('对测试集进行评估')
print('labels')
print(labels)
print('predictions')
print(predictions)
return metric.compute(predictions = predictions,references = labels,average='macro')
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
eval_dataset=scene["test"],
tokenizer=feature_extractor,
compute_metrics=compute_metric,
)
compute_metrics = trainer.evaluate()
# {'eval_loss': 0.04495017230510712, 'eval_accuracy': 0.9943181818181818, 'eval_runtime': 30.8715, 'eval_samples_per_second': 11.402, 'eval_steps_per_second': 1.425}
print('输出最后的结果eval_f1:')
print(compute_metrics['eval_f1'])
from doctest import Example
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, ImageClassificationPipeline
import os
extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
from transformers import pipeline
#generator = ImageClassificationPipeline(model=model, tokenizer=extractor)
vision_classifier = pipeline(task="image-classification",model = model,feature_extractor = extractor)
result_dict = {'City_road':0,'fog':1,'rain':2,'snow':3}
val_path = '/home/huhao/script/val/'
all_img = os.listdir(val_path)
for img in all_img:
tmp_score = 0
end_label = ''
img_path = os.path.join(val_path,img)
score_list = vision_classifier(img_path)
for sample in score_list:
score = sample['score']
label = sample['label']
if tmp_score < score:
tmp_score = score
end_label = label
print(result_dict[end_label])
边栏推荐
- Windows MySQL 安装配置
- 游戏窗口化的逆向分析
- PyQt5快速开发与实战 9.2 数据库处理
- FME's scheme and operation process for reading and writing cass data
- node.js中怎么连接redis?
- Leetcode 118. Yanghui Triangle
- 新技术要去做新价值
- 武汉星起航跨境电商有前景吗?亚马逊的未来趋势如何发展?
- vivo announced to extend the product warranty period, the system launched a variety of functional services
- RobotStudio实现喷漆、打磨等功能(曲面路径生成与仿真)
猜你喜欢
随机推荐
Recent learning defragmentation (24)
五只小猪的案例(五只小猪 比较体重的大小)
【SOC】Classic output hello world
[NCTF2019] Fake XML cookbook-1|XXE vulnerability|XXE information introduction
3D激光SLAM:LeGO-LOAM论文解读---特征提取部分
php how to query string occurrence position
How to intercept the first few digits of a string in php
23. 请你谈谈关于IO同步、异步、阻塞、非阻塞的区别
Jetpack Compose 到底优秀在哪里?| 开发者说·DTalk
Public Key Retrieval is not allowed报错解决方案
C# List<T> 模板的案例
arcpy tutorial
【开发者必看】【push kit】推送服务典型问题合集2
rscsa笔记八
新技术要去做新价值
Large-scale integrated office management system source code (OA+HR+CRM) source code sharing for free
基于STM32F407使用ADC采集电压实验
Wuhan Star Sets Sail: Overseas warehouse infrastructure has become a major tool for cross-border e-commerce companies to go overseas
Nervegrowold d2l (7) kaggle housing forecast model, numerical stability and the initialization and activation function
武汉星起航:海外仓基础建设成为跨境电商企业的一大出海利器









