当前位置:网站首页>[image classification] how to use mmclassification to train your classification model
[image classification] how to use mmclassification to train your classification model
2022-07-29 06:03:00 【Dull cat】
List of articles

MMclassification Is a classification tool library , This article is a brief record of how to use this tool library to train your classification model , Including data preparation , Model modification , model training , Model testing, etc .
MMclassification link :https://github.com/open-mmlab/mmclassification
install :https://mmclassification.readthedocs.io/en/latest/install.html
Training :https://mmclassification.readthedocs.io/en/latest/getting_started.html
One 、 Data preparation
MMclassification Support ImageNet and cifar Two data formats , We use ImageNet Take the data structure as an example :
|- imagenet
| |- classmap.txt
| |- train
| | |- cls1
| | |- cls2
| | |- cls3
| | |- ...
| |- train.txt
| |- val
| | |- images
| |- val.txt
Suppose we want to train a cat dog dichotomy model , The form of organization is as follows :
|- dog_cat_dataset
| |- classmap.txt
| |- train
| | |- dog
| | |- cat
| |- train.txt
| |- val
| | |- images
| |- val.txt
among ,classmap.txt The contents in are as follows :
dog 0
cat 1
Two 、 Model modification
Suppose you use resnet18 To train , Then the content we need to modify mainly focuses on config Inside the document , The modified config file resnet18_b32x8_dog_cat_cls.py as follows :
- Modify category : take 1000 Class to 2 class
- Modify data path :data
- If the data preprocessing needs to be modified , It can also be in config Inside modification
- because config It's the most advanced , Therefore, the modified model will be overwritten from mmcls Parameters read from the Library
_base_ = [
'../_base_/models/resnet18.py', '../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
model = dict(
head=dict(
type='LinearClsHead',
num_classes=2,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, ),
))
data = dict(
samples_per_gpu=32,
workers_per_gpu=1,
train=dict(
data_prefix='data/dog_cat_dataset/train',
ann_file='data/dog_cat_dataset/train.txt',
classes='data/dog_cat_dataset/classmap.txt'),
val=dict(
data_prefix='data/dog_cat_dataset/val',
ann_file='data/dog_cat_dataset/val.txt',
classes='data/dog_cat_dataset/classmap.txt'),
test=dict(
# replace `data/val` with `data/test` for standard test
data_prefix='data/dog_cat_dataset/val',
ann_file='data/dog_cat_dataset/val.txt',
classes='data/dog_cat_dataset/classmap.txt'))
evaluation = dict(interval=1, metric='accuracy', metric_options={
'topk': (1, )})
3、 ... and 、 model training
python tools/train.py configs/resnet/resnet18_b32x8_dog_cat_cls.py

Four 、 Model effect visualization
python tools/test.py configs/resnet/resnet18_b32x8_dog_cat_cls.py ./models/epoch_99.pth --out result.pkl --show-dir output_cls
Use gradcam visualization :
python tools/visualizations/vis_cam.py visual_img/4.jpg configs/resnet/resnet18_b32x8_door.py ./models/epoch_99.pth --s
ave-path visual_path/4.jpg
5、 ... and 、 How to calculate the accuracy rate and recall rate of each category respectively
Advanced line test , obtain result.pkl file , Then run the following program :
python tools/cal_precision.py configs/resnet/resnet18_b32x8_imagenet.py
import mmcv
import argparse
from mmcls.datasets import build_dataset
from mmcls.core.evaluation import calculate_confusion_matrix
from sklearn.metrics import confusion_matrix
def parse_args():
parser = argparse.ArgumentParser(description='calculate precision and recall for each class')
parser.add_argument('config', help='test config file path')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
dataset = build_dataset(cfg.data.test)
pred = mmcv.load("./result.pkl")['pred_label']
matrix = confusion_matrix(pred, dataset.get_gt_labels())
print('confusion_matrix:', matrix)
cat_recall = matrix[0,0]/(matrix[0,0]+matrix[1,0])
dog_recall = matrix[1,1]/(matrix[0,1]+matrix[1,1])
cat_precision = matrix[0,0]/sum(matrix[0])
dog_precision = matrix[1,1]/sum(matrix[1])
print(' cat_precision:{} \n dog_precison:{} \n cat_recall:{} \n dog_recall:{}'.format(cat_precision, dog_precison, cat_recall, dog_recall))
if __name__ == '__main__':
main()
边栏推荐
- Flutter正在被悄悄放弃?浅析Flutter的未来
- 【语义分割】Mapillary 数据集简介
- 【目标检测】KL-Loss:Bounding Box Regression with Uncertainty for Accurate Object Detection
- Win10+opencv3.2+vs2015 configuration
- Flutter 绘制技巧探索:一起画箭头(技巧拓展)
- Performance comparison | FASS iSCSI vs nvme/tcp
- 【图像分类】如何使用 mmclassification 训练自己的分类模型
- Basic use of array -- traverse the circular array to find the maximum value, minimum value, maximum subscript and minimum subscript of the array
- Reporting Services- Web Service
- 与张小姐的春夏秋冬(5)
猜你喜欢

PHP write a diaper to buy the lowest price in the whole network

Spring, summer, autumn and winter with Miss Zhang (2)

简单聊聊 PendingIntent 与 Intent 的区别

在uni-app项目中,如何实现微信小程序openid的获取

My ideal job, the absolute freedom of coder farmers is the most important - the pursuit of entrepreneurship in the future

Markdown语法

【网络设计】ConvNeXt:A ConvNet for the 2020s

Reporting Services- Web Service

ASM插桩:学完ASM Tree api,再也不用怕hook了

mysql 的show profiles 使用。
随机推荐
SSM integration
Use of file upload (2) -- upload to Alibaba cloud OSS file server
Flutter 绘制技巧探索:一起画箭头(技巧拓展)
性能优化之趣谈线程池:线程开的越多就越好吗?
Android studio login registration - source code (connect to MySQL database)
与张小姐的春夏秋冬(1)
数组的基础使用--遍历循环数组求出数组最大值,最小值以及最大值下标,最小值下标
Isaccessible() method: use reflection techniques to improve your performance several times
【Transformer】ACMix:On the Integration of Self-Attention and Convolution
30 knowledge points that must be mastered in quantitative development [what is level-2 data]
【Transformer】SOFT: Softmax-free Transformer with Linear Complexity
Huawei 2020 school recruitment written test programming questions read this article is enough (Part 1)
[competition website] collect machine learning / deep learning competition website (continuously updated)
IDEA中设置自动build-改动代码,不用重启工程,刷新页面即可
Nailing alarm script
【语义分割】语义分割综述
通过简单的脚本在Linux环境实现Mysql数据库的定时备份(Mysqldump命令备份)
asyncawait和promise的区别
ssm整合
PHP write a diaper to buy the lowest price in the whole network