当前位置:网站首页>PyTorch 提取中间层特征?
PyTorch 提取中间层特征?
2022-07-06 09:38:00 【小白学视觉】
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
来源:机器学习算法与自然语言处理
编辑:忆臻
https://www.zhihu.com/question/68384370
本文仅作为学术分享,如果侵权,会删文处理
PyTorch提取中间层特征?
作者:涩醉
https://www.zhihu.com/question/68384370/answer/751212803
通过pytorch的hook机制简单实现了一下,只输出conv层的特征图。
import torch
from torchvision.models import resnet18
import torch.nn as nn
from torchvision import transforms
import matplotlib.pyplot as plt
def viz(module, input):
x = input[0][0]
#最多显示4张图
min_num = np.minimum(4, x.size()[0])
for i in range(min_num):
plt.subplot(1, 4, i+1)
plt.imshow(x[i])
plt.show()
import cv2
import numpy as np
def main():
t = transforms.Compose([transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = resnet18(pretrained=True).to(device)
for name, m in model.named_modules():
# if not isinstance(m, torch.nn.ModuleList) and
# not isinstance(m, torch.nn.Sequential) and
# type(m) in torch.nn.__dict__.values():
# 这里只对卷积层的feature map进行显示
if isinstance(m, torch.nn.Conv2d):
m.register_forward_pre_hook(viz)
img = cv2.imread('/Users/edgar/Desktop/cat.jpeg')
img = t(img).unsqueeze(0).to(device)
with torch.no_grad():
model(img)
if __name__ == '__main__':
main()
打印的特征图大概是这个样子,取了第一层以及第四层的特征图。
作者:袁坤
https://www.zhihu.com/question/68384370/answer/419741762
建议使用hook,在不改变网络forward函数的基础上提取所需的特征或者梯度,在调用阶段对module使用即可获得所需梯度或者特征。
inter_feature = {}
inter_gradient = {}
def make_hook(name, flag):
if flag == 'forward':
def hook(m, input, output):
inter_feature[name] = input
return hook
elif flag == 'backward':
def hook(m, input, output):
inter_gradient[name] = output
return hook
else:
assert False
m.register_forward_hook(make_hook(name, 'forward'))
m.register_backward_hook(make_hook(name, 'backward'))
在前向计算和反向计算的时候即可达到类似钩子的作用,中间变量已经被放置于inter_feature 和 inter_gradient。
output = model(input) # achieve intermediate feature
loss = criterion(output, target)
loss.backward() # achieve backward intermediate gradients
最后可根据需求是否释放hook。
hook.remove()
作者:罗一成
https://www.zhihu.com/question/68384370/answer/263120790
提取中间特征是指把中间的weights给提出来吗?这样不是直接访问那个矩阵不就好了吗? pytorch在存参数的时候, 其实就是给所有的weights bias之类的起个名字然后存在了一个字典里面. 不然你看看state_dict.keys(), 找到相对应的key拿出来就好了.
然后你说的慎用也是一个很奇怪的问题啊..
就算用modules下面的class, 你存模型的时候因为你的activation function上面本身没有参数, 所以也不会被存进去. 不然你可以试试在Sequential里面把relu换成sigmoid, 你还是可以把之前存的state_dict给load回去.
不能说是慎用functional吧, 我觉得其他的设置是应该分开也存一份的(假设你把这些当做超参的话)
利益相关: 给pytorch提过PR
好消息!
小白学视觉知识星球
开始面向外开放啦
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
边栏推荐
- Xin'an Second Edition: Chapter 12 network security audit technology principle and application learning notes
- 遠程代碼執行滲透測試——B模塊測試
- Jetpack compose 1.1 release, based on kotlin's Android UI Toolkit
- Flink analysis (II): analysis of backpressure mechanism
- The NTFS format converter (convert.exe) is missing from the current system
- MySQL error reporting solution
- Automatic operation and maintenance sharp weapon ansible Playbook
- 【MySQL入门】第三话 · MySQL中常见的数据类型
- 05个人研发的产品及推广-数据同步工具
- List set data removal (list.sublist.clear)
猜你喜欢
pip install pyodbc : ERROR: Command errored out with exit status 1
Learn the wisdom of investment Masters
C# WinForm中DataGridView单元格显示图片
How does wechat prevent withdrawal come true?
Final review of information and network security (based on the key points given by the teacher)
05 personal R & D products and promotion - data synchronization tool
06个人研发的产品及推广-代码统计工具
03个人研发的产品及推广-计划服务配置器V3.0
[reverse intermediate] eager to try
The most complete tcpdump and Wireshark packet capturing practice in the whole network
随机推荐
Concept and basic knowledge of network layering
Shawshank's sense of redemption
mysql高级(索引,视图,存储过程,函数,修改密码)
04个人研发的产品及推广-数据推送工具
Kali2021 installation and basic configuration
Flink源码解读(二):JobGraph源码解读
Akamai talking about risk control principles and Solutions
EasyRE WriteUp
Flexible report v1.0 (simple version)
轻量级计划服务工具研发与实践
Application service configurator (regular, database backup, file backup, remote backup)
MySQL error reporting solution
复盘网鼎杯Re-Signal Writeup
Huawei certified cloud computing hica
C#版Selenium操作Chrome全屏模式显示(F11)
【逆向初级】独树一帜
TCP connection is more than communicating with TCP protocol
Flink源码解读(一):StreamGraph源码解读
Error: Publish of Process project to Orchestrator failed. The operation has timed out.
Interpretation of Flink source code (II): Interpretation of jobgraph source code