当前位置:网站首页>Machine learning perceptron and k-nearest neighbor
Machine learning perceptron and k-nearest neighbor
2022-06-24 10:10:00 【Cpsu】
One 、 perceptron
import numpy as np
from matplotlib import pyplot as plt
np.random.seed(3)
# In fact, we should use x1=np.linspace(0,5,50)
#X Axis dataset
x1=[i for i in np.arange(0,5,0.1)]
# Positive sample data set
x2=np.abs(np.random.randn(50))
# Negative sample data set
x3=np.abs(np.random.randn(50)+8)
plt.scatter(x1,x2)
plt.scatter(x1,x3)

# Use the above data to build a data matrix
# features 1
x_1=x1+x1
# features 2
x_2=list(x2)+list(x3)
# Label column
y=[1]*50+[-1]*50
data=np.zeros((100,3))
data=np.c_[x_1,x_2,y]
data.shape
#(100, 3)
class Perceptron():
""" :param data: ndarray Format data : N x P N individual P D data :param lr: Learning rate :param maxiter: Maximum number of iterations :param w_vect: Initial weight vector """
def __init__(self,data,lr,maxiter,w_vect):
self.data=data
self.w=w_vect
self.lr=lr
self.maxiter=maxiter
def get_wrong(self):
""" :return: Misclassification data set matrix and corresponding index of original matrix """
# Filter data
x=np.c_[self.data[:,:-1],np.ones(self.data.shape[0])]
wrong_index=np.where((x.dot(self.w)*((self.data[:,-1]).reshape(-1,1))<=0))[0]
return data[wrong_index,:],wrong_index
def fit(self):
for j in range(self.maxiter):
error=0
wrong_data,wrong_index=self.get_wrong()
#print(wrong_data)
x=np.c_[wrong_data[:,:-1],np.ones(wrong_data.shape[0])]
# Scrambling data sets to obtain different hyperplane solutions
np.random.shuffle(x)
for i in range(0,wrong_data.shape[0]):
gradient=((-wrong_data[i,-1:])*x[i,:]).reshape(-1,1)
self.w=self.w-self.lr*gradient
error+=1
#print(gradient.shape)
if error==0:
break
#w_vect=np.zeros((data.shape[1],1))
w_vect=np.array([[0],[0],[0]])
a=Perceptron(data,0.01,200,w_vect)
a.fit()
weights=a.w
w1 = weights[0][0]
w2 = weights[1][0]
bias = weights[-1][0]
print(a.w)
x6 = -w1 / w2 * np.array(x1) - bias / w2
plt.scatter(x1,x2)
plt.scatter(x1,x3)
plt.plot(x1,x6)

