当前位置:网站首页>mmclassification 训练自定义数据
mmclassification 训练自定义数据
2022-07-05 11:49:00 【Coding的叶子】
1 mmclassification 安装
如果环境已安装mmclassification,请跳过该步骤。mmclassification框架安装与调试验证请参考博客:mmclassification安装与调试_Coding的叶子的博客-CSDN博客_mmclassification 安装。
2 数据集准备
mmclassification 的数据集目录主要由标注文件和图片样本组成,其中标注文件存储在meta文件夹中,图片样本存在train、val、test文件夹下,即分别是用于训练、验证和测试的图片样本。图片样本文件按照类别存储在train、val、test文件夹下,同一类别图片存储在同一个子文件夹中,子文件夹的名称为图片所属类别名称。
meta文件夹中主要包含了train.txt、val.txt和test.txt文件。txt文件中的每一行分别存储了图片样本路径和类别id,如下图所示。
如果没有meta标注文件,请参考博客:mmclassification 标注文件生成_Coding的叶子的博客-CSDN博客,生成meta文件夹及其文件夹下的txt文件。
本文示例数据来源于minist手写字体可视化数据集,已按照train、test文件夹进行存储,下载地址为:minist手写数字可视化数据集-深度学习文档类资源-CSDN下载。
将下载的数据集文件夹名称重名为Minist,并且mmclassification工程目录下新建data文件夹,将数据集放到data文件夹下即可。数据集的存储路径不限,需要在下方3.3节中配置相应的路径即可。
3 自定义数据集
3.1 新建MyDataset
在mmclassification工程目录下的mmcls/datasets/新建mydataset.py文件,自定义数据加载类MyDataset,文件名称mydataset和类名称MyDataset可以自行更改。mydataset.py文件中的内容如下:
# -*- coding: utf-8 -*-
"""
乐乐感知学堂公众号
@author: https://blog.csdn.net/suiyingy
"""
import numpy as np
from .builder import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module()
class MyDataset(BaseDataset):
def load_annotations(self):
assert isinstance(self.ann_file, str)
data_infos = []
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for filename, gt_label in samples:
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)
return data_infos
3.2 将MyDataset注册到mmclassification框架
在mmcls/datasets/__init__.py文件中增加上面定义的类MyDataset,如下图所示:
3.3 新建数据集配置文件
在mmclassification工程目录configs/_base_/datasets/文件夹下,新建mydataset.py文件,主要用于设置数据集类型、数据增强方式、batch size (samples_per_gpu)、数据集路径和标注文件路径、模型保存周期(interval)。文件内容如下所示:
# -*- coding: utf-8 -*-
"""
乐乐感知学堂公众号
@author: https://blog.csdn.net/suiyingy
"""
dataset_type = 'MyDataset'
classes = ['cat', 'bird', 'dog'] # The category names of your dataset
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=(256, -1)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(
type=dataset_type,
data_prefix='data/Minist/train',
ann_file='data/Minist/meta/train.txt',
classes=classes,
pipeline=train_pipeline
),
val=dict(
type=dataset_type,
data_prefix='data/Minist/test',
ann_file='data/Minist/meta/test.txt',
classes=classes,
pipeline=test_pipeline
),
test=dict(
type=dataset_type,
data_prefix='data/Minist/test',
ann_file='data/Minist/meta/test.txt',
classes=classes,
pipeline=test_pipeline
)
)
evaluation = dict(interval=1, metric='accuracy')
4 修改configs模型配置文件
以configs/resnet/resnet18_8xb16_cifar10.py配置文件为例,mmclassification的配置文件通常包含以下4个部分:
_base_ = [
'../_base_/models/resnet18_cifar.py', '../_base_/datasets/cifar10_bs16.py',
'../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
]
../_base_/models/resnet18_cifar.py:定义模型参数,主要包括主干网络、neck、head和类别数量。
../_base_/datasets/cifar10_bs16.py:定义数据集增强方式和路径,也就是3.3节的配置文件,bs16表示batch size为16,即samples_per_gpu=16。
../_base_/schedules/cifar10_bs128.py:定义训练参数,主要包括优化器、学习率、训练总epoch数量。
../_base_/default_runtime.py:定义运行参数,主要包括模型保存周期、日志输出周期等。
configs主要修改的地方为数据配置文件,即把 '../_base_/datasets/cifar10_bs16.py'更换成3.3节中的配置文件'../_base_/datasets/mydataset.py'。即:
5 运行训练程序
mmcls基本的训练命令为:
python tools/train.py 模型配置文件
示例:
python tools/train.py configs/resnet/resnet18_8xb16_cifar10.py
这里已经把resnet18_8xb16_cifar10.py文件按照第4节进行了修改。
6 运行结果
7 【python三维深度学习】python三维点云从基础到深度学习_Coding的叶子的博客-CSDN博客_python 三维点云
更多三维、二维感知算法和金融量化分析算法请关注“乐乐感知学堂”微信公众号,并将持续进行更新。
边栏推荐
- Principle of redis cluster mode
- Redirection of redis cluster
- The ninth Operation Committee meeting of dragon lizard community was successfully held
- Use and install RkNN toolkit Lite2 on itop-3568 development board NPU
- [calculation of loss in yolov3]
- Codeforces Round #804 (Div. 2)
- Open3D 欧式聚类
- Redis cluster (master-slave) brain crack and solution
- MySQL giant pit: update updates should be judged with caution by affecting the number of rows!!!
- Empêcher le navigateur de reculer
猜你喜欢
Harbor image warehouse construction
Yolov5 target detection neural network -- calculation principle of loss function
简单解决redis cluster中从节点读取不了数据(error) MOVED
Cdga | six principles that data governance has to adhere to
yolov5目标检测神经网络——损失函数计算原理
【pytorch 修改预训练模型:实测加载预训练模型与模型随机初始化差别不大】
【 YOLOv3中Loss部分计算】
Redis cluster (master-slave) brain crack and solution
XML parsing
COMSOL -- establishment of geometric model -- establishment of two-dimensional graphics
随机推荐
项目总结笔记系列 wsTax KT Session2 代码分析
Idea set the number of open file windows
C#实现WinForm DataGridView控件支持叠加数据绑定
【TFLite, ONNX, CoreML, TensorRT Export】
【云原生 | Kubernetes篇】Ingress案例实战(十三)
Error assembling WAR: webxml attribute is required (or pre-existing WEB-INF/web.xml if executing in
7 themes and 9 technology masters! Dragon Dragon lecture hall hard core live broadcast preview in July, see you tomorrow
[upsampling method opencv interpolation]
【load dataset】
How can China Africa diamond accessory stones be inlaid to be safe and beautiful?
Principle of persistence mechanism of redis
Network five whip
中非 钻石副石怎么镶嵌,才能既安全又好看?
Pytorch weight decay and dropout
Install esxi 6.0 interactively
Riddle 1
Hash tag usage in redis cluster
[calculation of loss in yolov3]
【主流Nivida显卡深度学习/强化学习/AI算力汇总】
redis的持久化机制原理