当前位置:网站首页>PyTorch 提取中间层特征?
PyTorch 提取中间层特征?
2022-07-06 09:38:00 【小白学视觉】
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
来源:机器学习算法与自然语言处理
编辑:忆臻
https://www.zhihu.com/question/68384370
本文仅作为学术分享,如果侵权,会删文处理
PyTorch提取中间层特征?
作者:涩醉
https://www.zhihu.com/question/68384370/answer/751212803
通过pytorch的hook机制简单实现了一下,只输出conv层的特征图。
import torch
from torchvision.models import resnet18
import torch.nn as nn
from torchvision import transforms
import matplotlib.pyplot as plt
def viz(module, input):
x = input[0][0]
#最多显示4张图
min_num = np.minimum(4, x.size()[0])
for i in range(min_num):
plt.subplot(1, 4, i+1)
plt.imshow(x[i])
plt.show()
import cv2
import numpy as np
def main():
t = transforms.Compose([transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = resnet18(pretrained=True).to(device)
for name, m in model.named_modules():
# if not isinstance(m, torch.nn.ModuleList) and
# not isinstance(m, torch.nn.Sequential) and
# type(m) in torch.nn.__dict__.values():
# 这里只对卷积层的feature map进行显示
if isinstance(m, torch.nn.Conv2d):
m.register_forward_pre_hook(viz)
img = cv2.imread('/Users/edgar/Desktop/cat.jpeg')
img = t(img).unsqueeze(0).to(device)
with torch.no_grad():
model(img)
if __name__ == '__main__':
main()
打印的特征图大概是这个样子,取了第一层以及第四层的特征图。


作者:袁坤
https://www.zhihu.com/question/68384370/answer/419741762
建议使用hook,在不改变网络forward函数的基础上提取所需的特征或者梯度,在调用阶段对module使用即可获得所需梯度或者特征。
inter_feature = {}
inter_gradient = {}
def make_hook(name, flag):
if flag == 'forward':
def hook(m, input, output):
inter_feature[name] = input
return hook
elif flag == 'backward':
def hook(m, input, output):
inter_gradient[name] = output
return hook
else:
assert False
m.register_forward_hook(make_hook(name, 'forward'))
m.register_backward_hook(make_hook(name, 'backward'))
在前向计算和反向计算的时候即可达到类似钩子的作用,中间变量已经被放置于inter_feature 和 inter_gradient。
output = model(input) # achieve intermediate feature
loss = criterion(output, target)
loss.backward() # achieve backward intermediate gradients
最后可根据需求是否释放hook。
hook.remove()
作者:罗一成
https://www.zhihu.com/question/68384370/answer/263120790
提取中间特征是指把中间的weights给提出来吗?这样不是直接访问那个矩阵不就好了吗? pytorch在存参数的时候, 其实就是给所有的weights bias之类的起个名字然后存在了一个字典里面. 不然你看看state_dict.keys(), 找到相对应的key拿出来就好了.
然后你说的慎用也是一个很奇怪的问题啊..
就算用modules下面的class, 你存模型的时候因为你的activation function上面本身没有参数, 所以也不会被存进去. 不然你可以试试在Sequential里面把relu换成sigmoid, 你还是可以把之前存的state_dict给load回去.
不能说是慎用functional吧, 我觉得其他的设置是应该分开也存一份的(假设你把这些当做超参的话)
利益相关: 给pytorch提过PR
好消息!
小白学视觉知识星球
开始面向外开放啦
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
边栏推荐
- How uipath determines that an object is null
- Openharmony developer documentation open source project
- Xin'an Second Edition: Chapter 12 network security audit technology principle and application learning notes
- 案例:检查空字段【注解+反射+自定义异常】
- Error: Publish of Process project to Orchestrator failed. The operation has timed out.
- 【逆向初级】独树一帜
- Program counter of JVM runtime data area
- 遠程代碼執行滲透測試——B模塊測試
- February database ranking: how long can Oracle remain the first?
- [CISCN 2021 华南赛区]rsa Writeup
猜你喜欢
Uipath browser performs actions in the new tab
Solr appears write Lock, solrexception: could not get leader props in the log
Akamai浅谈风控原理与解决方案
CTF逆向入门题——掷骰子
06 products and promotion developed by individuals - code statistical tools
【逆向初级】独树一帜
Automatic operation and maintenance sharp weapon ansible Foundation
JVM garbage collector part 2
Chrome prompts the solution of "your company management" (the startup page is bound to the company's official website and cannot be modified)
JUnit unit test
随机推荐
【Elastic】Elastic缺少xpack无法创建模板 unknown setting index.lifecycle.name index.lifecycle.rollover_alias
Xin'an Second Edition: Chapter 23 cloud computing security requirements analysis and security protection engineering learning notes
Based on infragistics Document. Excel export table class
Automatic operation and maintenance sharp weapon ansible Foundation
【MySQL入门】第一话 · 初入“数据库”大陆
Wu Jun's trilogy insight (V) refusing fake workers
[reverse primary] Unique
Redis quick start
集成开发管理平台
【逆向中级】跃跃欲试
关于Selenium启动Chrome浏览器闪退问题
04 products and promotion developed by individuals - data push tool
[mmdetection] solves the installation problem
Program counter of JVM runtime data area
Flink parsing (III): memory management
Selenium test of automatic answer runs directly in the browser, just like real users.
Redis快速入门
Integrated development management platform
【逆向初级】独树一帜
信息与网络安全期末复习(基于老师给的重点)