当前位置:网站首页>Numpy realizes the classification of iris by perceptron
Numpy realizes the classification of iris by perceptron
2022-07-03 10:38:00 【Machine learning Xiaobai】
import random
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
def sign(x):
if x >= 0:
return 1
else:
return -1
def get_loss(dataset,target,w):
results = [np.matmul(w.T,data) for data in dataset]
''' Predictive value '''
predict = [sign(i) for i in results]
''' Misclassification point '''
error_index = [i for i in range(len(target)) if target[i] != predict[i]]
print(' And then there were {} Error points '.format(len(error_index)))
''' Calculation loss'''
loss = 0
for i in error_index:
loss += results[i] * (predict[i] - target[i])
return loss
def get_w(dataset,target,w,learning_rate):
results = [np.matmul(w.T,data) for data in dataset]
''' Predictive value '''
predict = [sign(i) for i in results]
''' Misclassification point '''
error_index = [i for i in range(len(target)) if target[i] != predict[i]]
''' Select random misclassification points '''
index = random.choice(error_index)
''' Stochastic gradient descent '''
w = w - learning_rate * (predict[index] - target[index]) * dataset[index]
return w
if __name__ == '__main__':
''' Calyx length 、 Calyx width 、 Petal length 、 Petal width '''
'''0,1 Divisible , Select two categories '''
dataset, target = datasets.load_iris()['data'], datasets.load_iris()['target']
target = target[:100]
target = np.array([1 if i == 1 else -1 for i in target])
dataset = dataset[:100]
dataset = np.array([np.array(data[:2]) for data in dataset])
dataset = np.hstack((np.ones((dataset.shape[0], 1)), dataset))
w = np.random.rand(len(dataset[0]))
loss = get_loss(dataset,target,w)
print(loss)
epoch = 10000
learning_rate = 0.1
for i in range(epoch):
''' to update w'''
w = get_w(dataset,target,w,learning_rate)
''' Calculation loss'''
loss = get_loss(dataset,target,w)
print(' after {} iteration ,loss:{}'.format(i,loss))
if loss == 0:
break
print(w)
x_point = np.array([i[1] for i in dataset])
y_point = np.array([i[2] for i in dataset])
x = np.linspace(4,7.2,10)
y = (-w[1] * x - w[0])/w[2]
plt.scatter(x_point[:50],y_point[:50],label = '0')
plt.scatter(x_point[-50:],y_point[-50:],label = '1')
plt.legend() # Show Legend
plt.plot(x,y)
plt.show()
Here is a random gradient descent
jie The result of perceptron classification is not unique
边栏推荐
- Hands on deep learning pytorch version exercise solution - 2.5 automatic differentiation
- Jetson TX2 刷机
- Softmax 回归(PyTorch)
- Step 1: teach you to trace the IP address of [phishing email]
- Hands on deep learning pytorch version exercise solution - 2.3 linear algebra
- Unity小组工程实践项目《最强外卖员》策划案&纠错文档
- 深度学习入门之线性代数(PyTorch)
- What useful materials have I learned from when installing QT
- Entropy method to calculate weight
- Leetcode skimming ---852
猜你喜欢
Hands on deep learning pytorch version exercise answer - 2.2 preliminary knowledge / data preprocessing
Ut2014 supplementary learning notes
Knowledge map enhancement recommendation based on joint non sampling learning
Leetcode - 706 design hash mapping (Design)*
八、MySQL之事务控制语言
Leetcode - 705 design hash set (Design)
【吐槽&脑洞】关于逛B站时偶然体验的弹幕互动游戏魏蜀吴三国争霸游戏的一些思考
Ind kwf first week
Linear regression of introduction to deep learning (pytorch)
Model evaluation and selection
随机推荐
Hands on deep learning pytorch version exercise solution - 3.1 linear regression
Ut2011 learning notes
Leetcode刷题---374
Leetcode刷题---35
安装yolov3(Anaconda)
Inverse code of string (Jilin University postgraduate entrance examination question)
Automatic derivation of introduction to deep learning (pytoch)
Leetcode刷题---852
Leetcode刷题---704
Leetcode刷题---75
C#项目-寝室管理系统(1)
Leetcode刷题---977
MySQL报错“Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggre”解决方法
Free online markdown to write a good resume
Secure in mysql8.0 under Windows_ file_ Priv is null solution
2018 Lenovo y7000 black apple external display scheme
Leetcode刷题---189
Drop out (pytoch)
Data preprocessing - Data Mining 1
Notes - regular expressions