当前位置:网站首页>PyTorch 提取中间层特征?
PyTorch 提取中间层特征?
2022-07-06 09:38:00 【小白学视觉】
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达来源:机器学习算法与自然语言处理
编辑:忆臻
https://www.zhihu.com/question/68384370
本文仅作为学术分享,如果侵权,会删文处理
PyTorch提取中间层特征?
作者:涩醉
https://www.zhihu.com/question/68384370/answer/751212803
通过pytorch的hook机制简单实现了一下,只输出conv层的特征图。
import torch
from torchvision.models import resnet18
import torch.nn as nn
from torchvision import transforms
import matplotlib.pyplot as plt
def viz(module, input):
x = input[0][0]
#最多显示4张图
min_num = np.minimum(4, x.size()[0])
for i in range(min_num):
plt.subplot(1, 4, i+1)
plt.imshow(x[i])
plt.show()
import cv2
import numpy as np
def main():
t = transforms.Compose([transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = resnet18(pretrained=True).to(device)
for name, m in model.named_modules():
# if not isinstance(m, torch.nn.ModuleList) and
# not isinstance(m, torch.nn.Sequential) and
# type(m) in torch.nn.__dict__.values():
# 这里只对卷积层的feature map进行显示
if isinstance(m, torch.nn.Conv2d):
m.register_forward_pre_hook(viz)
img = cv2.imread('/Users/edgar/Desktop/cat.jpeg')
img = t(img).unsqueeze(0).to(device)
with torch.no_grad():
model(img)
if __name__ == '__main__':
main()打印的特征图大概是这个样子,取了第一层以及第四层的特征图。


作者:袁坤
https://www.zhihu.com/question/68384370/answer/419741762
建议使用hook,在不改变网络forward函数的基础上提取所需的特征或者梯度,在调用阶段对module使用即可获得所需梯度或者特征。
inter_feature = {}
inter_gradient = {}
def make_hook(name, flag):
if flag == 'forward':
def hook(m, input, output):
inter_feature[name] = input
return hook
elif flag == 'backward':
def hook(m, input, output):
inter_gradient[name] = output
return hook
else:
assert False
m.register_forward_hook(make_hook(name, 'forward'))
m.register_backward_hook(make_hook(name, 'backward'))在前向计算和反向计算的时候即可达到类似钩子的作用,中间变量已经被放置于inter_feature 和 inter_gradient。
output = model(input) # achieve intermediate feature
loss = criterion(output, target)
loss.backward() # achieve backward intermediate gradients最后可根据需求是否释放hook。
hook.remove()作者:罗一成
https://www.zhihu.com/question/68384370/answer/263120790
提取中间特征是指把中间的weights给提出来吗?这样不是直接访问那个矩阵不就好了吗? pytorch在存参数的时候, 其实就是给所有的weights bias之类的起个名字然后存在了一个字典里面. 不然你看看state_dict.keys(), 找到相对应的key拿出来就好了.
然后你说的慎用也是一个很奇怪的问题啊..
就算用modules下面的class, 你存模型的时候因为你的activation function上面本身没有参数, 所以也不会被存进去. 不然你可以试试在Sequential里面把relu换成sigmoid, 你还是可以把之前存的state_dict给load回去.
不能说是慎用functional吧, 我觉得其他的设置是应该分开也存一份的(假设你把这些当做超参的话)
利益相关: 给pytorch提过PR
好消息!
小白学视觉知识星球
开始面向外开放啦

下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~边栏推荐
- Junit单元测试
- JVM 垃圾回收器之Garbage First
- Flink 解析(一):基础概念解析
- C# WinForm系列-Button简单使用
- Pyspark operator processing spatial data full parsing (4): let's talk about spatial operations first
- Precipitated database operation class - version C (SQL Server)
- Flink源码解读(二):JobGraph源码解读
- Flink analysis (II): analysis of backpressure mechanism
- Serial serialold parnew of JVM garbage collector
- [CISCN 2021 华南赛区]rsa Writeup
猜你喜欢

微信防撤回是怎么实现的?

Vscode matches and replaces the brackets

【MySQL入门】第四话 · 和kiko一起探索MySQL中的运算符

Akamai anti confusion

Deploy flask project based on LNMP

Chrome prompts the solution of "your company management" (the startup page is bound to the company's official website and cannot be modified)

Virtual machine startup prompt probing EDD (edd=off to disable) error

02个人研发的产品及推广-短信平台

Shawshank's sense of redemption

05 personal R & D products and promotion - data synchronization tool
随机推荐
PostgreSQL 14.2, 13.6, 12.10, 11.15 and 10.20 releases
Garbage first of JVM garbage collector
Yarn: unable to load file d:\programfiles\nodejs\yarn PS1, because running scripts is prohibited on this system
yarn : 无法加载文件 D:\ProgramFiles\nodejs\yarn.ps1,因为在此系统上禁止运行脚本
[reverse] repair IAT and close ASLR after shelling
Automatic operation and maintenance sharp weapon ansible Foundation
06个人研发的产品及推广-代码统计工具
Flink 解析(三):内存管理
Flink parsing (III): memory management
05个人研发的产品及推广-数据同步工具
[reverse primary] Unique
BearPi-HM_ Nano development board "flower protector" case
Flink 解析(二):反压机制解析
EasyRE WriteUp
Flink 解析(一):基础概念解析
[ciscn 2021 South China]rsa writeup
04个人研发的产品及推广-数据推送工具
Take you hand-in-hand to do intensive learning experiments -- knock the level in detail
mysql的列的数据类型详解
Pyspark operator processing spatial data full parsing (4): let's talk about spatial operations first