当前位置:网站首页>T - sne + data visualization parts of the network parameters
T - sne + data visualization parts of the network parameters
2022-07-31 15:38:00 【FakeOccupational】
注:This code mainly realizes the visualization of an intermediate feature or calculated network parameters in the network.If only a simple parameter in the network is visualized,可以考虑使用 model.weight得到矩阵,Then put it in the dimensionality reduction visualization part of the partial code.TSNE参数说明
总体代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
class Network(nn.Module): # extend nn.Module class of nn
def __init__(self):
super().__init__() # super class constructor
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5))
self.batchN1 = nn.BatchNorm2d(num_features=6)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=(5, 5))
self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
self.batchN2 = nn.BatchNorm1d(num_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=60)
self.out = nn.Linear(in_features=60, out_features=10)
def forward(self, t): # implements the forward method (flow of tensors)
# hidden conv layer
t = self.conv1(t)
t = F.max_pool2d(input=t, kernel_size=2, stride=2)
t = F.relu(t)
t = self.batchN1(t)
# hidden conv layer
t = self.conv2(t)
t = F.max_pool2d(input=t, kernel_size=2, stride=2)
t = F.relu(t)
# flatten
t = t.reshape(-1, 12 * 4 * 4)
t = self.fc1(t)
t = F.relu(t)
t = self.batchN2(t)
t = self.fc2(t)
t = F.relu(t)
# output
t = self.out(t)
return t
cnn_model = Network() # init model
pretrained_dict = cnn_model.state_dict()
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5))
def forward(self, x):
return x
model = Identity()
model_dict = model.state_dict()
# model.load_state_dict(pretrained_dict) # RuntimeError: Error(s) in loading state_dict for Identity: Unexpected key(s) in state_dict: "batchN1.weight", "batchN1.bias", "batchN1.running_mean",
# 1. filter out unnecessary keys
pretrained_dict = {
k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(pretrained_dict)
vector = model.conv1.weight.detach().numpy()[0,0,:,:]
digits_final = TSNE(perplexity=30).fit_transform(vector) #
plt.scatter(digits_final[:,0], digits_final[:,1])
plt.show()
Functional division code
Model processing part
import torch
import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module): # extend nn.Module class of nn
def __init__(self):
super().__init__() # super class constructor
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5))
self.batchN1 = nn.BatchNorm2d(num_features=6)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=(5, 5))
self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
self.batchN2 = nn.BatchNorm1d(num_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=60)
self.out = nn.Linear(in_features=60, out_features=10)
def forward(self, t): # implements the forward method (flow of tensors)
# hidden conv layer
t = self.conv1(t)
t = F.max_pool2d(input=t, kernel_size=2, stride=2)
t = F.relu(t)
t = self.batchN1(t)
# hidden conv layer
t = self.conv2(t)
t = F.max_pool2d(input=t, kernel_size=2, stride=2)
t = F.relu(t)
# flatten
t = t.reshape(-1, 12 * 4 * 4)
t = self.fc1(t)
t = F.relu(t)
t = self.batchN2(t)
t = self.fc2(t)
t = F.relu(t)
# output
t = self.out(t)
return t
cnn_model = Network() # init model
pretrained_dict = cnn_model.state_dict()
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5))
def forward(self, x):
return x
model = Identity()
model_dict = model.state_dict()
# model.load_state_dict(pretrained_dict) # RuntimeError: Error(s) in loading state_dict for Identity: Unexpected key(s) in state_dict: "batchN1.weight", "batchN1.bias", "batchN1.running_mean",
# 1. filter out unnecessary keys
pretrained_dict = {
k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(pretrained_dict)
Dimensionality reduction visualization part
import numpy as np
import sklearn #Import scikitlearn for machine learning functionalities
from sklearn.manifold import TSNE
from sklearn.datasets import load_digits # For the UCI ML handwritten digits dataset
import matplotlib # Matplotlib 是 Python 中的一个库,它是 NumPy Numerical math extensions to the library
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
import seaborn as sb
digits = load_digits()
print(digits.data.shape) # There are 10 classes (0 to 9) with alomst 180 images in each class
# The images are 8x8 and hence 64 pixels(dimensions)
plt.gray();
#Displaying what the standard images look like
for i in range(0,10):
plt.matshow(digits.images[i])
plt.show()
X = np.vstack([digits.data[digits.target==i] for i in range(10)]) # Place the arrays of data of each digit on top of each other and store in X
# X = np.random.random([1797, 64])
#Implementing the TSNE Function - ah Scikit learn makes it so easy!
digits_final = TSNE(perplexity=30).fit_transform(X) # plt.scatter(digits_final[0], digits_final[1])
#Play around with varying the parameters like perplexity, random_state to get different plots
# With the above line, our job is done. But why did we even reduce the dimensions in the first place?
# To visualise it on a graph.
# So, here is a utility function that helps to do a scatter plot of thee transformed data
def plot(x, colors):
palette = np.array(sb.color_palette("hls", 10)) # Choosing color palette
# Create a scatter plot.
f = plt.figure(figsize=(8, 8))
ax = plt.subplot(aspect='equal')
sc = ax.scatter(x[:, 0], x[:, 1], lw=0, s=40, c=palette[colors.astype(np.int)])
# 添加文本
txts = []
# for i in range(10):
# # Position of each label.
# xtext, ytext = np.median(x[colors == i, :], axis=0) # 返回数组元素的中位数.
# txt = ax.text(xtext, ytext, str(i), fontsize=24) # Text(6.610861, 37.19979, '9')
# txt.set_path_effects([pe.Stroke(linewidth=5, foreground="w"), pe.Normal()])# 文本效果
# txts.append(txt)
return f, ax, txts
Y = np.hstack([digits.target[digits.target==i] for i in range(10)]) # Place the arrays of data of each target digit by the side of each other continuosly and store in Y
plot(digits_final,Y)
plt.show()
t-sne Mathematical interpretation of data visualization
参考与更多
pytorchOfficial save and load model tutorial(contains the use of TorchScript 格式导出/加载模型)
Modification of the model dictionary: https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/2
scikit-learn.org
https://github.com/shivanichander/tSNE/blob/master/Code/tSNE%20Code.ipynb
t-SNE:One of the best dimensionality reduction methods - 知乎 (zhihu.com)
https://discuss.pytorch.org/t/changing-state-dict-value-is-not-changing-model/88695/2
model = nn.Linear(1, 1)
print(model.weight)
# ISOMAP https://scikit-learn.org.cn/view/452.html
from sklearn.manifold import Isomap
digits_final = Isomap(n_components=2).fit_transform(res)
边栏推荐
- Snake Project (Simple)
- Why don't you make a confession during the graduation season?
- org.apache.jasperException(could not initialize class org)
- 基于ABP实现DDD
- Linux check redis version (check mongodb version)
- 女性服务社群产品设计
- Oracle dynamically registers non-1521 ports
- "Autumn Recruitment Series" MySQL Interview Core 25 Questions (with answers)
- 多主复制下处理写冲突(4)-多主复制拓扑
- what exactly is json (c# json)
猜你喜欢
随机推荐
WeChat chat record search in a red envelope
双边滤波加速「建议收藏」
Oracle dynamically registers non-1521 ports
苹果官网样式调整 结账时产品图片“巨大化”
Dialogue with Zhuang Biaowei: The first lesson of open source
TRACE32 - Common Operations
Use of radiobutton
ES6 类
Why don't you make a confession during the graduation season?
Why is the field of hacking almost filled with boys?
多主复制下处理写冲突(4)-多主复制拓扑
The normal form of the database (first normal form, second normal form, third normal form, BCNF normal form) "recommended collection"
Emmet 语法
[Meetup Preview] OpenMLDB+OneFlow: Link feature engineering to model training to accelerate machine learning model development
MySQL数据库操作
org.apache.jasperException(could not initialize class org)
Replication Latency Case (3) - Monotonic Read
mongo进入报错
Bilateral filtering acceleration "recommended collection"
TRACE32——常用操作