当前位置:网站首页>Pytorch's code for visualizing feature maps after training its own network
Pytorch's code for visualizing feature maps after training its own network
2022-07-01 22:41:00 【zouxiaolv】
Other blog visual feature maps on the Internet cannot realize the network visual feature map trained by ourselves
Reference blog : visualization pytorch Network feature map _hello_dear_you The blog of -CSDN Blog _pytorch Visual feature map
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from torchvision.utils import make_grid, save_image
import argparse
import os
import pprint
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import _init_paths
from config import cfg
from config import update_config
from core.loss import JointsMSELoss
# from core.function import validate,validate_six_scale
from core.function import validate
from utils.utils import create_logger
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
# import torchvision.models as models
import dataset
import models
from PIL import Image
def parse_args():
parser = argparse.ArgumentParser(description='Train keypoints network')
parser.add_argument('--cfg',
help='experiment configure file name',
type=str,
default='/media/zxl/E/zxl/code/experiments/mpii/hgcpef/hg8_256x256_d256x3_adam_lr2.5e-4.yaml')
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--modelDir',
help='model directory',
type=str,
default='')
parser.add_argument('--logDir',
help='log directory',
type=str,
default='')
parser.add_argument('--dataDir',
help='data directory',
type=str,
default='')
parser.add_argument('--prevModelDir',
help='prev Model directory',
type=str,
default='')
args = parser.parse_args()
return args
# model
args = parse_args()
update_config(cfg, args)
logger, final_output_dir, tb_log_dir = create_logger(
cfg, args.cfg, 'valid')
logger.info(pprint.pformat(args))
logger.info(cfg)
# cudnn related setting
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
# get devices
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
# create model
model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(
cfg, is_train=False
)
# print('*******************',model)
if cfg.TEST.MODEL_FILE:
logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
else:
model_state_file = os.path.join(
final_output_dir, 'final_state.pth'
)
logger.info('=> loading model from {}'.format(model_state_file))
model.load_state_dict(torch.load(model_state_file))
net= torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
# net = models.vgg16_bn(pretrained=True).cuda()
# image pre-process
transforms_input = transforms.Compose([transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
fImg = Image.open("/media/zxl/E/zxl/000003072.jpg").convert('RGB')
data = transforms_input(fImg).unsqueeze(0).cuda()
# feature image save path
FEATURE_FOLDER = "/media/zxl/E/zxl/code/ffffff/"
if not os.path.exists(FEATURE_FOLDER):
os.mkdir(FEATURE_FOLDER)
# three global vatiable for feature image name
feature_list = list()
count = 0
idx = 0
def get_image_path_for_hook(module):
global count
image_name = feature_list[count] + ".png"
count += 1
image_path = os.path.join(FEATURE_FOLDER, image_name)
return image_path
def hook_func(module, input, output):
image_path = get_image_path_for_hook(module)
data = output.clone().detach()
global idx
print(idx, "->", data.shape)
idx += 1
data = data.data.permute(1, 0, 2, 3)
save_image(data, image_path, normalize=False)
for name, module in net.named_modules():
if isinstance(module, torch.nn.Conv2d):
print(name)
feature_list.append(name)
module.register_forward_hook(hook_func)
out = net(data)
边栏推荐
- 详解LockSupport的使用
- Copy ‘XXXX‘ to effectively final temp variable
- Measurement of reference loop gain and phase margin
- PyTorch磨刀篇|argmax和argmin函数
- Why must digital transformation strategies include continuous testing?
- [jetcache] how to use jetcache
- Learning notes on futuretask source code of concurrent programming series
- 利用SecureCRTPortable远程连接虚拟机
- Operation category read is not supported in state standby
- Mysql——》MyISAM存储引擎的索引
猜你喜欢
随机推荐
【日常训练】66. 加一
JVM有哪些类加载机制?
旅游管理系统
【MySQL】数据库优化方法
Learning notes on futuretask source code of concurrent programming series
比较版本号[双指针截取自己想要的字串]
Measurement of reference loop gain and phase margin
分享一个一年经历两次裁员的程序员的一些感触
Relationship and difference between enterprise architecture and project management
台积电全球员工薪酬中位数约46万,CEO约899万;苹果上调日本的 iPhone 售价 ;Vim 9.0 发布|极客头条
基准环路增益与相位裕度的测量
H5 model trained by keras to tflite
并发编程系列之FutureTask源码学习笔记
Slope compensation
MySQL的视图练习题
陈天奇的机器学习编译课(免费)
深度学习--数据操作
三翼鸟两周年:羽翼渐丰,腾飞指日可待
详解Volatile关键字
cvpr2022 human pose estiamtion









