当前位置:网站首页>pytorch_ grad_ Cam -- visual Library of class activation mapping (CAM) under pytorch
pytorch_ grad_ Cam -- visual Library of class activation mapping (CAM) under pytorch
2022-06-27 02:16:00 【Ten thousand miles' journey to】
Deep learning is a " Black box " System . It passes through “end-to-end” The way to work , The intermediate process is unknowable , Through intermediate feature visualization, the data of the model can be explained to some extent . The earliest feature visualization is through the last... Of the model conv Layer of Global average pooling Realization , And set the classification layer as a single full connection layer . adopt Global average pooling To determine the value of each feature map The weight of , Then add it up to achieve visualization . Later, a series of , Based on specific class label Visualization method of back propagation to obtain gradient ,Grad-CAM. For more detailed development routes, please refer to Ten thousand words long text : Feature visualization technology (CAM) - You know
ad locum , Bloggers just want to share a simple pytorch Under the CAM Use of visualization Library , By using this library , It only needs a few lines of code to realize CAM visualization . Besides , Bloggers are also based on hook Self implemented GradCAM, The code is at the end of this article , Two pieces of code can be copied together . Based on the classification or positioning of wrong samples CAM visualization , Our model can quickly locate our symptoms , Selective adjustment data , So as to enhance the prediction accuracy of the model .GitHub - jacobgil/pytorch-grad-cam: Many Class Activation Map methods implemented in Pytorch for CNNs and Vision Transformers. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM Bloggers have not seen the source code of the library , But it is expected to use hook Technology implementation library ( stay pytorch In the process of forward and back propagation of the model , Can be applied to any layer Set up hook, Status of pulling data ).
The library supports the following CAM Method , At the same time, it supports online image enhancement to make CAM The result is smoother .
| Method | What it does |
|---|---|
| GradCAM | Weight the 2D activations by the average gradient |
| GradCAM++ | Like GradCAM but uses second order gradients |
| XGradCAM | Like GradCAM but scale the gradients by the normalized activations |
| AblationCAM | Zero out activations and measure how the output drops (this repository includes a fast batched implementation) |
| ScoreCAM | Perbutate the image by the scaled activations and measure how the output drops |
| EigenCAM | Takes the first principle component of the 2D Activations (no class discrimination, but seems to give great results) |
| EigenGradCAM | Like EigenCAM but with class discrimination: First principle component of Activations*Grad. Looks like GradCAM, but cleaner |
| LayerCAM | Spatially weight the activations by positive gradients. Works better especially in lower layers |
| FullGrad | Computes the gradients of the biases from all over the network, and then sums them |
1、 install
pip install pytorch_grad_cam
2、 Use
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50
import torchvision
import torch
from matplotlib import pyplot as plt
import numpy as np
def myimshows(imgs, titles=False, fname="test.jpg", size=6):
lens = len(imgs)
fig = plt.figure(figsize=(size * lens,size))
if titles == False:
titles="0123456789"
for i in range(1, lens + 1):
cols = 100 + lens * 10 + i
plt.xticks(())
plt.yticks(())
plt.subplot(cols)
if len(imgs[i - 1].shape) == 2:
plt.imshow(imgs[i - 1], cmap='Reds')
else:
plt.imshow(imgs[i - 1])
plt.title(titles[i - 1])
plt.xticks(())
plt.yticks(())
plt.savefig(fname, bbox_inches='tight')
plt.show()
def tensor2img(tensor,heatmap=False,shape=(224,224)):
np_arr=tensor.detach().numpy()#[0]
# Normalize the data
if np_arr.max()>1 or np_arr.min()<0:
np_arr=np_arr-np_arr.min()
np_arr=np_arr/np_arr.max()
#np_arr=(np_arr*255).astype(np.uint8)
if np_arr.shape[0]==1:
np_arr=np.concatenate([np_arr,np_arr,np_arr],axis=0)
np_arr=np_arr.transpose((1,2,0))
return np_arr
path=r"D:\\daxiang.jpg"
bin_data=torchvision.io.read_file(path)# Loading binary data
img=torchvision.io.decode_image(bin_data)/255# Decoded into CHW Pictures of the
img=img.unsqueeze(0)# become BCHW The data of ,B==1; squeeze
input_tensor=torchvision.transforms.functional.resize(img,[224, 224])
# Flip the image horizontally , Get two data
input_tensors=torch.cat([input_tensor, input_tensor.flip(dims=(3,))],axis=0)
model = resnet50(pretrained=True)
target_layers = [model.layer4[-1]]# If multiple layer,cam The output will take the mean
#cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
with GradCAM(model=model, target_layers=target_layers, use_cuda=False) as cam:
targets = [ClassifierOutputTarget(386),ClassifierOutputTarget(386)] # Specify view class_num by 386 Thermal diagram of
# aug_smooth=True, eigen_smooth=True Using image enhancement makes the heat map smoother
grayscale_cams = cam(input_tensor=input_tensors, targets=targets)#targets=None The category with the highest probability of automatic call is displayed
for grayscale_cam,tensor in zip(grayscale_cams,input_tensors):
# Integrate the thermal map results with the original map
rgb_img=tensor2img(tensor)
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
myimshows([rgb_img, grayscale_cam, visualization],["image","cam","image + cam"])The output after code execution is as shown in Figure 1 Shown

