当前位置:网站首页>SOM网络2: 代码的实现
SOM网络2: 代码的实现
2022-08-01 21:37:00 【@BangBang】
SOM自组织映射神经网络的原理,详见博客:SOM网络1:原理讲解
训练的主函数
train_SO代码如下:
def train_SOM(X, # 输出节点行数
Y, # 输出节点列数
N_epoch, # epoch
datas, # 训练数据(N x D) N个D维样本
init_lr=0.5, # 初始化学习率 lr
sigma = 0.5, # 初始化 sigma 用来更新领域节点权重
dis_func = euclidean_distance, # 距离公式 默认欧拉距离
neighborhood_func = gaussion_neighborhood, # 邻域节点权重公式g 默认高斯函数
init_weight_fun=None, #初始化权重函数
seed=10):
# 获取输入的特征维度
N,D =np.shape(datas)
# 训练的步数
N_steps =N_epoch*N
#对权重进行初始化
rng = np.random.RandomState(seed)
if init_weight_fun is None:
weights =rng.rand(X,Y,D)*2-1 #随机初始化
weights /=np.linalg.norm(weights,axis=-1,keepdims=True) #标准化
else:
weights = init_weight_fun(X,Y,datas) # 一般使用PCA初始化
PCA 初始化权重
def weights_PCA(X,Y,data):
N,D=np.shape(data)
weights=np.zeros([X,Y,D])
pc_value,pc=np.linalg.eig(np.conv(np.transpose(data))) # pc_vale为特征值,pc 为特征向量 DXD维
pc_order=np.argsort(-pc_value) # 特征值从大到小排序,并返回Index
# 对W:[X,Y,D]进行初始化
for i,c1 in enumerate(np.linspace(-1,1,X)):
for j,c2 in enumerate(np.linsapce(-1,1,Y)):
weights[i,j]=c1*pc[pc_order[0]]+c2*pc[pc_order[1]] #利用最大的2个特征值对应的特征向量加权组合成i,j位置的D维表征向量
完整的训练代码
def train_SOM(X, # 输出节点行数
Y, # 输出节点列数
N_epoch, # epoch
datas, # 训练数据(N x D) N个D维样本
init_lr=0.5, # 初始化学习率 lr
sigma = 0.5, # 初始化 sigma 用来更新领域节点权重
dis_func = euclidean_distance, # 距离公式 默认欧拉距离
neighborhood_func = gaussion_neighborhood, # 邻域节点权重公式g 默认高斯函数
init_weight_func=weights_PCA, #初始化权重函数
seed=10):
# 获取输入的特征维度
N,D =np.shape(datas)
# 训练的步数
N_steps =N_epoch*N
#对权重进行初始化
rng = np.random.RandomState(seed)
if init_weight_func is None:
weights =rng.rand(X,Y,D)*2-1 #随机初始化
weights /=np.linalg.norm(weights,axis=-1,keepdims=True) #标准化
else:
weights = init_weight_fun(X,Y,datas) # 一般使用PCA初始化
for n_epoch in range(N_epoch):
print("Epoch %d" %(n_epoch+1))
#打乱样本次序
index=rng.permulation(np.arange(N))
for n_step,_id in enumerate(index):
# 取一个样本
x=datas[_id]
#计算learning rate (eta)
t=N*n_epoch + n_step
eta=get_learning_rate(init_lr,t,N_steps)
#计算样本距离输出的每个节点的距离,并获取激活点的位置
winner=get_winner_index(x,weights,dis_func)
#根据激活点的位置计算临近点的权重 随着迭代的进行sigma也需要不断减少
new_sigma=get_learning_rate(sigma,t,N_steps) # sigma 更新的方式和学习率一样
g=neighborhood_fun(X,Y,winner,new_sigma)
g=g*eta
#进行权重的更新
weights = weights + np.expand_dims(g,-1)*(x-weights)
# 打印量化误差
print("quantization_error=%.4f" %(get_quantization_error(data,weights)))
return weights
#计算学习率
def get_learning_rate(lr,t,max_steps): # t当前的steps max_steps=N x epoch (N样本数)
return lr/(1+t/(max_steps/2))
# 获取激活(获胜点)节点的位置,与x距离最小的输出节点位置
def get_winner_index(x,w,dis_func=euclidean_distance):
# 计算输入样本和各个节点的距离
dis = dis_func(x,w)
#找到距离最小的位置
index=np.where(dis ==np.min(dis))
return (index[0][0],index[1][0])
#利用高斯距离法计算临近点的权重
# X,Y模板大小,c中心点的位置
def gaussion_neighborhood(X,Y,c,sigma)
xx,yy=np.meshgrid(np.arange(X),np.arange(Y))
d=2*sigma*sigma
ax=np.exp(-np.power(xx-xx.T[c],2)/d)
ay=np.exp(-np.power(yy-yy.T[c],2)/d)
return (ax*ay).T
# 计算欧式距离
def euclidean_distance(x,w):
dis=np.expand_dims(x,axis=(0,1))-w # x:D w:[X,Y,D] 因此需要增加两维 x:D->x:[1,1,D]
return np.linalg.norm(dis,axis=-1) # 输出[X,Y] 二范数 即为欧拉距离
# 特征标准化 (x-mu)/std
def feature_normalization(data):
mu=np.mean(data,axis=0,keepdims=True)
sigma=np.std(data,axis=0,keepdims=True)
return (data-mu)/sigma
def get_U_Matrix(weights):
X,Y,D=np.shape(weights)
um=na.nan * np.zeros((X,Y,8)) #8 领域
ii=[0 ,-1,-1,-1,0,1,1, 1]
jj=[-1,-1, 0, 1,1,1,0,-1]
for x in range(X):
for y in range(Y):
w_2=weights[x,y]
for k,(i,j) in enumerate(zip(ii,jj)):
if(x+i >=0 and x+i<X and y+j>=0 and y+j <Y):
w_1=weights[x+i,y+j]
um[x,y,k]=np.linalg.norm(w_1-w_2)
um=np.nansum(um,axis=2)
return um/um.max()
#计算量化误差 计算每个样本点和映射点之间的平均距离
def get_quantization_error(data,weights):
w_x,w_y=zip(*[get_winner_index(d,weights) for d in datas])
error=datas-weights[w_x,w_y] # 数据域聚类中心的距离
error=np.linalg.norm(error,axis=-1)
return np.mean(error)
训练完成后,返回输出节点的weights,维度为 [ X , Y , D ] [X,Y,D] [X,Y,D], 相当于固化了模型的权重weights, weights表征了当前的训练样本。
测试
if __name__ == "__main__":
# seed 数据展示
columns=['area','perimeter','compactness','length_kernel','width_kernel',
'asymmetry_coefficient','length_kernel_groove','target']
data = pd.read_csv('seeds_dataset.txt',names=columns,sep='\t+',engine='python')
labs=data['target'].values
lab_names={
1:'Kama',2:'Rosa',3:'Canadian'}
datas=data[data.columns[:-1]].values
N,D=np.shape(datas)
print(N,D)
# 对训练数据进行标准化
datas = feature_normalization(datas)
#SOM的训练
weights=train_SOM()X=9,Y=9,N_epoch=2,datas=datas,sigma=1.5,init_weight_func=weights_PCA)
# 获取UMAP 用于可视化
UM=get_U_Matrix(weights)
plt.figure(figure=(9,9))
plt.pcolor(UM.T,cmap='bone_r') #plotting the distance map as background
plt.colorbar()
测试数据
U_Matrix
- 颜色越深说明与邻近点的关系越强烈,颜色越强说明与邻近点的关系越不强烈。
测试分类的效果
```python
if __name__ == "__main__":
# seed 数据展示
columns=['area','perimeter','compactness','length_kernel','width_kernel',
'asymmetry_coefficient','length_kernel_groove','target']
data = pd.read_csv('seeds_dataset.txt',names=columns,sep='\t+',engine='python')
labs=data['target'].values
lab_names={
1:'Kama',2:'Rosa',3:'Canadian'}
datas=data[data.columns[:-1]].values
N,D=np.shape(datas)
print(N,D)
# 对训练数据进行标准化
datas = feature_normalization(datas)
#SOM的训练
weights=train_SOM()X=9,Y=9,N_epoch=2,datas=datas,sigma=1.5,init_weight_func=weights_PCA)
# 获取UMAP 用于可视化
UM=get_U_Matrix(weights)
plt.figure(figure=(9,9))
plt.pcolor(UM.T,cmap='bone_r') #plotting the distance map as background
plt.colorbar()
# 查看分类的效果
markers=['o','s','D']
colors =['C0','C1','C2']
for i in range(N):
x =datas[i]
w=get_winner_index(x,weights)
i_lab=labs[i]-1
plt.plot(w[0]+.5,w[1]+.5,markers[i_lab],markerfacecolor='None'
markeredgecolor=colors[i_lab],markersize=12,markeredgewidth=2)
plt.show()

边栏推荐
- ARFoundation Getting Started Tutorial U2-AR Scene Screenshot Screenshot
- shell specification and variables
- ARFoundation入门教程U2-AR场景截图截屏
- 【力扣】字符串相乘
- 365天挑战LeetCode1000题——Day 046 生成每种字符都是奇数个的字符串 + 两数相加 + 有效的括号
- ORI-GB-NP半乳糖介导冬凌草甲素/姜黄素牛血清白蛋白纳米粒的研究制备方法
- Small program -- subcontracting
- Chapter 12, target recognition of digital image processing
- 虚拟内存与物理内存之间的关系
- JVM内存结构详解
猜你喜欢
随机推荐
基于php旅游网站管理系统获取(php毕业设计)
shell programming conventions and variables
KMP 字符串匹配问题
The thing about npm
0DFS中等 LeetCode6134. 找到离给定两个节点最近的节点
Shell编程之条件语句
Chapter 12, target recognition of digital image processing
ARFoundation Getting Started Tutorial U2-AR Scene Screenshot Screenshot
Unity Shader general lighting model code finishing
方舟生存进化是什么游戏?好不好玩
教你VSCode如何快速对齐代码、格式化代码
软考 ----- UML设计与分析(上)
网络水军第一课:手写自动弹幕
scikit-learn no moudule named six
P7215 [JOISC2020] 首都 题解
File operations of WEB penetration
方舟:生存进化官服和私服区别
shell编程规范与变量
小程序--分包
上传markdown文档到博客园









