当前位置:网站首页>Numpy realizes the classification of iris by perceptron
Numpy realizes the classification of iris by perceptron
2022-07-03 10:38:00 【Machine learning Xiaobai】
import random
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
def sign(x):
if x >= 0:
return 1
else:
return -1
def get_loss(dataset,target,w):
results = [np.matmul(w.T,data) for data in dataset]
''' Predictive value '''
predict = [sign(i) for i in results]
''' Misclassification point '''
error_index = [i for i in range(len(target)) if target[i] != predict[i]]
print(' And then there were {} Error points '.format(len(error_index)))
''' Calculation loss'''
loss = 0
for i in error_index:
loss += results[i] * (predict[i] - target[i])
return loss
def get_w(dataset,target,w,learning_rate):
results = [np.matmul(w.T,data) for data in dataset]
''' Predictive value '''
predict = [sign(i) for i in results]
''' Misclassification point '''
error_index = [i for i in range(len(target)) if target[i] != predict[i]]
''' Select random misclassification points '''
index = random.choice(error_index)
''' Stochastic gradient descent '''
w = w - learning_rate * (predict[index] - target[index]) * dataset[index]
return w
if __name__ == '__main__':
''' Calyx length 、 Calyx width 、 Petal length 、 Petal width '''
'''0,1 Divisible , Select two categories '''
dataset, target = datasets.load_iris()['data'], datasets.load_iris()['target']
target = target[:100]
target = np.array([1 if i == 1 else -1 for i in target])
dataset = dataset[:100]
dataset = np.array([np.array(data[:2]) for data in dataset])
dataset = np.hstack((np.ones((dataset.shape[0], 1)), dataset))
w = np.random.rand(len(dataset[0]))
loss = get_loss(dataset,target,w)
print(loss)
epoch = 10000
learning_rate = 0.1
for i in range(epoch):
''' to update w'''
w = get_w(dataset,target,w,learning_rate)
''' Calculation loss'''
loss = get_loss(dataset,target,w)
print(' after {} iteration ,loss:{}'.format(i,loss))
if loss == 0:
break
print(w)
x_point = np.array([i[1] for i in dataset])
y_point = np.array([i[2] for i in dataset])
x = np.linspace(4,7.2,10)
y = (-w[1] * x - w[0])/w[2]
plt.scatter(x_point[:50],y_point[:50],label = '0')
plt.scatter(x_point[-50:],y_point[-50:],label = '1')
plt.legend() # Show Legend
plt.plot(x,y)
plt.show()
Here is a random gradient descent
jie The result of perceptron classification is not unique
边栏推荐
猜你喜欢
Tensorflow - tensorflow Foundation
Configure opencv in QT Creator
7、 Data definition language of MySQL (2)
Linear regression of introduction to deep learning (pytorch)
Jetson TX2 brush machine
Class-Variant Margin Normalized Softmax Loss for Deep Face Recognition
Raspberry pie 4B deploys lnmp+tor and builds a website on dark web
Raspberry pie 4B installs yolov5 to achieve real-time target detection
Ind wks first week
Drop out (pytoch)
随机推荐
Ind FXL first week
Advantageous distinctive domain adaptation reading notes (detailed)
Hands on deep learning pytorch version exercise solution - 2.4 calculus
[LZY learning notes dive into deep learning] 3.4 3.6 3.7 softmax principle and Implementation
一步教你溯源【钓鱼邮件】的IP地址
【吐槽&脑洞】关于逛B站时偶然体验的弹幕互动游戏魏蜀吴三国争霸游戏的一些思考
Data preprocessing - Data Mining 1
ThreadLocal原理及使用场景
Linear regression of introduction to deep learning (pytorch)
Leetcode刷题---263
[graduation season] the picture is rich, and frugality is easy; Never forget chaos and danger in peace.
Leetcode刷题---10
多层感知机(PyTorch)
Leetcode刷题---1385
Knowledge map enhancement recommendation based on joint non sampling learning
六、MySQL之数据定义语言(一)
Leetcode skimming ---10
Hands on deep learning pytorch version exercise solution - 3.1 linear regression
Ut2016 learning notes
8、 Transaction control language of MySQL