当前位置:网站首页>Numpy realizes the classification of iris by perceptron
Numpy realizes the classification of iris by perceptron
2022-07-03 10:38:00 【Machine learning Xiaobai】
import random
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
def sign(x):
if x >= 0:
return 1
else:
return -1
def get_loss(dataset,target,w):
results = [np.matmul(w.T,data) for data in dataset]
''' Predictive value '''
predict = [sign(i) for i in results]
''' Misclassification point '''
error_index = [i for i in range(len(target)) if target[i] != predict[i]]
print(' And then there were {} Error points '.format(len(error_index)))
''' Calculation loss'''
loss = 0
for i in error_index:
loss += results[i] * (predict[i] - target[i])
return loss
def get_w(dataset,target,w,learning_rate):
results = [np.matmul(w.T,data) for data in dataset]
''' Predictive value '''
predict = [sign(i) for i in results]
''' Misclassification point '''
error_index = [i for i in range(len(target)) if target[i] != predict[i]]
''' Select random misclassification points '''
index = random.choice(error_index)
''' Stochastic gradient descent '''
w = w - learning_rate * (predict[index] - target[index]) * dataset[index]
return w
if __name__ == '__main__':
''' Calyx length 、 Calyx width 、 Petal length 、 Petal width '''
'''0,1 Divisible , Select two categories '''
dataset, target = datasets.load_iris()['data'], datasets.load_iris()['target']
target = target[:100]
target = np.array([1 if i == 1 else -1 for i in target])
dataset = dataset[:100]
dataset = np.array([np.array(data[:2]) for data in dataset])
dataset = np.hstack((np.ones((dataset.shape[0], 1)), dataset))
w = np.random.rand(len(dataset[0]))
loss = get_loss(dataset,target,w)
print(loss)
epoch = 10000
learning_rate = 0.1
for i in range(epoch):
''' to update w'''
w = get_w(dataset,target,w,learning_rate)
''' Calculation loss'''
loss = get_loss(dataset,target,w)
print(' after {} iteration ,loss:{}'.format(i,loss))
if loss == 0:
break
print(w)
x_point = np.array([i[1] for i in dataset])
y_point = np.array([i[2] for i in dataset])
x = np.linspace(4,7.2,10)
y = (-w[1] * x - w[0])/w[2]
plt.scatter(x_point[:50],y_point[:50],label = '0')
plt.scatter(x_point[-50:],y_point[-50:],label = '1')
plt.legend() # Show Legend
plt.plot(x,y)
plt.show()Here is a random gradient descent
jie The result of perceptron classification is not unique
边栏推荐
- Notes - regular expressions
- Hands on deep learning pytorch version exercise solution - 2.4 calculus
- 侯捷——STL源码剖析 笔记
- MySQL报错“Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggre”解决方法
- 一个30岁的测试员无比挣扎的故事,连躺平都是奢望
- OpenCV Error: Assertion failed (size.width>0 && size.height>0) in imshow
- Leetcode刷题---10
- Leetcode刷题---278
- 二分查找法
- Multi-Task Feature Learning for Knowledge Graph Enhanced Recommendation
猜你喜欢

Powshell's set location: unable to find a solution to the problem of accepting actual parameters

Training effects of different data sets (yolov5)

Mysql5.7 installation and configuration tutorial (Graphic ultra detailed version)

ThreadLocal原理及使用场景

一个30岁的测试员无比挣扎的故事,连躺平都是奢望

Judging the connectivity of undirected graphs by the method of similar Union and set search

Hands on deep learning pytorch version exercise solution - 2.3 linear algebra

Advantageous distinctive domain adaptation reading notes (detailed)

Install yolov3 (Anaconda)

深度学习入门之线性回归(PyTorch)
随机推荐
Hands on deep learning pytorch version exercise solution - 2.3 linear algebra
Hands on deep learning pytorch version exercise solution - 2.4 calculus
神经网络入门之模型选择(PyTorch)
CSDN, I'm coming!
Data classification: support vector machine
Ut2014 learning notes
Leetcode刷题---1
Leetcode skimming ---217
Step 1: teach you to trace the IP address of [phishing email]
High imitation bosom friend manke comic app
OpenCV Error: Assertion failed (size.width>0 && size.height>0) in imshow
Ind kwf first week
Leetcode - 705 design hash set (Design)
A complete mall system
Class-Variant Margin Normalized Softmax Loss for Deep Face Recognition
一步教你溯源【钓鱼邮件】的IP地址
Leetcode skimming ---202
Rewrite Boston house price forecast task (using paddlepaddlepaddle)
深度学习入门之线性回归(PyTorch)
MySQL报错“Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggre”解决方法