当前位置:网站首页>深度学习区分不同种类的图片
深度学习区分不同种类的图片
2022-07-30 16:11:00 【这个利弗莫尔不太冷】
数据集格式

之前利用resnet从一开始训练,效果比较差,后来利用谷歌的模型进行微调达到了很好的效果
训练代码如下:
from datasets import load_dataset
# /home/huhao/TensorFlow2.0_ResNet/dataset
# /home/huhao/dataset
import numpy as np
from datasets import load_metric
scene = load_dataset("/home/huhao/TensorFlow2.0_ResNet/dataset")
dataset = scene['train']
scene = dataset.train_test_split(test_size=0.2)
labels = scene["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
from transformers import AutoFeatureExtractor
# google/vit-base-patch16-224-in21k
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
def transforms(examples):
examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
del examples["image"]
return examples
scene = scene.with_transform(transforms)
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
def compute_metric(eval_pred):
metric = load_metric("accuracy")
logits,labels = eval_pred
print(logits,labels)
print(len(logits),len(labels))
predictions = np.argmax(logits,axis=-1)
print(len(predictions))
print('predictions')
print(predictions)
return metric.compute(predictions = predictions,references = labels)
model = AutoModelForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k",
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir = 'True',
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
load_best_model_at_end=False,
save_strategy='no',
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=scene["train"],
eval_dataset=scene["test"],
tokenizer=feature_extractor,
compute_metrics=compute_metric,
)
trainer.train()
trainer.evaluate()
trainer.save_model('/home/huhao/script/model')测试代码如下
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
# 我已经把训练好的模型上传到网上,这里下载即可使用
from datasets import load_dataset
# /home/huhao/TensorFlow2.0_ResNet/dataset
# /home/huhao/dataset
import numpy as np
from datasets import load_metric
# 这个是数据集加载的路径
scene = load_dataset("/home/huhao/script/dataset")
dataset = scene['train']
scene = dataset.train_test_split(test_size=0.2)
labels = scene["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
from transformers import AutoFeatureExtractor
# google/vit-base-patch16-224-in21k
feature_extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
def transforms(examples):
examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
del examples["image"]
return examples
scene = scene.with_transform(transforms)
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir = 'True',
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
load_best_model_at_end=False,
save_strategy='no',
)
model = AutoModelForImageClassification.from_pretrained(
"HaoHu/vit-base-patch16-224-in21k-classify-4scence",
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
def compute_metric(eval_pred):
metric = load_metric("f1")
logits,labels = eval_pred
print(len(logits),len(labels))
predictions = np.argmax(logits,axis=-1)
print('对测试集进行评估')
print('labels')
print(labels)
print('predictions')
print(predictions)
return metric.compute(predictions = predictions,references = labels,average='macro')
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
eval_dataset=scene["test"],
tokenizer=feature_extractor,
compute_metrics=compute_metric,
)
compute_metrics = trainer.evaluate()
# {'eval_loss': 0.04495017230510712, 'eval_accuracy': 0.9943181818181818, 'eval_runtime': 30.8715, 'eval_samples_per_second': 11.402, 'eval_steps_per_second': 1.425}
print('输出最后的结果eval_f1:')
print(compute_metrics['eval_f1'])
from doctest import Example
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, ImageClassificationPipeline
import os
extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
from transformers import pipeline
#generator = ImageClassificationPipeline(model=model, tokenizer=extractor)
vision_classifier = pipeline(task="image-classification",model = model,feature_extractor = extractor)
result_dict = {'City_road':0,'fog':1,'rain':2,'snow':3}
val_path = '/home/huhao/script/val/'
all_img = os.listdir(val_path)
for img in all_img:
tmp_score = 0
end_label = ''
img_path = os.path.join(val_path,img)
score_list = vision_classifier(img_path)
for sample in score_list:
score = sample['score']
label = sample['label']
if tmp_score < score:
tmp_score = score
end_label = label
print(result_dict[end_label])
边栏推荐
- Data Analysis Tools - DDL operations & DML operations in HQL
- 华为ADS获取转化跟踪参数报错:getInstallReferrer IOException: getInstallReferrer not found installreferrer
- Overview of TiDB Tool Functions
- Wuhan Star Sets Sail: Overseas warehouse infrastructure has become a major tool for cross-border e-commerce companies to go overseas
- js 切换数据源的时候该缓存checkbox选中结果并回显?
- L2-007 Family property (use of vector, set, map)
- rhce笔记2
- 【AGC】质量服务1-崩溃服务示例
- 大型综合办公管理系统源码(OA+HR+CRM)源码免费分享
- PMP每日一练 | 考试不迷路-7.30(包含敏捷+多选)
猜你喜欢
随机推荐
如何写一份高可读性的软件工程设计文档
游戏显示分辨率的逆向分析
hcip--ospf综合实验
Public Key Retrieval is not allowed报错解决方案
How to intercept the first few digits of a string in php
Goland opens file saving and automatically formats
Why is there no data reported when the application is connected to Huawei Analytics in the application debugging mode?
vivo宣布延长产品保修期限 系统上线多种功能服务
数组和指针(2)
动态规划 --- 状态压缩DP 详细解释
支付系统架构设计详解,精彩!
【AGC】开放式测试示例
影像信息提取DEM
【HMS core】【Media】【视频编辑服务】 在线素材无法展示,一直Loading状态或是网络异常
【HMS core】【FAQ】A collection of typical questions about Account, IAP, Location Kit and HarmonyOS 1
Moonbeam创始人解读多链新概念Connected Contract
基于STM32F407使用ADC采集电压实验
Nervegrowold d2l (7) kaggle housing forecast model, numerical stability and the initialization and activation function
Image information extraction DEM
在 Chrome 浏览器中安装 JSON 显示插件





![[AGC] Quality Service 1 - Example of Crash Service](/img/d8/e6b365889449745a61597b668dc89b.png)



