当前位置:网站首页>pytorch训练自己网络后可视化特征图谱的代码
pytorch训练自己网络后可视化特征图谱的代码
2022-07-01 21:47:00 【zouxiaolv】
网上的其他博客可视化特征图谱无法实现自己训练的网络可视化特征图谱
参考博客:可视化pytorch网络特征图_hello_dear_you的博客-CSDN博客_pytorch特征图可视化
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from torchvision.utils import make_grid, save_image
import argparse
import os
import pprint
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import _init_paths
from config import cfg
from config import update_config
from core.loss import JointsMSELoss
# from core.function import validate,validate_six_scale
from core.function import validate
from utils.utils import create_logger
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
# import torchvision.models as models
import dataset
import models
from PIL import Image
def parse_args():
parser = argparse.ArgumentParser(description='Train keypoints network')
parser.add_argument('--cfg',
help='experiment configure file name',
type=str,
default='/media/zxl/E/zxl/code/experiments/mpii/hgcpef/hg8_256x256_d256x3_adam_lr2.5e-4.yaml')
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--modelDir',
help='model directory',
type=str,
default='')
parser.add_argument('--logDir',
help='log directory',
type=str,
default='')
parser.add_argument('--dataDir',
help='data directory',
type=str,
default='')
parser.add_argument('--prevModelDir',
help='prev Model directory',
type=str,
default='')
args = parser.parse_args()
return args
# model
args = parse_args()
update_config(cfg, args)
logger, final_output_dir, tb_log_dir = create_logger(
cfg, args.cfg, 'valid')
logger.info(pprint.pformat(args))
logger.info(cfg)
# cudnn related setting
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
# get devices
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
# create model
model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(
cfg, is_train=False
)
# print('*******************',model)
if cfg.TEST.MODEL_FILE:
logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
else:
model_state_file = os.path.join(
final_output_dir, 'final_state.pth'
)
logger.info('=> loading model from {}'.format(model_state_file))
model.load_state_dict(torch.load(model_state_file))
net= torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
# net = models.vgg16_bn(pretrained=True).cuda()
# image pre-process
transforms_input = transforms.Compose([transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
fImg = Image.open("/media/zxl/E/zxl/000003072.jpg").convert('RGB')
data = transforms_input(fImg).unsqueeze(0).cuda()
# feature image save path
FEATURE_FOLDER = "/media/zxl/E/zxl/code/ffffff/"
if not os.path.exists(FEATURE_FOLDER):
os.mkdir(FEATURE_FOLDER)
# three global vatiable for feature image name
feature_list = list()
count = 0
idx = 0
def get_image_path_for_hook(module):
global count
image_name = feature_list[count] + ".png"
count += 1
image_path = os.path.join(FEATURE_FOLDER, image_name)
return image_path
def hook_func(module, input, output):
image_path = get_image_path_for_hook(module)
data = output.clone().detach()
global idx
print(idx, "->", data.shape)
idx += 1
data = data.data.permute(1, 0, 2, 3)
save_image(data, image_path, normalize=False)
for name, module in net.named_modules():
if isinstance(module, torch.nn.Conv2d):
print(name)
feature_list.append(name)
module.register_forward_hook(hook_func)
out = net(data)
边栏推荐
- 高攀不起的希尔排序,直接插入排序
- 园区全光技术选型-中篇
- 多种智能指针
- Which securities company should we choose to open an account for flush stock? Is it safe to open an account with a mobile phone?
- [commercial terminal simulation solution] Shanghai daoning brings you Georgia introduction, trial and tutorial
- Spark interview questions
- 牛客月赛-分组求对数和
- Delete AWS bound credit card account
- 深度学习--数据操作
- 配置筛选机
猜你喜欢
随机推荐
【c语言】malloc函数详解[通俗易懂]
A debugging to understand the slot mechanism of redis cluster
Separate the letters and numbers in the string so that the letters come first and the array comes last
Mysql——》MyISAM存储引擎的索引
RestTemplate 远程调用工具类
Recent public ancestor offline practice (tarjan)
twenty million two hundred and twenty thousand seven hundred and one
Redis configuration and optimization
【QT小作】封装一个简单的线程管理类
Slope compensation
灵动微 MM32 多路ADC-DMA配置
awoo‘s Favorite Problem(优先队列)
【juc学习之路第8天】Condition
并发编程系列之FutureTask源码学习笔记
20220701
Mask wearing detection method based on yolov5
Ida dynamic debugging apk
Little p weekly Vol.11
Basic knowledge of ngnix
FFMpeg学习笔记