当前位置:网站首页>Pytorch's code for visualizing feature maps after training its own network
Pytorch's code for visualizing feature maps after training its own network
2022-07-01 22:41:00 【zouxiaolv】
Other blog visual feature maps on the Internet cannot realize the network visual feature map trained by ourselves
Reference blog : visualization pytorch Network feature map _hello_dear_you The blog of -CSDN Blog _pytorch Visual feature map
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from torchvision.utils import make_grid, save_image
import argparse
import os
import pprint
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import _init_paths
from config import cfg
from config import update_config
from core.loss import JointsMSELoss
# from core.function import validate,validate_six_scale
from core.function import validate
from utils.utils import create_logger
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
# import torchvision.models as models
import dataset
import models
from PIL import Image
def parse_args():
parser = argparse.ArgumentParser(description='Train keypoints network')
parser.add_argument('--cfg',
help='experiment configure file name',
type=str,
default='/media/zxl/E/zxl/code/experiments/mpii/hgcpef/hg8_256x256_d256x3_adam_lr2.5e-4.yaml')
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--modelDir',
help='model directory',
type=str,
default='')
parser.add_argument('--logDir',
help='log directory',
type=str,
default='')
parser.add_argument('--dataDir',
help='data directory',
type=str,
default='')
parser.add_argument('--prevModelDir',
help='prev Model directory',
type=str,
default='')
args = parser.parse_args()
return args
# model
args = parse_args()
update_config(cfg, args)
logger, final_output_dir, tb_log_dir = create_logger(
cfg, args.cfg, 'valid')
logger.info(pprint.pformat(args))
logger.info(cfg)
# cudnn related setting
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
# get devices
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
# create model
model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(
cfg, is_train=False
)
# print('*******************',model)
if cfg.TEST.MODEL_FILE:
logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
else:
model_state_file = os.path.join(
final_output_dir, 'final_state.pth'
)
logger.info('=> loading model from {}'.format(model_state_file))
model.load_state_dict(torch.load(model_state_file))
net= torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
# net = models.vgg16_bn(pretrained=True).cuda()
# image pre-process
transforms_input = transforms.Compose([transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
fImg = Image.open("/media/zxl/E/zxl/000003072.jpg").convert('RGB')
data = transforms_input(fImg).unsqueeze(0).cuda()
# feature image save path
FEATURE_FOLDER = "/media/zxl/E/zxl/code/ffffff/"
if not os.path.exists(FEATURE_FOLDER):
os.mkdir(FEATURE_FOLDER)
# three global vatiable for feature image name
feature_list = list()
count = 0
idx = 0
def get_image_path_for_hook(module):
global count
image_name = feature_list[count] + ".png"
count += 1
image_path = os.path.join(FEATURE_FOLDER, image_name)
return image_path
def hook_func(module, input, output):
image_path = get_image_path_for_hook(module)
data = output.clone().detach()
global idx
print(idx, "->", data.shape)
idx += 1
data = data.data.permute(1, 0, 2, 3)
save_image(data, image_path, normalize=False)
for name, module in net.named_modules():
if isinstance(module, torch.nn.Conv2d):
print(name)
feature_list.append(name)
module.register_forward_hook(hook_func)
out = net(data)
边栏推荐
猜你喜欢

固定资产管理子系统报表分为什么大类,包括哪些科目

In the past 100 years, only 6 products have been approved, which is the "adjuvant" behind the vaccine competition

Indicator trap: seven KPI mistakes that it leaders are prone to make

高攀不起的希尔排序,直接插入排序

内部字段分隔符

2020-ViT ICLR

删除AWS绑定的信用卡账户
![[jetcache] how to use jetcache](/img/fa/5b3abe53bb7e9db6af2dbb1cb76a31.png)
[jetcache] how to use jetcache

Slope compensation

IDA动态调试apk
随机推荐
MySQL stored procedure
2020-ViT ICLR
JVM有哪些类加载机制?
台积电全球员工薪酬中位数约46万,CEO约899万;苹果上调日本的 iPhone 售价 ;Vim 9.0 发布|极客头条
IDA动态调试apk
SAP UI5 应用开发教程之一百零四 - SAP UI5 表格控件的支持复选(Multi-Select)以及如何用代码一次选中多个表格行项目
倒置残差的理解
MySQL view exercise
Configure filter
Appium自动化测试基础 — 补充:Desired Capabilities参数介绍
Communication between browser tab pages
flink sql-client 使用 对照并熟悉官方文档
Slope compensation
功能测试报告的编写
awoo‘s Favorite Problem(优先队列)
447-哔哩哔哩面经1
Indicator trap: seven KPI mistakes that it leaders are prone to make
Appium自动化测试基础 — APPium安装(一)
Relationship and difference between enterprise architecture and project management
Pytorch sharpening chapter | argmax and argmin functions