当前位置:网站首页>t-sne 数据可视化网络中的部分参数+
t-sne 数据可视化网络中的部分参数+
2022-07-31 15:29:00 【FakeOccupational】
注:本代码主要实现对于网络中对于某个中间特征或计算得到的网络参数的可视化实现。如果仅可视化网络中的某个简单的参数,可以考虑使用 model.weight得到矩阵,然后放入分部代码中的降维可视化部分即可。TSNE参数说明
总体代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
class Network(nn.Module): # extend nn.Module class of nn
def __init__(self):
super().__init__() # super class constructor
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5))
self.batchN1 = nn.BatchNorm2d(num_features=6)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=(5, 5))
self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
self.batchN2 = nn.BatchNorm1d(num_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=60)
self.out = nn.Linear(in_features=60, out_features=10)
def forward(self, t): # implements the forward method (flow of tensors)
# hidden conv layer
t = self.conv1(t)
t = F.max_pool2d(input=t, kernel_size=2, stride=2)
t = F.relu(t)
t = self.batchN1(t)
# hidden conv layer
t = self.conv2(t)
t = F.max_pool2d(input=t, kernel_size=2, stride=2)
t = F.relu(t)
# flatten
t = t.reshape(-1, 12 * 4 * 4)
t = self.fc1(t)
t = F.relu(t)
t = self.batchN2(t)
t = self.fc2(t)
t = F.relu(t)
# output
t = self.out(t)
return t
cnn_model = Network() # init model
pretrained_dict = cnn_model.state_dict()
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5))
def forward(self, x):
return x
model = Identity()
model_dict = model.state_dict()
# model.load_state_dict(pretrained_dict) # RuntimeError: Error(s) in loading state_dict for Identity: Unexpected key(s) in state_dict: "batchN1.weight", "batchN1.bias", "batchN1.running_mean",
# 1. filter out unnecessary keys
pretrained_dict = {
k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(pretrained_dict)
vector = model.conv1.weight.detach().numpy()[0,0,:,:]
digits_final = TSNE(perplexity=30).fit_transform(vector) #
plt.scatter(digits_final[:,0], digits_final[:,1])
plt.show()
功能分部代码
模型处理部分
import torch
import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module): # extend nn.Module class of nn
def __init__(self):
super().__init__() # super class constructor
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5))
self.batchN1 = nn.BatchNorm2d(num_features=6)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=(5, 5))
self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
self.batchN2 = nn.BatchNorm1d(num_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=60)
self.out = nn.Linear(in_features=60, out_features=10)
def forward(self, t): # implements the forward method (flow of tensors)
# hidden conv layer
t = self.conv1(t)
t = F.max_pool2d(input=t, kernel_size=2, stride=2)
t = F.relu(t)
t = self.batchN1(t)
# hidden conv layer
t = self.conv2(t)
t = F.max_pool2d(input=t, kernel_size=2, stride=2)
t = F.relu(t)
# flatten
t = t.reshape(-1, 12 * 4 * 4)
t = self.fc1(t)
t = F.relu(t)
t = self.batchN2(t)
t = self.fc2(t)
t = F.relu(t)
# output
t = self.out(t)
return t
cnn_model = Network() # init model
pretrained_dict = cnn_model.state_dict()
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5))
def forward(self, x):
return x
model = Identity()
model_dict = model.state_dict()
# model.load_state_dict(pretrained_dict) # RuntimeError: Error(s) in loading state_dict for Identity: Unexpected key(s) in state_dict: "batchN1.weight", "batchN1.bias", "batchN1.running_mean",
# 1. filter out unnecessary keys
pretrained_dict = {
k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(pretrained_dict)
降维可视化部分
import numpy as np
import sklearn #Import scikitlearn for machine learning functionalities
from sklearn.manifold import TSNE
from sklearn.datasets import load_digits # For the UCI ML handwritten digits dataset
import matplotlib # Matplotlib 是 Python 中的一个库,它是 NumPy 库的数值数学扩展
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
import seaborn as sb
digits = load_digits()
print(digits.data.shape) # There are 10 classes (0 to 9) with alomst 180 images in each class
# The images are 8x8 and hence 64 pixels(dimensions)
plt.gray();
#Displaying what the standard images look like
for i in range(0,10):
plt.matshow(digits.images[i])
plt.show()
X = np.vstack([digits.data[digits.target==i] for i in range(10)]) # Place the arrays of data of each digit on top of each other and store in X
# X = np.random.random([1797, 64])
#Implementing the TSNE Function - ah Scikit learn makes it so easy!
digits_final = TSNE(perplexity=30).fit_transform(X) # plt.scatter(digits_final[0], digits_final[1])
#Play around with varying the parameters like perplexity, random_state to get different plots
# With the above line, our job is done. But why did we even reduce the dimensions in the first place?
# To visualise it on a graph.
# So, here is a utility function that helps to do a scatter plot of thee transformed data
def plot(x, colors):
palette = np.array(sb.color_palette("hls", 10)) # Choosing color palette
# Create a scatter plot.
f = plt.figure(figsize=(8, 8))
ax = plt.subplot(aspect='equal')
sc = ax.scatter(x[:, 0], x[:, 1], lw=0, s=40, c=palette[colors.astype(np.int)])
# 添加文本
txts = []
# for i in range(10):
# # Position of each label.
# xtext, ytext = np.median(x[colors == i, :], axis=0) # 返回数组元素的中位数。
# txt = ax.text(xtext, ytext, str(i), fontsize=24) # Text(6.610861, 37.19979, '9')
# txt.set_path_effects([pe.Stroke(linewidth=5, foreground="w"), pe.Normal()])# 文本效果
# txts.append(txt)
return f, ax, txts
Y = np.hstack([digits.target[digits.target==i] for i in range(10)]) # Place the arrays of data of each target digit by the side of each other continuosly and store in Y
plot(digits_final,Y)
plt.show()
t-sne 数据可视化的数学解释
参考与更多
pytorch官方保存加载模型教程(包含使用以 TorchScript 格式导出/加载模型)
模型字典的修改: https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/2
scikit-learn.org
https://github.com/shivanichander/tSNE/blob/master/Code/tSNE%20Code.ipynb
t-SNE:最好的降维方法之一 - 知乎 (zhihu.com)
https://discuss.pytorch.org/t/changing-state-dict-value-is-not-changing-model/88695/2
model = nn.Linear(1, 1)
print(model.weight)
# ISOMAP https://scikit-learn.org.cn/view/452.html
from sklearn.manifold import Isomap
digits_final = Isomap(n_components=2).fit_transform(res)
边栏推荐
- Synchronized和volatile 面试简单汇总
- R语言计算时间序列数据的移动平均值(滚动平均值、例如5日均线、10日均线等):使用zoo包中的rollmean函数计算k个周期移动平均值
- leetcode303 Weekly Match Replay
- R语言ggstatsplot包ggbarstats函数可视化条形图、并添加假设检验结果(包含样本数、统计量、效应大小及其置信区间、显著性、组间两两比较、贝叶斯假设)、检验结果报告符合APA标准
- 435. 无重叠区间
- How useful is four-quadrant time management?
- 删除表格数据或清空表格
- [CUDA study notes] First acquaintance with CUDA
- 数据表插入数据insert into
- Excel quickly aligns the middle name of the table (two-word name and three-word name alignment)
猜你喜欢
TRACE32——基于SNOOPer的变量记录
How useful is four-quadrant time management?
「秋招系列」MySQL面试核心25问(附答案)
abaqus find contact pairs报错:surface name is already in use
RecyclerView的高效使用第一节
STM32(十)------- SPI通信
Kubernetes principle analysis and practical application manual, too complete
全新宝马3系上市,安全、舒适一个不落
The use of border controls
AVH Deployment Practice (1) | Deploying the Flying Paddle Model on Arm Virtual Hardware
随机推荐
R language moves time series data forward or backward (custom lag or lead period): use the lag function in the dplyr package to move the time series data forward by one day (set the parameter n to a p
Linux查看redis版本(查看mongodb版本)
leetcode303 Weekly Match Replay
border控件的使用
what exactly is json (c# json)
JVM参数解析 Xmx、Xms、Xmn、NewRatio、SurvivorRatio、PermSize、PrintGC「建议收藏」
mysql black window ~ build database and build table
Efficient use of RecyclerView Section 2
Tencent Cloud Deployment----DevOps
RecyclerView高效使用第二节
Ubantu专题5:设置静态ip地址
R language ggplot2 visualization: use the ggmapplot function of the ggpubr package to visualize the MA plot (MA-plot), the font.legend parameter and the font.main parameter to set the title and legend
Ubuntu Topic 5: Setting a Static IP Address
最小费用最大流问题详解
Kubernetes common commands
Gorm—Go语言数据库框架
工程水文学名词解释总结
leetcode303场周赛复盘
ASP.NET Core 产生连续 Guid
对话庄表伟:开源第一课