当前位置:网站首页>【图像分类】如何使用 mmclassification 训练自己的分类模型
【图像分类】如何使用 mmclassification 训练自己的分类模型
2022-07-29 05:21:00 【呆呆的猫】

MMclassification 是一个分类工具库,这篇文章是简单记录一下如何用该工具库来训练自己的分类模型,包括数据准备,模型修改,模型训练,模型测试等等。
MMclassification链接:https://github.com/open-mmlab/mmclassification
安装:https://mmclassification.readthedocs.io/en/latest/install.html
训练:https://mmclassification.readthedocs.io/en/latest/getting_started.html
一、数据准备
MMclassification 支持 ImageNet 和 cifar 两种数据格式,我们以 ImageNet 为例来看看数据结构:
|- imagenet
| |- classmap.txt
| |- train
| | |- cls1
| | |- cls2
| | |- cls3
| | |- ...
| |- train.txt
| |- val
| | |- images
| |- val.txt
假设我们要训练一个猫狗二分类模型,则需要组织的形式如下:
|- dog_cat_dataset
| |- classmap.txt
| |- train
| | |- dog
| | |- cat
| |- train.txt
| |- val
| | |- images
| |- val.txt
其中,classmap.txt 中的内容如下:
dog 0
cat 1
二、模型修改
假设使用 resnet18 来训练,则我们需要修改的内容主要集中在 config 文件里边,修改后的config文件 resnet18_b32x8_dog_cat_cls.py
如下:
- 修改类别:将 1000 类改为 2 类
- 修改数据路径:data
- 如果数据前处理需要修改的话,也可以在config里边修改
- 因为config是最高级的,所以在这里修改后会覆盖模型从mmcls库中读出来的参数
_base_ = [
'../_base_/models/resnet18.py', '../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
model = dict(
head=dict(
type='LinearClsHead',
num_classes=2,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, ),
))
data = dict(
samples_per_gpu=32,
workers_per_gpu=1,
train=dict(
data_prefix='data/dog_cat_dataset/train',
ann_file='data/dog_cat_dataset/train.txt',
classes='data/dog_cat_dataset/classmap.txt'),
val=dict(
data_prefix='data/dog_cat_dataset/val',
ann_file='data/dog_cat_dataset/val.txt',
classes='data/dog_cat_dataset/classmap.txt'),
test=dict(
# replace `data/val` with `data/test` for standard test
data_prefix='data/dog_cat_dataset/val',
ann_file='data/dog_cat_dataset/val.txt',
classes='data/dog_cat_dataset/classmap.txt'))
evaluation = dict(interval=1, metric='accuracy', metric_options={
'topk': (1, )})
三、模型训练
python tools/train.py configs/resnet/resnet18_b32x8_dog_cat_cls.py
四、模型效果可视化
python tools/test.py configs/resnet/resnet18_b32x8_dog_cat_cls.py ./models/epoch_99.pth --out result.pkl --show-dir output_cls
使用 gradcam 可视化:
python tools/visualizations/vis_cam.py visual_img/4.jpg configs/resnet/resnet18_b32x8_door.py ./models/epoch_99.pth --s
ave-path visual_path/4.jpg
五、如何分别计算每个类别的精确率和召回率
先进行测试,得到 result.pkl
文件,然后运行下面的程序即可:
python tools/cal_precision.py configs/resnet/resnet18_b32x8_imagenet.py
import mmcv
import argparse
from mmcls.datasets import build_dataset
from mmcls.core.evaluation import calculate_confusion_matrix
from sklearn.metrics import confusion_matrix
def parse_args():
parser = argparse.ArgumentParser(description='calculate precision and recall for each class')
parser.add_argument('config', help='test config file path')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
dataset = build_dataset(cfg.data.test)
pred = mmcv.load("./result.pkl")['pred_label']
matrix = confusion_matrix(pred, dataset.get_gt_labels())
print('confusion_matrix:', matrix)
cat_recall = matrix[0,0]/(matrix[0,0]+matrix[1,0])
dog_recall = matrix[1,1]/(matrix[0,1]+matrix[1,1])
cat_precision = matrix[0,0]/sum(matrix[0])
dog_precision = matrix[1,1]/sum(matrix[1])
print(' cat_precision:{} \n dog_precison:{} \n cat_recall:{} \n dog_recall:{}'.format(cat_precision, dog_precison, cat_recall, dog_recall))
if __name__ == '__main__':
main()
边栏推荐
- Research and implementation of flash loan DAPP
- 【DL】搭建卷积神经网络用于回归预测(数据+代码详细教程)
- nacos外置数据库的配置与使用
- 【数据库】数据库课程设计一一疫苗接种数据库
- Ribbon learning notes 1
- Spring, summer, autumn and winter with Miss Zhang (4)
- Markdown语法
- Realize the scheduled backup of MySQL database in Linux environment through simple script (mysqldump command backup)
- 【bug】XLRDError: Excel xlsx file; not supported
- 主流实时流处理计算框架Flink初体验。
猜你喜欢
重庆大道云行作为软件产业代表受邀参加渝中区重点项目签约仪式
mysql插入百万数据(使用函数和存储过程)
与张小姐的春夏秋冬(2)
Spring, summer, autumn and winter with Miss Zhang (2)
这些你一定要知道的进程知识
Semaphore (semaphore) for learning notes of concurrent programming
day02 作业之文件权限
Tear the ORM framework by hand (generic + annotation + reflection)
C# 判断用户是手机访问还是电脑访问
My ideal job, the absolute freedom of coder farmers is the most important - the pursuit of entrepreneurship in the future
随机推荐
day02作业之进程管理
简单聊聊 PendingIntent 与 Intent 的区别
【数据库】数据库课程设计一一疫苗接种数据库
[DL] introduction and understanding of tensor
并发编程学习笔记 之 工具类Semaphore(信号量)
数组的基础使用--遍历循环数组求出数组最大值,最小值以及最大值下标,最小值下标
Research and implementation of flash loan DAPP
Research on the implementation principle of reentrantlock in concurrent programming learning notes
与张小姐的春夏秋冬(5)
浅谈分布式全闪存储自动化测试平台设计
手撕ORM 框架(泛型+注解+反射)
Intelligent security of the fifth space ⼤ real competition problem ----------- PNG diagram ⽚ converter
How to PR an open source composer project
微信小程序源码获取(附工具的下载)
C # judge whether the user accesses by mobile phone or computer
C# 连接 SharepointOnline WebService
【目标检测】Generalized Focal Loss V1
Show profiles of MySQL is used.
Ribbon学习笔记二
与张小姐的春夏秋冬(1)