当前位置:网站首页>mmclassification 训练自定义数据
mmclassification 训练自定义数据
2022-07-05 11:49:00 【Coding的叶子】
1 mmclassification 安装
如果环境已安装mmclassification,请跳过该步骤。mmclassification框架安装与调试验证请参考博客:mmclassification安装与调试_Coding的叶子的博客-CSDN博客_mmclassification 安装。
2 数据集准备
mmclassification 的数据集目录主要由标注文件和图片样本组成,其中标注文件存储在meta文件夹中,图片样本存在train、val、test文件夹下,即分别是用于训练、验证和测试的图片样本。图片样本文件按照类别存储在train、val、test文件夹下,同一类别图片存储在同一个子文件夹中,子文件夹的名称为图片所属类别名称。
meta文件夹中主要包含了train.txt、val.txt和test.txt文件。txt文件中的每一行分别存储了图片样本路径和类别id,如下图所示。
如果没有meta标注文件,请参考博客:mmclassification 标注文件生成_Coding的叶子的博客-CSDN博客,生成meta文件夹及其文件夹下的txt文件。
本文示例数据来源于minist手写字体可视化数据集,已按照train、test文件夹进行存储,下载地址为:minist手写数字可视化数据集-深度学习文档类资源-CSDN下载。
将下载的数据集文件夹名称重名为Minist,并且mmclassification工程目录下新建data文件夹,将数据集放到data文件夹下即可。数据集的存储路径不限,需要在下方3.3节中配置相应的路径即可。
3 自定义数据集
3.1 新建MyDataset
在mmclassification工程目录下的mmcls/datasets/新建mydataset.py文件,自定义数据加载类MyDataset,文件名称mydataset和类名称MyDataset可以自行更改。mydataset.py文件中的内容如下:
# -*- coding: utf-8 -*-
"""
乐乐感知学堂公众号
@author: https://blog.csdn.net/suiyingy
"""
import numpy as np
from .builder import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module()
class MyDataset(BaseDataset):
def load_annotations(self):
assert isinstance(self.ann_file, str)
data_infos = []
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for filename, gt_label in samples:
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)
return data_infos
3.2 将MyDataset注册到mmclassification框架
在mmcls/datasets/__init__.py文件中增加上面定义的类MyDataset,如下图所示:
3.3 新建数据集配置文件
在mmclassification工程目录configs/_base_/datasets/文件夹下,新建mydataset.py文件,主要用于设置数据集类型、数据增强方式、batch size (samples_per_gpu)、数据集路径和标注文件路径、模型保存周期(interval)。文件内容如下所示:
# -*- coding: utf-8 -*-
"""
乐乐感知学堂公众号
@author: https://blog.csdn.net/suiyingy
"""
dataset_type = 'MyDataset'
classes = ['cat', 'bird', 'dog'] # The category names of your dataset
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=(256, -1)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(
type=dataset_type,
data_prefix='data/Minist/train',
ann_file='data/Minist/meta/train.txt',
classes=classes,
pipeline=train_pipeline
),
val=dict(
type=dataset_type,
data_prefix='data/Minist/test',
ann_file='data/Minist/meta/test.txt',
classes=classes,
pipeline=test_pipeline
),
test=dict(
type=dataset_type,
data_prefix='data/Minist/test',
ann_file='data/Minist/meta/test.txt',
classes=classes,
pipeline=test_pipeline
)
)
evaluation = dict(interval=1, metric='accuracy')
4 修改configs模型配置文件
以configs/resnet/resnet18_8xb16_cifar10.py配置文件为例,mmclassification的配置文件通常包含以下4个部分:
_base_ = [
'../_base_/models/resnet18_cifar.py', '../_base_/datasets/cifar10_bs16.py',
'../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
]
../_base_/models/resnet18_cifar.py:定义模型参数,主要包括主干网络、neck、head和类别数量。
../_base_/datasets/cifar10_bs16.py:定义数据集增强方式和路径,也就是3.3节的配置文件,bs16表示batch size为16,即samples_per_gpu=16。
../_base_/schedules/cifar10_bs128.py:定义训练参数,主要包括优化器、学习率、训练总epoch数量。
../_base_/default_runtime.py:定义运行参数,主要包括模型保存周期、日志输出周期等。
configs主要修改的地方为数据配置文件,即把 '../_base_/datasets/cifar10_bs16.py'更换成3.3节中的配置文件'../_base_/datasets/mydataset.py'。即:
5 运行训练程序
mmcls基本的训练命令为:
python tools/train.py 模型配置文件
示例:
python tools/train.py configs/resnet/resnet18_8xb16_cifar10.py
这里已经把resnet18_8xb16_cifar10.py文件按照第4节进行了修改。
6 运行结果
7 【python三维深度学习】python三维点云从基础到深度学习_Coding的叶子的博客-CSDN博客_python 三维点云
更多三维、二维感知算法和金融量化分析算法请关注“乐乐感知学堂”微信公众号,并将持续进行更新。
边栏推荐
- POJ 3176 cow bowling (DP | memory search)
- ibatis的动态sql
- [cloud native | kubernetes] actual battle of ingress case (13)
- Yolov5 target detection neural network -- calculation principle of loss function
- 1. Laravel creation project of PHP
- 阻止瀏覽器後退操作
- pytorch-softmax回归
- Shell script file traversal STR to array string splicing
- 汉诺塔问题思路的证明
- Multi table operation - sub query
猜你喜欢
1个插件搞定网页中的广告
Redirection of redis cluster
MySQL 巨坑:update 更新慎用影响行数做判断!!!
全网最全的新型数据库、多维表格平台盘点 Notion、FlowUs、Airtable、SeaTable、维格表 Vika、飞书多维表格、黑帕云、织信 Informat、语雀
Splunk configuration 163 mailbox alarm
Reading notes of growth hacker
Redis cluster (master-slave) brain crack and solution
【PyTorch预训练模型修改、增删特定层】
Riddle 1
【使用TensorRT通过ONNX部署Pytorch项目】
随机推荐
Open3D 欧式聚类
多表操作-子查询
中非 钻石副石怎么镶嵌,才能既安全又好看?
CDGA|数据治理不得不坚持的六个原则
Codeworks 5 questions per day (1700 average) - day 5
The ninth Operation Committee meeting of dragon lizard community was successfully held
【PyTorch预训练模型修改、增删特定层】
[yolov5.yaml parsing]
查看rancher中debug端口信息,并做IDEA Remote Jvm Debug
Troubleshooting of high memory usage of redis in a production environment
COMSOL -- establishment of geometric model -- establishment of two-dimensional graphics
跨平台(32bit和64bit)的 printf 格式符 %lld 输出64位的解决方式
Implementation of array hash function in PHP
简单解决redis cluster中从节点读取不了数据(error) MOVED
[configuration method of win11 multi-user simultaneous login remote desktop]
【yolov3损失函数】
Redis cluster (master-slave) brain crack and solution
[calculation of loss in yolov3]
[singleshotmultiboxdetector (SSD, single step multi frame target detection)]
COMSOL -- establishment of 3D graphics