当前位置:网站首页>Pytorch's code for visualizing feature maps after training its own network
Pytorch's code for visualizing feature maps after training its own network
2022-07-01 22:41:00 【zouxiaolv】
Other blog visual feature maps on the Internet cannot realize the network visual feature map trained by ourselves
Reference blog : visualization pytorch Network feature map _hello_dear_you The blog of -CSDN Blog _pytorch Visual feature map
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from torchvision.utils import make_grid, save_image
import argparse
import os
import pprint
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import _init_paths
from config import cfg
from config import update_config
from core.loss import JointsMSELoss
# from core.function import validate,validate_six_scale
from core.function import validate
from utils.utils import create_logger
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
# import torchvision.models as models
import dataset
import models
from PIL import Image
def parse_args():
parser = argparse.ArgumentParser(description='Train keypoints network')
parser.add_argument('--cfg',
help='experiment configure file name',
type=str,
default='/media/zxl/E/zxl/code/experiments/mpii/hgcpef/hg8_256x256_d256x3_adam_lr2.5e-4.yaml')
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--modelDir',
help='model directory',
type=str,
default='')
parser.add_argument('--logDir',
help='log directory',
type=str,
default='')
parser.add_argument('--dataDir',
help='data directory',
type=str,
default='')
parser.add_argument('--prevModelDir',
help='prev Model directory',
type=str,
default='')
args = parser.parse_args()
return args
# model
args = parse_args()
update_config(cfg, args)
logger, final_output_dir, tb_log_dir = create_logger(
cfg, args.cfg, 'valid')
logger.info(pprint.pformat(args))
logger.info(cfg)
# cudnn related setting
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
# get devices
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
# create model
model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(
cfg, is_train=False
)
# print('*******************',model)
if cfg.TEST.MODEL_FILE:
logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
else:
model_state_file = os.path.join(
final_output_dir, 'final_state.pth'
)
logger.info('=> loading model from {}'.format(model_state_file))
model.load_state_dict(torch.load(model_state_file))
net= torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
# net = models.vgg16_bn(pretrained=True).cuda()
# image pre-process
transforms_input = transforms.Compose([transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
fImg = Image.open("/media/zxl/E/zxl/000003072.jpg").convert('RGB')
data = transforms_input(fImg).unsqueeze(0).cuda()
# feature image save path
FEATURE_FOLDER = "/media/zxl/E/zxl/code/ffffff/"
if not os.path.exists(FEATURE_FOLDER):
os.mkdir(FEATURE_FOLDER)
# three global vatiable for feature image name
feature_list = list()
count = 0
idx = 0
def get_image_path_for_hook(module):
global count
image_name = feature_list[count] + ".png"
count += 1
image_path = os.path.join(FEATURE_FOLDER, image_name)
return image_path
def hook_func(module, input, output):
image_path = get_image_path_for_hook(module)
data = output.clone().detach()
global idx
print(idx, "->", data.shape)
idx += 1
data = data.data.permute(1, 0, 2, 3)
save_image(data, image_path, normalize=False)
for name, module in net.named_modules():
if isinstance(module, torch.nn.Conv2d):
print(name)
feature_list.append(name)
module.register_forward_hook(hook_func)
out = net(data)
边栏推荐
- 记录一次spark on yarn 任务报错 Operation category READ is not supported in state standby
- Mysql——》Innodb存储引擎的索引
- 陈天奇的机器学习编译课(免费)
- SAP UI5 应用开发教程之一百零四 - SAP UI5 表格控件的支持复选(Multi-Select)以及如何用代码一次选中多个表格行项目
- PyTorch磨刀篇|argmax和argmin函数
- Awoo's favorite problem (priority queue)
- Basic knowledge of ngnix
- QStringList 的常规使用
- 切面条 C语言
- MySQL中对于事务的理解
猜你喜欢
91.(cesium篇)cesium火箭发射模拟
2020-ViT ICLR
功能测试报告的编写
2020-ViT ICLR
Sonic cloud real machine learning summary 6 - 1.4.1 server and agent deployment
[intelligent QBD risk assessment tool] Shanghai daoning brings you leanqbd introduction, trial and tutorial
【JetCache】JetCache的使用方法与步骤
Is PMP certificate really useful?
Slope compensation
EasyExcel 复杂数据导出
随机推荐
447-哔哩哔哩面经1
固定资产管理子系统报表分为什么大类,包括哪些科目
There is no signal in HDMI in computer games caused by memory, so it crashes
LC501. 二叉搜索树中的众数
Can you get a raise? Analysis on gold content of PMP certificate
Clean up system cache and free memory under Linux
基准环路增益与相位裕度的测量
Spark interview questions
Flume interview questions
2020-ViT ICLR
awoo‘s Favorite Problem(优先队列)
3DE 资源没东西或不对
地图其他篇总目录
# CutefishOS系统~
利用SecureCRTPortable远程连接虚拟机
Measurement of reference loop gain and phase margin
How to write a performance test plan
100年仅6款产品获批,疫苗竞争背后的“佐剂”江湖
Slope compensation
多种智能指针