3、 Realize it by yourself GradCAM
3.1 Basic library imports and function definitions
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.transforms import transforms
from torchvision import models
from torchsummary import summary
from matplotlib import pyplot as plt
import numpy as np
import cv2
def myimshows(imgs, titles=False, fname="test.jpg", size=6):
lens = len(imgs)
fig = plt.figure(figsize=(size * lens,size))
if titles == False:
titles="0123456789"
for i in range(1, lens + 1):
cols = 100 + lens * 10 + i
plt.xticks(())
plt.yticks(())
plt.subplot(cols)
if len(imgs[i - 1].shape) == 2:
plt.imshow(imgs[i - 1], cmap='Reds')
else:
plt.imshow(imgs[i - 1])
plt.title(titles[i - 1])
plt.xticks(())
plt.yticks(())
plt.savefig(fname, bbox_inches='tight')
plt.show()
def tensor2img(tensor,heatmap=False,shape=(224,224)):
np_arr=tensor.detach().numpy()[0]
# Normalize the data
if np_arr.max()>1 or np_arr.min()<0:
np_arr=np_arr-np_arr.min()
np_arr=np_arr/np_arr.max()
np_arr=(np_arr*255).astype(np.uint8)
if np_arr.shape[0]==1:
np_arr=np.concatenate([np_arr,np_arr,np_arr],axis=0)
np_arr=np_arr.transpose((1,2,0))
if heatmap:
np_arr = cv2.resize(np_arr, shape)
np_arr = cv2.applyColorMap(np_arr, cv2.COLORMAP_JET) # Apply the heat map to the original image
return np_arr/255
def backward_hook(module, grad_in, grad_out):
grad_block.append(grad_out[0].detach())
print("backward_hook:",grad_in[0].shape,grad_out[0].shape)
def farward_hook(module, input, output):
fmap_block.append(output)
print("farward_hook:",input[0].shape,output.shape)3.2 Realization GradCAM
# Load model
model = models.resnet18(pretrained=True)
model.eval() # The evaluation mode
#summary(model,input_size=(3,512,512))
# register hook
fh=model.layer4.register_forward_hook(farward_hook)
bh=model.layer4.register_backward_hook(backward_hook)
# Define an array that stores features and gradients
fmap_block = list()
grad_block = list()
# Load variables and make predictions
path=r"D:\\daxiang.jpg"
bin_data=torchvision.io.read_file(path)# Loading binary data
img=torchvision.io.decode_image(bin_data)/255# Decoded into CHW Pictures of the
img=img.unsqueeze(0)# become BCHW The data of ,B==1; squeeze
img=torchvision.transforms.functional.resize(img,[224, 224])
preds=model(img)
print("pred type:",preds.argmax(1))
# structure label, And back propagation
clas=386#
trues=torch.ones((1,),dtype=torch.int64)*clas
ce_loss=nn.CrossEntropyLoss()
loss=ce_loss(preds,trues)
loss.backward()
# uninstall hook
fh.remove()
bh.remove()
# Take out the corresponding features and gradients
layer1_grad=grad_block[-1] #layer1_grad.shape [1, 64, 128, 128]
layer1_fmap=fmap_block[-1]
# Combine gradient with fmap Multiply
cam=layer1_grad[0,0].mul(layer1_fmap[0,0])
for i in range(1,layer1_grad.shape[1]):
cam+=layer1_grad[0,i].mul(layer1_fmap[0,i])
layer1_grad=layer1_grad.sum(1,keepdim=True) #layer1_grad.shape [1, 1, 128, 128]
layer1_fmap=layer1_fmap.sum(1,keepdim=True) # In order to unify in tensor2img Call in function
cam=cam.reshape((1,1,*cam.shape))
# Visualizing
img_np=tensor2img(img)
#layer1_fmap=torchvision.transforms.functional.resize(layer1_fmap,[224, 224])
layer1_grad_np=tensor2img(layer1_grad,heatmap=True,shape=(224,224))
layer1_fmap_np=tensor2img(layer1_fmap,heatmap=True,shape=(224,224))
cam_np=tensor2img(cam,heatmap=True,shape=(224,224))
print(" The deeper the color ( red ), Indicates that the larger the value of this area ")
myimshows([img_np,cam_np,cam_np*0.4+img_np*0.6],['image','cam','cam + image']) The execution output of the code is shown in Figure 2 Shown