Two 、KNN
# establish kd Trees
import numpy as np
import matplotlib.pyplot as plt
class kdTree():
def __init__(self, parent_node):
# Node initialization
self.nodedata = None # The data value of the current node , Two dimensional data
self.split = None # Sequence number of the direction axis of the split plane ,0 Represents along x Axis segmentation ,1 Represents along y Axis segmentation
self.range = None # Split threshold
self.left = None # Left subtree node
self.right = None # Right subtree node
self.parent = parent_node # Parent node
self.leftdata = None # Keep all the data of the left node
self.rightdata = None # Keep all the data of the right node
self.isinvted = False # Record whether the current node has been accessed
def print(self):
# Print the current node information
print(self.nodedata, self.split, self.range)
def getSplitAxis(self, all_data):
# Determine the segmentation axis according to the variance
var_all_data = np.var(all_data, axis=0)
if var_all_data[0] > var_all_data[1]:
return 0
else:
return 1
def getRange(self, split_axis, all_data):
# Get the size of the median data value on the corresponding split axis
split_all_data = all_data[:, split_axis]
data_count = split_all_data.shape[0]
med_index = int(data_count/2)
sort_split_all_data = np.sort(split_all_data)
range_data = sort_split_all_data[med_index]
return range_data
def getNodeLeftRigthData(self, all_data):
# Divide the data into the left subtree , Right subtree and get the current node
data_count = all_data.shape[0]
ls_leftdata = []
ls_rightdata = []
for i in range(data_count):
now_data = all_data[i]
if now_data[self.split] < self.range:
ls_leftdata.append(now_data)
elif now_data[self.split] == self.range and self.nodedata == None:
self.nodedata = now_data
else:
ls_rightdata.append(now_data)
self.leftdata = np.array(ls_leftdata)
self.rightdata = np.array(ls_rightdata)
def createNextNode(self,all_data):
# Iteratively create nodes , Generate kd Trees
if all_data.shape[0] == 0:
print("create kd tree finished!")
return None
self.split = self.getSplitAxis(all_data)
self.range = self.getRange(self.split, all_data)
self.getNodeLeftRigthData(all_data)
if self.leftdata.shape[0] != 0:
self.left = kdTree(self)
self.left.createNextNode(self.leftdata)
if self.rightdata.shape[0] != 0:
self.right = kdTree(self)
self.right.createNextNode(self.rightdata)
def plotKdTree(self):
# Draw the recursive iteration process of tree structure on the graph
if self.parent == None:
plt.figure(dpi=300)
plt.xlim([0.0, 10.0])
plt.ylim([0.0, 10.0])
color = np.random.random(3)
if self.left != None:
plt.plot([self.nodedata[0], self.left.nodedata[0]],[self.nodedata[1], self.left.nodedata[1]], '-o', color=color)
plt.arrow(x=self.nodedata[0], y=self.nodedata[1], dx=(self.left.nodedata[0]-self.nodedata[0])/2.0, dy=(self.left.nodedata[1]-self.nodedata[1])/2.0, color=color, head_width=0.2)
self.left.plotKdTree()
if self.right != None:
plt.plot([self.nodedata[0], self.right.nodedata[0]],[self.nodedata[1], self.right.nodedata[1]], '-o', color=color)
plt.arrow(x=self.nodedata[0], y=self.nodedata[1], dx=(self.right.nodedata[0]-self.nodedata[0])/2.0, dy=(self.right.nodedata[1]-self.nodedata[1])/2.0, color=color, head_width=0.2)
self.right.plotKdTree()
# if self.split == 0:
# x = self.range
# plt.vlines(x, 0, 10, color=color, linestyles='--')
# else:
# y = self.range
# plt.hlines(y, 0, 10, color=color, linestyles='--')
test_array = 10.0*np.random.random([30,2])
my_kd_tree = kdTree(None)
my_kd_tree.createNextNode(test_array)
my_kd_tree.plotKdTree()
边栏推荐
- CVPR 2022 Oral | 英伟达提出自适应token的高效视觉Transformer网络A-ViT,不重要的token可以提前停止计算
- 被困英西中学的师生安全和食物有保障
- 美国电子烟巨头 Juul 遭遇灭顶之灾,所有产品强制下架
- Getting user information for applet learning (getuserprofile and getUserInfo)
- Amendment to VPP implementation policy routing
- 411 stack and queue (20. valid parentheses, 1047. delete all adjacent duplicates in the string, 150. inverse Polish expression evaluation, 239. sliding window maximum, 347. the first k high-frequency
- SQL-统计连续N天登陆的用户
- 时尚的弹出模态登录注册窗口
- Top issue tpami 2022! Behavior recognition based on different data modes: a recent review
- 一群骷髅在飞canvas动画js特效
猜你喜欢

美国电子烟巨头 Juul 遭遇灭顶之灾,所有产品强制下架

Three ways to use applicationcontextinitializer

indexedDB本地存储,首页优化

有关二叉树 的基本操作

canvas无限扫描js特效代码

队列Queue

ByteDance Interviewer: talk about the principle of audio and video synchronization. Can audio and video be absolutely synchronized?

Mise en œuvre du rendu de liste et du rendu conditionnel pour l'apprentissage des applets Wechat.

Geogebra instance clock

Phpstrom code formatting settings
随机推荐
415 binary tree (144. preorder traversal of binary tree, 145. postorder traversal of binary tree, 94. inorder traversal of binary tree)
整理接口性能优化技巧,干掉慢代码
Can the long-term financial products you buy be shortened?
读取csv(tsv)文件出错
Record the range of data that MySQL update will lock
413 binary tree Foundation
oracle池式连接请求超时问题排查步骤
涂鸦智能携多款重磅智能照明解决方案,亮相2022美国国际照明展
Recursive traversal of 414 binary tree
p5.js实现的炫酷交互式动画js特效
dedecms模板文件讲解以及首页标签替换
indexedDB本地存储,首页优化
静态链接库和动态链接库的区别
SQL Server AVG函数取整问题
小程序学习之获取用户信息(getUserProfile and getUserInfo)
正规方程、、、
二叉树第一部分
SQL-统计连续N天登陆的用户
一群骷髅在飞canvas动画js特效
学习整理在php中使用KindEditor富文本编辑器