当前位置:网站首页>PyTorch 提取中间层特征?
PyTorch 提取中间层特征?
2022-07-06 09:38:00 【小白学视觉】
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
来源:机器学习算法与自然语言处理
编辑:忆臻
https://www.zhihu.com/question/68384370
本文仅作为学术分享,如果侵权,会删文处理
PyTorch提取中间层特征?
作者:涩醉
https://www.zhihu.com/question/68384370/answer/751212803
通过pytorch的hook机制简单实现了一下,只输出conv层的特征图。
import torch
from torchvision.models import resnet18
import torch.nn as nn
from torchvision import transforms
import matplotlib.pyplot as plt
def viz(module, input):
x = input[0][0]
#最多显示4张图
min_num = np.minimum(4, x.size()[0])
for i in range(min_num):
plt.subplot(1, 4, i+1)
plt.imshow(x[i])
plt.show()
import cv2
import numpy as np
def main():
t = transforms.Compose([transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = resnet18(pretrained=True).to(device)
for name, m in model.named_modules():
# if not isinstance(m, torch.nn.ModuleList) and
# not isinstance(m, torch.nn.Sequential) and
# type(m) in torch.nn.__dict__.values():
# 这里只对卷积层的feature map进行显示
if isinstance(m, torch.nn.Conv2d):
m.register_forward_pre_hook(viz)
img = cv2.imread('/Users/edgar/Desktop/cat.jpeg')
img = t(img).unsqueeze(0).to(device)
with torch.no_grad():
model(img)
if __name__ == '__main__':
main()
打印的特征图大概是这个样子,取了第一层以及第四层的特征图。
作者:袁坤
https://www.zhihu.com/question/68384370/answer/419741762
建议使用hook,在不改变网络forward函数的基础上提取所需的特征或者梯度,在调用阶段对module使用即可获得所需梯度或者特征。
inter_feature = {}
inter_gradient = {}
def make_hook(name, flag):
if flag == 'forward':
def hook(m, input, output):
inter_feature[name] = input
return hook
elif flag == 'backward':
def hook(m, input, output):
inter_gradient[name] = output
return hook
else:
assert False
m.register_forward_hook(make_hook(name, 'forward'))
m.register_backward_hook(make_hook(name, 'backward'))
在前向计算和反向计算的时候即可达到类似钩子的作用,中间变量已经被放置于inter_feature 和 inter_gradient。
output = model(input) # achieve intermediate feature
loss = criterion(output, target)
loss.backward() # achieve backward intermediate gradients
最后可根据需求是否释放hook。
hook.remove()
作者:罗一成
https://www.zhihu.com/question/68384370/answer/263120790
提取中间特征是指把中间的weights给提出来吗?这样不是直接访问那个矩阵不就好了吗? pytorch在存参数的时候, 其实就是给所有的weights bias之类的起个名字然后存在了一个字典里面. 不然你看看state_dict.keys(), 找到相对应的key拿出来就好了.
然后你说的慎用也是一个很奇怪的问题啊..
就算用modules下面的class, 你存模型的时候因为你的activation function上面本身没有参数, 所以也不会被存进去. 不然你可以试试在Sequential里面把relu换成sigmoid, 你还是可以把之前存的state_dict给load回去.
不能说是慎用functional吧, 我觉得其他的设置是应该分开也存一份的(假设你把这些当做超参的话)
利益相关: 给pytorch提过PR
好消息!
小白学视觉知识星球
开始面向外开放啦
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
边栏推荐
- 04个人研发的产品及推广-数据推送工具
- Flink源码解读(一):StreamGraph源码解读
- [CISCN 2021 华南赛区]rsa Writeup
- C#WinForm中的dataGridView滚动条定位
- Jetpack compose 1.1 release, based on kotlin's Android UI Toolkit
- 关于Selenium启动Chrome浏览器闪退问题
- Automatic operation and maintenance sharp weapon ansible Playbook
- 04 products and promotion developed by individuals - data push tool
- 信息与网络安全期末复习(完整版)
- Xin'an Second Edition: Chapter 25 mobile application security requirements analysis and security protection engineering learning notes
猜你喜欢
案例:检查空字段【注解+反射+自定义异常】
Akamai talking about risk control principles and Solutions
Take you hand-in-hand to do intensive learning experiments -- knock the level in detail
关于Selenium启动Chrome浏览器闪退问题
数据仓库建模使用的模型以及分层介绍
JUnit unit test
Flink 解析(三):内存管理
Flink 解析(一):基础概念解析
Automatic operation and maintenance sharp weapon ansible Foundation
pip install pyodbc : ERROR: Command errored out with exit status 1
随机推荐
MySQL advanced (index, view, stored procedure, function, password modification)
List集合数据移除(List.subList.clear)
MySQL error reporting solution
Akamai浅谈风控原理与解决方案
[VNCTF 2022]ezmath wp
02 personal developed products and promotion - SMS platform
远程代码执行渗透测试——B模块测试
Xin'an Second Edition: Chapter 12 network security audit technology principle and application learning notes
mysql 基本增删改查SQL语句
Start job: operation returned an invalid status code 'badrequst' or 'forbidden‘
Serial serialold parnew of JVM garbage collector
Wordcloud colormap color set and custom colors
connection reset by peer
Akamai anti confusion
【逆向初级】独树一帜
Flink源码解读(二):JobGraph源码解读
Case: check the empty field [annotation + reflection + custom exception]
Xin'an Second Edition: Chapter 24 industrial control safety demand analysis and safety protection engineering learning notes
connection reset by peer
Vscode replaces commas, or specific characters with newlines