当前位置:网站首页>【图像分类】如何使用 mmclassification 训练自己的分类模型
【图像分类】如何使用 mmclassification 训练自己的分类模型
2022-07-29 05:21:00 【呆呆的猫】

MMclassification 是一个分类工具库,这篇文章是简单记录一下如何用该工具库来训练自己的分类模型,包括数据准备,模型修改,模型训练,模型测试等等。
MMclassification链接:https://github.com/open-mmlab/mmclassification
安装:https://mmclassification.readthedocs.io/en/latest/install.html
训练:https://mmclassification.readthedocs.io/en/latest/getting_started.html
一、数据准备
MMclassification 支持 ImageNet 和 cifar 两种数据格式,我们以 ImageNet 为例来看看数据结构:
|- imagenet
| |- classmap.txt
| |- train
| | |- cls1
| | |- cls2
| | |- cls3
| | |- ...
| |- train.txt
| |- val
| | |- images
| |- val.txt
假设我们要训练一个猫狗二分类模型,则需要组织的形式如下:
|- dog_cat_dataset
| |- classmap.txt
| |- train
| | |- dog
| | |- cat
| |- train.txt
| |- val
| | |- images
| |- val.txt
其中,classmap.txt 中的内容如下:
dog 0
cat 1
二、模型修改
假设使用 resnet18 来训练,则我们需要修改的内容主要集中在 config 文件里边,修改后的config文件 resnet18_b32x8_dog_cat_cls.py 如下:
- 修改类别:将 1000 类改为 2 类
- 修改数据路径:data
- 如果数据前处理需要修改的话,也可以在config里边修改
- 因为config是最高级的,所以在这里修改后会覆盖模型从mmcls库中读出来的参数
_base_ = [
'../_base_/models/resnet18.py', '../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
model = dict(
head=dict(
type='LinearClsHead',
num_classes=2,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, ),
))
data = dict(
samples_per_gpu=32,
workers_per_gpu=1,
train=dict(
data_prefix='data/dog_cat_dataset/train',
ann_file='data/dog_cat_dataset/train.txt',
classes='data/dog_cat_dataset/classmap.txt'),
val=dict(
data_prefix='data/dog_cat_dataset/val',
ann_file='data/dog_cat_dataset/val.txt',
classes='data/dog_cat_dataset/classmap.txt'),
test=dict(
# replace `data/val` with `data/test` for standard test
data_prefix='data/dog_cat_dataset/val',
ann_file='data/dog_cat_dataset/val.txt',
classes='data/dog_cat_dataset/classmap.txt'))
evaluation = dict(interval=1, metric='accuracy', metric_options={
'topk': (1, )})
三、模型训练
python tools/train.py configs/resnet/resnet18_b32x8_dog_cat_cls.py

四、模型效果可视化
python tools/test.py configs/resnet/resnet18_b32x8_dog_cat_cls.py ./models/epoch_99.pth --out result.pkl --show-dir output_cls
使用 gradcam 可视化:
python tools/visualizations/vis_cam.py visual_img/4.jpg configs/resnet/resnet18_b32x8_door.py ./models/epoch_99.pth --s
ave-path visual_path/4.jpg
五、如何分别计算每个类别的精确率和召回率
先进行测试,得到 result.pkl 文件,然后运行下面的程序即可:
python tools/cal_precision.py configs/resnet/resnet18_b32x8_imagenet.py
import mmcv
import argparse
from mmcls.datasets import build_dataset
from mmcls.core.evaluation import calculate_confusion_matrix
from sklearn.metrics import confusion_matrix
def parse_args():
parser = argparse.ArgumentParser(description='calculate precision and recall for each class')
parser.add_argument('config', help='test config file path')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
dataset = build_dataset(cfg.data.test)
pred = mmcv.load("./result.pkl")['pred_label']
matrix = confusion_matrix(pred, dataset.get_gt_labels())
print('confusion_matrix:', matrix)
cat_recall = matrix[0,0]/(matrix[0,0]+matrix[1,0])
dog_recall = matrix[1,1]/(matrix[0,1]+matrix[1,1])
cat_precision = matrix[0,0]/sum(matrix[0])
dog_precision = matrix[1,1]/sum(matrix[1])
print(' cat_precision:{} \n dog_precison:{} \n cat_recall:{} \n dog_recall:{}'.format(cat_precision, dog_precison, cat_recall, dog_recall))
if __name__ == '__main__':
main()
边栏推荐
- Use of file upload (2) -- upload to Alibaba cloud OSS file server
- Thinkphp6 output QR code image format to solve the conflict with debug
- Semaphore (semaphore) for learning notes of concurrent programming
- MySql统计函数COUNT详解
- Huawei 2020 school recruitment written test programming questions read this article is enough (Part 1)
- 并发编程学习笔记 之 ReentrantLock实现原理的探究
- Android studio login registration - source code (connect to MySQL database)
- Centos7 silently installs Oracle
- [DL] build convolutional neural network for regression prediction (detailed tutorial of data + code)
- "Shandong University mobile Internet development technology teaching website construction" project training log I
猜你喜欢

Breaking through the hardware bottleneck (I): the development of Intel Architecture and bottleneck mining

"Shandong University mobile Internet development technology teaching website construction" project training log V

ReportingService WebService Form身份验证

与张小姐的春夏秋冬(2)

Power BI Report Server 自定义身份验证

Detailed explanation of MySQL statistical function count

Some opportunities for young people in rural brand building
![30 knowledge points that must be mastered in quantitative development [what is level-2 data]](/img/cc/8eb2f0b11679af57e196f6e6d828f8.png)
30 knowledge points that must be mastered in quantitative development [what is level-2 data]
![[pycharm] pycharm remote connection server](/img/b2/a4e1c095343f9e635ff3dad1e3c507.png)
[pycharm] pycharm remote connection server

Tear the ORM framework by hand (generic + annotation + reflection)
随机推荐
anaconda中移除旧环境、增加新环境、查看环境、安装库、清理缓存等操作命令
【ML】机器学习模型之PMML--概述
与张小姐的春夏秋冬(1)
Simple optimization of interesting apps for deep learning (suitable for novices)
【go】defer的使用
Most PHP programmers don't understand how to deploy safe code
数组的基础使用--遍历循环数组求出数组最大值,最小值以及最大值下标,最小值下标
ANR优化:导致 OOM 崩溃及相对应的解决方案
My ideal job, the absolute freedom of coder farmers is the most important - the pursuit of entrepreneurship in the future
How to PR an open source composer project
与张小姐的春夏秋冬(2)
30 knowledge points that must be mastered in quantitative development [what is level-2 data]
『全闪实测』数据库加速解决方案
How does PHP generate QR code?
Ribbon learning notes II
Reporting Services- Web Service
DataX installation
性能优化之趣谈线程池:线程开的越多就越好吗?
浅谈分布式全闪存储自动化测试平台设计
Android Studio 实现登录注册-源代码 (连接MySql数据库)