当前位置:网站首页>label propagation 标签传播
label propagation 标签传播
2022-07-02 06:26:00 【想搞钱的小陈】
最近在研究时间序列的半监督算法,看到这个算法,就记录了下来。
转载自:标签传播算法(Label Propagation)及Python实现_zouxy09的专栏-CSDN博客_标签传播算法
半监督学习(Semi-supervised learning)发挥作用的场合是:你的数据有一些有label,一些没有。而且一般是绝大部分都没有,只有少许几个有label。半监督学习算法会充分的利用unlabeled数据来捕捉我们整个数据的潜在分布。它基于三大假设:
1)Smoothness平滑假设:相似的数据具有相同的label。
2)Cluster聚类假设:处于同一个聚类下的数据具有相同label。
3)Manifold流形假设:处于同一流形结构下的数据具有相同label。
标签传播算法(label propagation)的核心思想非常简单:相似的数据应该具有相同的label。LP算法包括两大步骤:1)构造相似矩阵(affinity matrix);2)勇敢的传播吧。
label propagation是一种基于图的算法。图是基于顶点和边组成的,每个顶点是一个样本,所有的顶点包括了有标签样本和无标签样本;边代表了顶点i到顶点j的概率,换句话说就是顶点i到顶点j的相似度。
这里,α是超参。
还有个非常常用的图构建方法是knn图,也就是只保留每个节点的k近邻权重,其他的为0,也就是不存在边,因此是稀疏的相似矩阵。
标签传播算法非常简单:通过节点之间的边传播label。边的权重越大,表示两个节点越相似,那么label越容易传播过去。我们定义一个NxN的概率转移矩阵P:
Pij表示从节点i转移到节点j的概率。假设有C个类和L个labeled样本,我们定义一个LxC的label矩阵YL,第i行表示第i个样本的标签指示向量,即如果第i个样本的类别是j,那么该行的第j个元素为1,其他为0。同样,我们也给U个unlabeled样本一个UxC的label矩阵YU。把他们合并,我们得到一个NxC的soft label矩阵F=[YL;YU]。soft label的意思是,我们保留样本i属于每个类别的概率,而不是互斥性的,这个样本以概率1只属于一个类。当然了,最后确定这个样本i的类别的时候,是取max也就是概率最大的那个类作为它的类别的。那F里面有个YU,它一开始是不知道的,那最开始的值是多少?无所谓,随便设置一个值就可以了。
千呼万唤始出来,简单的LP算法如下:
1)执行传播:F=PF
2)重置F中labeled样本的标签:FL=YL
3)重复步骤1)和2)直到F收敛。
步骤1)就是将矩阵P和矩阵F相乘,这一步,每个节点都将自己的label以P确定的概率传播给其他节点。如果两个节点越相似(在欧式空间中距离越近),那么对方的label就越容易被自己的label赋予,就是更容易拉帮结派。步骤2)非常关键,因为labeled数据的label是事先确定的,它不能被带跑,所以每次传播完,它都得回归它本来的label。随着labeled数据不断的将自己的label传播出去,最后的类边界会穿越高密度区域,而停留在低密度的间隔中。相当于每个不同类别的labeled样本划分了势力范围。
2.3、变身的LP算法
我们知道,我们每次迭代都是计算一个soft label矩阵F=[YL;YU],但是YL是已知的,计算它没有什么用,在步骤2)的时候,还得把它弄回来。我们关心的只是YU,那我们能不能只计算YU呢?Yes。我们将矩阵P做以下划分:
这时候,我们的算法就一个运算:
迭代上面这个步骤直到收敛就ok了,是不是很cool。可以看到FU不但取决于labeled数据的标签及其转移概率,还取决了unlabeled数据的当前label和转移概率。因此LP算法能额外运用unlabeled数据的分布特点。
这个算法的收敛性也非常容易证明,具体见参考文献[1]。实际上,它是可以收敛到一个凸解的:
所以我们也可以直接这样求解,以获得最终的YU。但是在实际的应用过程中,由于矩阵求逆需要O(n3)的复杂度,所以如果unlabeled数据非常多,那么I – PUU矩阵的求逆将会非常耗时,因此这时候一般选择迭代算法来实现。
#***************************************************************************
#*
#* Description: label propagation
#* Author: Zou Xiaoyi ([email protected])
#* Date: 2015-10-15
#* HomePage: http://blog.csdn.net/zouxy09
#*
#**************************************************************************
import time
import numpy as np
# return k neighbors index
def navie_knn(dataSet, query, k):
numSamples = dataSet.shape[0]
## step 1: calculate Euclidean distance
diff = np.tile(query, (numSamples, 1)) - dataSet
squaredDiff = diff ** 2
squaredDist = np.sum(squaredDiff, axis = 1) # sum is performed by row
## step 2: sort the distance
sortedDistIndices = np.argsort(squaredDist)
if k > len(sortedDistIndices):
k = len(sortedDistIndices)
return sortedDistIndices[0:k]
# build a big graph (normalized weight matrix)
def buildGraph(MatX, kernel_type, rbf_sigma = None, knn_num_neighbors = None):
num_samples = MatX.shape[0]
affinity_matrix = np.zeros((num_samples, num_samples), np.float32)
if kernel_type == 'rbf':
if rbf_sigma == None:
raise ValueError('You should input a sigma of rbf kernel!')
for i in xrange(num_samples):
row_sum = 0.0
for j in xrange(num_samples):
diff = MatX[i, :] - MatX[j, :]
affinity_matrix[i][j] = np.exp(sum(diff**2) / (-2.0 * rbf_sigma**2))
row_sum += affinity_matrix[i][j]
affinity_matrix[i][:] /= row_sum
elif kernel_type == 'knn':
if knn_num_neighbors == None:
raise ValueError('You should input a k of knn kernel!')
for i in xrange(num_samples):
k_neighbors = navie_knn(MatX, MatX[i, :], knn_num_neighbors)
affinity_matrix[i][k_neighbors] = 1.0 / knn_num_neighbors
else:
raise NameError('Not support kernel type! You can use knn or rbf!')
return affinity_matrix
# label propagation
def labelPropagation(Mat_Label, Mat_Unlabel, labels, kernel_type = 'rbf', rbf_sigma = 1.5, \
knn_num_neighbors = 10, max_iter = 500, tol = 1e-3):
# initialize
num_label_samples = Mat_Label.shape[0]
num_unlabel_samples = Mat_Unlabel.shape[0]
num_samples = num_label_samples + num_unlabel_samples
labels_list = np.unique(labels)
num_classes = len(labels_list)
MatX = np.vstack((Mat_Label, Mat_Unlabel))
clamp_data_label = np.zeros((num_label_samples, num_classes), np.float32)
for i in xrange(num_label_samples):
clamp_data_label[i][labels[i]] = 1.0
label_function = np.zeros((num_samples, num_classes), np.float32)
label_function[0 : num_label_samples] = clamp_data_label
label_function[num_label_samples : num_samples] = -1
# graph construction
affinity_matrix = buildGraph(MatX, kernel_type, rbf_sigma, knn_num_neighbors)
# start to propagation
iter = 0; pre_label_function = np.zeros((num_samples, num_classes), np.float32)
changed = np.abs(pre_label_function - label_function).sum()
while iter < max_iter and changed > tol:
if iter % 1 == 0:
print "---> Iteration %d/%d, changed: %f" % (iter, max_iter, changed)
pre_label_function = label_function
iter += 1
# propagation
label_function = np.dot(affinity_matrix, label_function)
# clamp
label_function[0 : num_label_samples] = clamp_data_label
# check converge
changed = np.abs(pre_label_function - label_function).sum()
# get terminate label of unlabeled data
unlabel_data_labels = np.zeros(num_unlabel_samples)
for i in xrange(num_unlabel_samples):
unlabel_data_labels[i] = np.argmax(label_function[i+num_label_samples])
return unlabel_data_labels
边栏推荐
- 华为机试题-20190417
- A summary of a middle-aged programmer's study of modern Chinese history
- Conversion of numerical amount into capital figures in PHP
- 实现接口 Interface Iterable&lt;T&gt;
- 超时停靠视频生成
- [paper introduction] r-drop: regulated dropout for neural networks
- [introduction to information retrieval] Chapter 7 scoring calculation in search system
- 程序的执行
- Two table Association of pyspark in idea2020 (field names are the same)
- Three principles of architecture design
猜你喜欢
SSM laboratory equipment management
SSM student achievement information management system
架构设计三原则
机器学习理论学习:感知机
SSM supermarket order management system
生成模型与判别模型的区别与理解
常见的机器学习相关评价指标
基于pytorch的YOLOv5单张图片检测实现
view的绘制机制(一)
[medical] participants to medical ontologies: Content Selection for Clinical Abstract Summarization
随机推荐
Pratique et réflexion sur l'entrepôt de données hors ligne et le développement Bi
Agile development of software development pattern (scrum)
Conversion of numerical amount into capital figures in PHP
【信息检索导论】第二章 词项词典与倒排记录表
使用Matlab实现:弦截法、二分法、CG法,求零点、解方程
sparksql数据倾斜那些事儿
自然辩证辨析题整理
【论文介绍】R-Drop: Regularized Dropout for Neural Networks
view的绘制机制(一)
ERNIE1.0 与 ERNIE2.0 论文解读
Two dimensional array de duplication in PHP
Sparksql data skew
Implement interface Iterable & lt; T&gt;
程序的内存模型
view的绘制机制(二)
基于pytorch的YOLOv5单张图片检测实现
【信息检索导论】第三章 容错式检索
Implementation of purchase, sales and inventory system with ssm+mysql
Alpha Beta Pruning in Adversarial Search
Interpretation of ernie1.0 and ernie2.0 papers