当前位置:网站首页>pytorch训练自己网络后可视化特征图谱的代码
pytorch训练自己网络后可视化特征图谱的代码
2022-07-01 21:47:00 【zouxiaolv】
网上的其他博客可视化特征图谱无法实现自己训练的网络可视化特征图谱
参考博客:可视化pytorch网络特征图_hello_dear_you的博客-CSDN博客_pytorch特征图可视化
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from torchvision.utils import make_grid, save_image
import argparse
import os
import pprint
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import _init_paths
from config import cfg
from config import update_config
from core.loss import JointsMSELoss
# from core.function import validate,validate_six_scale
from core.function import validate
from utils.utils import create_logger
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
# import torchvision.models as models
import dataset
import models
from PIL import Image
def parse_args():
parser = argparse.ArgumentParser(description='Train keypoints network')
parser.add_argument('--cfg',
help='experiment configure file name',
type=str,
default='/media/zxl/E/zxl/code/experiments/mpii/hgcpef/hg8_256x256_d256x3_adam_lr2.5e-4.yaml')
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--modelDir',
help='model directory',
type=str,
default='')
parser.add_argument('--logDir',
help='log directory',
type=str,
default='')
parser.add_argument('--dataDir',
help='data directory',
type=str,
default='')
parser.add_argument('--prevModelDir',
help='prev Model directory',
type=str,
default='')
args = parser.parse_args()
return args
# model
args = parse_args()
update_config(cfg, args)
logger, final_output_dir, tb_log_dir = create_logger(
cfg, args.cfg, 'valid')
logger.info(pprint.pformat(args))
logger.info(cfg)
# cudnn related setting
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
# get devices
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
# create model
model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(
cfg, is_train=False
)
# print('*******************',model)
if cfg.TEST.MODEL_FILE:
logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
else:
model_state_file = os.path.join(
final_output_dir, 'final_state.pth'
)
logger.info('=> loading model from {}'.format(model_state_file))
model.load_state_dict(torch.load(model_state_file))
net= torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
# net = models.vgg16_bn(pretrained=True).cuda()
# image pre-process
transforms_input = transforms.Compose([transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
fImg = Image.open("/media/zxl/E/zxl/000003072.jpg").convert('RGB')
data = transforms_input(fImg).unsqueeze(0).cuda()
# feature image save path
FEATURE_FOLDER = "/media/zxl/E/zxl/code/ffffff/"
if not os.path.exists(FEATURE_FOLDER):
os.mkdir(FEATURE_FOLDER)
# three global vatiable for feature image name
feature_list = list()
count = 0
idx = 0
def get_image_path_for_hook(module):
global count
image_name = feature_list[count] + ".png"
count += 1
image_path = os.path.join(FEATURE_FOLDER, image_name)
return image_path
def hook_func(module, input, output):
image_path = get_image_path_for_hook(module)
data = output.clone().detach()
global idx
print(idx, "->", data.shape)
idx += 1
data = data.data.permute(1, 0, 2, 3)
save_image(data, image_path, normalize=False)
for name, module in net.named_modules():
if isinstance(module, torch.nn.Conv2d):
print(name)
feature_list.append(name)
module.register_forward_hook(hook_func)
out = net(data)
边栏推荐
猜你喜欢
【图像分割】2021-SegFormer NeurIPS
MySQL MHA high availability configuration and failover
EasyExcel 复杂数据导出
三翼鸟两周年:羽翼渐丰,腾飞指日可待
GenICam GenTL 标准 ver1.5(4)第五章 采集引擎
Learning notes on futuretask source code of concurrent programming series
比较版本号[双指针截取自己想要的字串]
In the past 100 years, only 6 products have been approved, which is the "adjuvant" behind the vaccine competition
Sonic cloud real machine learning summary 6 - 1.4.1 server and agent deployment
leetcode - 287. 寻找重复数
随机推荐
【MySQL】索引的创建、查看和删除
【日常训练】326. 3 的幂
使用 Three.js 实现'雪糕'地球,让地球也凉爽一夏
【图像分割】2021-SegFormer NeurIPS
C#/VB.NET 给PDF文档添加文本/图像水印
[ecological partner] Kunpeng system engineer training
flink sql-client 使用 对照并熟悉官方文档
[jetcache] how to use jetcache
[intelligent QBD risk assessment tool] Shanghai daoning brings you leanqbd introduction, trial and tutorial
首席信息官对高绩效IT团队定义的探讨和分析
完全注解的ssm框架搭建
详解Volatile关键字
Copy ‘XXXX‘ to effectively final temp variable
【MySQL】索引的分类
Flume interview questions
JVM有哪些类加载机制?
2020-ViT ICLR
并发编程系列之FutureTask源码学习笔记
LC669. 修剪二叉搜索树
Indicator trap: seven KPI mistakes that it leaders are prone to make