当前位置:网站首页>[image classification] how to use mmclassification to train your classification model
[image classification] how to use mmclassification to train your classification model
2022-07-29 06:03:00 【Dull cat】
List of articles

MMclassification Is a classification tool library , This article is a brief record of how to use this tool library to train your classification model , Including data preparation , Model modification , model training , Model testing, etc .
MMclassification link :https://github.com/open-mmlab/mmclassification
install :https://mmclassification.readthedocs.io/en/latest/install.html
Training :https://mmclassification.readthedocs.io/en/latest/getting_started.html
One 、 Data preparation
MMclassification Support ImageNet and cifar Two data formats , We use ImageNet Take the data structure as an example :
|- imagenet
| |- classmap.txt
| |- train
| | |- cls1
| | |- cls2
| | |- cls3
| | |- ...
| |- train.txt
| |- val
| | |- images
| |- val.txt
Suppose we want to train a cat dog dichotomy model , The form of organization is as follows :
|- dog_cat_dataset
| |- classmap.txt
| |- train
| | |- dog
| | |- cat
| |- train.txt
| |- val
| | |- images
| |- val.txt
among ,classmap.txt The contents in are as follows :
dog 0
cat 1
Two 、 Model modification
Suppose you use resnet18 To train , Then the content we need to modify mainly focuses on config Inside the document , The modified config file resnet18_b32x8_dog_cat_cls.py as follows :
- Modify category : take 1000 Class to 2 class
- Modify data path :data
- If the data preprocessing needs to be modified , It can also be in config Inside modification
- because config It's the most advanced , Therefore, the modified model will be overwritten from mmcls Parameters read from the Library
_base_ = [
'../_base_/models/resnet18.py', '../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
model = dict(
head=dict(
type='LinearClsHead',
num_classes=2,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, ),
))
data = dict(
samples_per_gpu=32,
workers_per_gpu=1,
train=dict(
data_prefix='data/dog_cat_dataset/train',
ann_file='data/dog_cat_dataset/train.txt',
classes='data/dog_cat_dataset/classmap.txt'),
val=dict(
data_prefix='data/dog_cat_dataset/val',
ann_file='data/dog_cat_dataset/val.txt',
classes='data/dog_cat_dataset/classmap.txt'),
test=dict(
# replace `data/val` with `data/test` for standard test
data_prefix='data/dog_cat_dataset/val',
ann_file='data/dog_cat_dataset/val.txt',
classes='data/dog_cat_dataset/classmap.txt'))
evaluation = dict(interval=1, metric='accuracy', metric_options={
'topk': (1, )})
3、 ... and 、 model training
python tools/train.py configs/resnet/resnet18_b32x8_dog_cat_cls.py

Four 、 Model effect visualization
python tools/test.py configs/resnet/resnet18_b32x8_dog_cat_cls.py ./models/epoch_99.pth --out result.pkl --show-dir output_cls
Use gradcam visualization :
python tools/visualizations/vis_cam.py visual_img/4.jpg configs/resnet/resnet18_b32x8_door.py ./models/epoch_99.pth --s
ave-path visual_path/4.jpg
5、 ... and 、 How to calculate the accuracy rate and recall rate of each category respectively
Advanced line test , obtain result.pkl file , Then run the following program :
python tools/cal_precision.py configs/resnet/resnet18_b32x8_imagenet.py
import mmcv
import argparse
from mmcls.datasets import build_dataset
from mmcls.core.evaluation import calculate_confusion_matrix
from sklearn.metrics import confusion_matrix
def parse_args():
parser = argparse.ArgumentParser(description='calculate precision and recall for each class')
parser.add_argument('config', help='test config file path')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
dataset = build_dataset(cfg.data.test)
pred = mmcv.load("./result.pkl")['pred_label']
matrix = confusion_matrix(pred, dataset.get_gt_labels())
print('confusion_matrix:', matrix)
cat_recall = matrix[0,0]/(matrix[0,0]+matrix[1,0])
dog_recall = matrix[1,1]/(matrix[0,1]+matrix[1,1])
cat_precision = matrix[0,0]/sum(matrix[0])
dog_precision = matrix[1,1]/sum(matrix[1])
print(' cat_precision:{} \n dog_precison:{} \n cat_recall:{} \n dog_recall:{}'.format(cat_precision, dog_precison, cat_recall, dog_recall))
if __name__ == '__main__':
main()
边栏推荐
- Training log 4 of the project "construction of Shandong University mobile Internet development technology teaching website"
- 【Transformer】ATS: Adaptive Token Sampling For Efficient Vision Transformers
- Anr Optimization: cause oom crash and corresponding solutions
- Activity交互问题,你确定都知道?
- 【Attention】Visual Attention Network
- Ffmpeg creation GIF expression pack tutorial is coming! Say thank you, brother black fly?
- 【语义分割】SETR_Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformer
- Flutter 绘制技巧探索:一起画箭头(技巧拓展)
- Simple optimization of interesting apps for deep learning (suitable for novices)
- 并发编程学习笔记 之 工具类CountDownLatch、CyclicBarrier详解
猜你喜欢

深入理解MMAP原理,让大厂都爱不释手的技术

Are you sure you know the interaction problem of activity?

【语义分割】Fully Attentional Network for Semantic Segmentation

Most PHP programmers don't understand how to deploy safe code

Use of file upload (2) -- upload to Alibaba cloud OSS file server

Flutter 绘制技巧探索:一起画箭头(技巧拓展)
![[overview] image classification network](/img/2b/7e3ba36a4d7e95cb262eebaadee2f3.png)
[overview] image classification network

datax安装

【目标检测】Generalized Focal Loss V1

【Transformer】SegFormer:Simple and Efficient Design for Semantic Segmentation with Transformers
随机推荐
Are you sure you know the interaction problem of activity?
Research on the implementation principle of reentrantlock in concurrent programming learning notes
Use of file upload (2) -- upload to Alibaba cloud OSS file server
【语义分割】语义分割综述
Huawei 2020 school recruitment written test programming questions read this article is enough (Part 2)
Synchronous development with open source projects & codereview & pull request & Fork how to pull the original warehouse
Lock lock of concurrent programming learning notes and its implementation basic usage of reentrantlock, reentrantreadwritelock and stampedlock
中海油集团,桌面云&网盘存储系统应用案例
【目标检测】Generalized Focal Loss V1
mysql在查询字符串类型的时候带单引号和不带的区别和原因
赓续新征程,共驭智存储
Valuable blog and personal experience collection (continuous update)
Most PHP programmers don't understand how to deploy safe code
【语义分割】Mapillary 数据集简介
Thinkphp6 pipeline mode pipeline use
D3.JS 纵向关系图(加箭头,连接线文字描述)
并发编程学习笔记 之 工具类Semaphore(信号量)
Personal learning website
【比赛网站】收集机器学习/深度学习比赛网站(持续更新)
【CV】请问卷积核(滤波器)3*3、5*5、7*7、11*11 都是具体什么数?