边栏推荐
- Oracle/PLSQL: Translate Function
- Uninstallation of Dameng database
- pytorch 22 8种Dropout方法的简介 及 基于Dropout用4行代码快速实现DropBlock
- Introduction to stm32
- mmdetection 用yolox训练自己的coco数据集
- pytorch_grad_cam——pytorch下的模型特征(Class Activation Mapping, CAM)可视化库
- C language -- Design of employee information management system
- 我靠副业一个月挣了3W块:你看不起的行业,真的很挣钱!
- paddlepaddle 20 指数移动平均(ExponentialMovingAverage,EMA)的实现与使用(支持静态图与动态图)
- 执念斩长河暑期规划
猜你喜欢

SQLite reader plug-in tests SQLite syntax

Flink学习4:flink技术栈

Look! In June, 2022, the programming language ranking list was released! The first place is awesome

我靠副业一个月挣了3W块:你看不起的行业,真的很挣钱!

"All majors are persuading them to quit." is it actually the most friendly to college students?

Flink Learning 2: Application Scenarios

ConstraintLayout(约束布局)开发指南

解决cherry pick提交报错问题

Why pass SPIF_ Sendchange flag systemparametersinfo will hang?

CVPR2022 | PointDistiller:面向高效紧凑3D检测的结构化知识蒸馏
随机推荐
H5 liquid animation JS special effect code
Oracle/PLSQL: Rpad Function
memcached基础10
C语言--职工信息管理系统设计
Dameng database installation
平均风向风速计算(单位矢量法)
正则表达式:语法
Yalm 100b: 100billion parameter open source large model from yandex, Russia, allowing commercial use
Flink learning 5: how it works
ThreadLocal详解
[array] sword finger offer II 012 The sum of left and right subarrays is equal | sword finger offer II 013 Sum of two dimensional submatrix
pytorch_grad_cam——pytorch下的模型特征(Class Activation Mapping, CAM)可视化库
svg拖拽装扮Kitty猫
使用命令行安装达梦数据库
Oracle/PLSQL: NumToYMInterval Function
Laravel 的 ORM 缓存包
企业数字化转型:信息化与数字化
I earned 3W yuan a month from my sideline: the industry you despise really makes money!
Would rather go to 996 than stay at home! 24 years old, unemployed for 7 months, worse than work, no work
学习太极创客 — MQTT(七)MQTT 主题进阶