当前位置:网站首页>【图像分类】如何使用 mmclassification 训练自己的分类模型
【图像分类】如何使用 mmclassification 训练自己的分类模型
2022-07-29 05:21:00 【呆呆的猫】

MMclassification 是一个分类工具库,这篇文章是简单记录一下如何用该工具库来训练自己的分类模型,包括数据准备,模型修改,模型训练,模型测试等等。
MMclassification链接:https://github.com/open-mmlab/mmclassification
安装:https://mmclassification.readthedocs.io/en/latest/install.html
训练:https://mmclassification.readthedocs.io/en/latest/getting_started.html
一、数据准备
MMclassification 支持 ImageNet 和 cifar 两种数据格式,我们以 ImageNet 为例来看看数据结构:
|- imagenet
| |- classmap.txt
| |- train
| | |- cls1
| | |- cls2
| | |- cls3
| | |- ...
| |- train.txt
| |- val
| | |- images
| |- val.txt
假设我们要训练一个猫狗二分类模型,则需要组织的形式如下:
|- dog_cat_dataset
| |- classmap.txt
| |- train
| | |- dog
| | |- cat
| |- train.txt
| |- val
| | |- images
| |- val.txt
其中,classmap.txt 中的内容如下:
dog 0
cat 1
二、模型修改
假设使用 resnet18 来训练,则我们需要修改的内容主要集中在 config 文件里边,修改后的config文件 resnet18_b32x8_dog_cat_cls.py
如下:
- 修改类别:将 1000 类改为 2 类
- 修改数据路径:data
- 如果数据前处理需要修改的话,也可以在config里边修改
- 因为config是最高级的,所以在这里修改后会覆盖模型从mmcls库中读出来的参数
_base_ = [
'../_base_/models/resnet18.py', '../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
model = dict(
head=dict(
type='LinearClsHead',
num_classes=2,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, ),
))
data = dict(
samples_per_gpu=32,
workers_per_gpu=1,
train=dict(
data_prefix='data/dog_cat_dataset/train',
ann_file='data/dog_cat_dataset/train.txt',
classes='data/dog_cat_dataset/classmap.txt'),
val=dict(
data_prefix='data/dog_cat_dataset/val',
ann_file='data/dog_cat_dataset/val.txt',
classes='data/dog_cat_dataset/classmap.txt'),
test=dict(
# replace `data/val` with `data/test` for standard test
data_prefix='data/dog_cat_dataset/val',
ann_file='data/dog_cat_dataset/val.txt',
classes='data/dog_cat_dataset/classmap.txt'))
evaluation = dict(interval=1, metric='accuracy', metric_options={
'topk': (1, )})
三、模型训练
python tools/train.py configs/resnet/resnet18_b32x8_dog_cat_cls.py
四、模型效果可视化
python tools/test.py configs/resnet/resnet18_b32x8_dog_cat_cls.py ./models/epoch_99.pth --out result.pkl --show-dir output_cls
使用 gradcam 可视化:
python tools/visualizations/vis_cam.py visual_img/4.jpg configs/resnet/resnet18_b32x8_door.py ./models/epoch_99.pth --s
ave-path visual_path/4.jpg
五、如何分别计算每个类别的精确率和召回率
先进行测试,得到 result.pkl
文件,然后运行下面的程序即可:
python tools/cal_precision.py configs/resnet/resnet18_b32x8_imagenet.py
import mmcv
import argparse
from mmcls.datasets import build_dataset
from mmcls.core.evaluation import calculate_confusion_matrix
from sklearn.metrics import confusion_matrix
def parse_args():
parser = argparse.ArgumentParser(description='calculate precision and recall for each class')
parser.add_argument('config', help='test config file path')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
dataset = build_dataset(cfg.data.test)
pred = mmcv.load("./result.pkl")['pred_label']
matrix = confusion_matrix(pred, dataset.get_gt_labels())
print('confusion_matrix:', matrix)
cat_recall = matrix[0,0]/(matrix[0,0]+matrix[1,0])
dog_recall = matrix[1,1]/(matrix[0,1]+matrix[1,1])
cat_precision = matrix[0,0]/sum(matrix[0])
dog_precision = matrix[1,1]/sum(matrix[1])
print(' cat_precision:{} \n dog_precison:{} \n cat_recall:{} \n dog_recall:{}'.format(cat_precision, dog_precison, cat_recall, dog_recall))
if __name__ == '__main__':
main()
边栏推荐
- Nifi changed UTC time to CST time
- mysql插入百万数据(使用函数和存储过程)
- 并发编程学习笔记 之 工具类CountDownLatch、CyclicBarrier详解
- [database] database course design - vaccination database
- Use of file upload (2) -- upload to Alibaba cloud OSS file server
- 【比赛网站】收集机器学习/深度学习比赛网站(持续更新)
- Spring, summer, autumn and winter with Miss Zhang (2)
- Research on the implementation principle of reentrantlock in concurrent programming learning notes
- 简单聊聊 PendingIntent 与 Intent 的区别
- 与张小姐的春夏秋冬(4)
猜你喜欢
深入理解MMAP原理,让大厂都爱不释手的技术
[DL] introduction and understanding of tensor
Ribbon学习笔记二
Lock lock of concurrent programming learning notes and its implementation basic usage of reentrantlock, reentrantreadwritelock and stampedlock
Centos7 silently installs Oracle
Simple optimization of interesting apps for deep learning (suitable for novices)
Realize the scheduled backup of MySQL database in Linux environment through simple script (mysqldump command backup)
nacos外置数据库的配置与使用
识变!应变!求变!
并发编程学习笔记 之 ReentrantLock实现原理的探究
随机推荐
Research on the implementation principle of reentrantlock in concurrent programming learning notes
nacos外置数据库的配置与使用
重庆大道云行作为软件产业代表受邀参加渝中区重点项目签约仪式
Tear the ORM framework by hand (generic + annotation + reflection)
Reporting Services- Web Service
Breaking through the hardware bottleneck (I): the development of Intel Architecture and bottleneck mining
Process management of day02 operation
Use of xtrabackup
Centos7 silently installs Oracle
简单聊聊 PendingIntent 与 Intent 的区别
Ribbon learning notes 1
Use of file upload (2) -- upload to Alibaba cloud OSS file server
与张小姐的春夏秋冬(2)
The difference between asyncawait and promise
Android studio login registration - source code (connect to MySQL database)
并发编程学习笔记 之 原子操作类AtomicInteger详解
Some opportunities for young people in rural brand building
与张小姐的春夏秋冬(5)
Detailed explanation of tool classes countdownlatch and cyclicbarrier of concurrent programming learning notes
【目标检测】KL-Loss:Bounding Box Regression with Uncertainty for Accurate Object Detection