当前位置:网站首页>Numpy realizes the classification of iris by perceptron
Numpy realizes the classification of iris by perceptron
2022-07-03 10:38:00 【Machine learning Xiaobai】
import random
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
def sign(x):
if x >= 0:
return 1
else:
return -1
def get_loss(dataset,target,w):
results = [np.matmul(w.T,data) for data in dataset]
''' Predictive value '''
predict = [sign(i) for i in results]
''' Misclassification point '''
error_index = [i for i in range(len(target)) if target[i] != predict[i]]
print(' And then there were {} Error points '.format(len(error_index)))
''' Calculation loss'''
loss = 0
for i in error_index:
loss += results[i] * (predict[i] - target[i])
return loss
def get_w(dataset,target,w,learning_rate):
results = [np.matmul(w.T,data) for data in dataset]
''' Predictive value '''
predict = [sign(i) for i in results]
''' Misclassification point '''
error_index = [i for i in range(len(target)) if target[i] != predict[i]]
''' Select random misclassification points '''
index = random.choice(error_index)
''' Stochastic gradient descent '''
w = w - learning_rate * (predict[index] - target[index]) * dataset[index]
return w
if __name__ == '__main__':
''' Calyx length 、 Calyx width 、 Petal length 、 Petal width '''
'''0,1 Divisible , Select two categories '''
dataset, target = datasets.load_iris()['data'], datasets.load_iris()['target']
target = target[:100]
target = np.array([1 if i == 1 else -1 for i in target])
dataset = dataset[:100]
dataset = np.array([np.array(data[:2]) for data in dataset])
dataset = np.hstack((np.ones((dataset.shape[0], 1)), dataset))
w = np.random.rand(len(dataset[0]))
loss = get_loss(dataset,target,w)
print(loss)
epoch = 10000
learning_rate = 0.1
for i in range(epoch):
''' to update w'''
w = get_w(dataset,target,w,learning_rate)
''' Calculation loss'''
loss = get_loss(dataset,target,w)
print(' after {} iteration ,loss:{}'.format(i,loss))
if loss == 0:
break
print(w)
x_point = np.array([i[1] for i in dataset])
y_point = np.array([i[2] for i in dataset])
x = np.linspace(4,7.2,10)
y = (-w[1] * x - w[0])/w[2]
plt.scatter(x_point[:50],y_point[:50],label = '0')
plt.scatter(x_point[-50:],y_point[-50:],label = '1')
plt.legend() # Show Legend
plt.plot(x,y)
plt.show()
Here is a random gradient descent
jie The result of perceptron classification is not unique
边栏推荐
- 神经网络入门之矩阵计算(Pytorch)
- 实战篇:Oracle 数据库标准版(SE)转换为企业版(EE)
- Jetson TX2 brush machine
- Configure opencv in QT Creator
- 深度学习入门之自动求导(Pytorch)
- Hands on deep learning pytorch version exercise solution - 2.4 calculus
- 【吐槽&脑洞】关于逛B站时偶然体验的弹幕互动游戏魏蜀吴三国争霸游戏的一些思考
- 深度学习入门之线性回归(PyTorch)
- Knowledge map reasoning -- hybrid neural network and distributed representation reasoning
- 2018 Lenovo y7000 black apple external display scheme
猜你喜欢
Leetcode - 705 design hash set (Design)
Ut2016 learning notes
Timo background management system
Judging the connectivity of undirected graphs by the method of similar Union and set search
7、 Data definition language of MySQL (2)
Simple real-time gesture recognition based on OpenCV (including code)
【SQL】一篇带你掌握SQL数据库的查询与修改相关操作
六、MySQL之数据定义语言(一)
Adaptive Propagation Graph Convolutional Network
Ind kwf first week
随机推荐
Powshell's set location: unable to find a solution to the problem of accepting actual parameters
Hands on deep learning pytorch version exercise solution - 2.5 automatic differentiation
Leetcode刷题---1
Leetcode skimming ---75
Policy Gradient Methods of Deep Reinforcement Learning (Part Two)
Leetcode刷题---263
[graduation season] the picture is rich, and frugality is easy; Never forget chaos and danger in peace.
Ind FXL first week
8、 Transaction control language of MySQL
Leetcode刷题---44
【SQL】一篇带你掌握SQL数据库的查询与修改相关操作
Data classification: support vector machine
Neural Network Fundamentals (1)
mysql5.7安装和配置教程(图文超详细版)
Several problems encountered in installing MySQL under MAC system
Leetcode刷题---75
SQL Server Management Studio cannot be opened
Jetson TX2 刷机
Ind yff first week
Leetcode skimming ---852