当前位置:网站首页>Numpy realizes the classification of iris by perceptron
Numpy realizes the classification of iris by perceptron
2022-07-03 10:38:00 【Machine learning Xiaobai】
import random
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
def sign(x):
if x >= 0:
return 1
else:
return -1
def get_loss(dataset,target,w):
results = [np.matmul(w.T,data) for data in dataset]
''' Predictive value '''
predict = [sign(i) for i in results]
''' Misclassification point '''
error_index = [i for i in range(len(target)) if target[i] != predict[i]]
print(' And then there were {} Error points '.format(len(error_index)))
''' Calculation loss'''
loss = 0
for i in error_index:
loss += results[i] * (predict[i] - target[i])
return loss
def get_w(dataset,target,w,learning_rate):
results = [np.matmul(w.T,data) for data in dataset]
''' Predictive value '''
predict = [sign(i) for i in results]
''' Misclassification point '''
error_index = [i for i in range(len(target)) if target[i] != predict[i]]
''' Select random misclassification points '''
index = random.choice(error_index)
''' Stochastic gradient descent '''
w = w - learning_rate * (predict[index] - target[index]) * dataset[index]
return w
if __name__ == '__main__':
''' Calyx length 、 Calyx width 、 Petal length 、 Petal width '''
'''0,1 Divisible , Select two categories '''
dataset, target = datasets.load_iris()['data'], datasets.load_iris()['target']
target = target[:100]
target = np.array([1 if i == 1 else -1 for i in target])
dataset = dataset[:100]
dataset = np.array([np.array(data[:2]) for data in dataset])
dataset = np.hstack((np.ones((dataset.shape[0], 1)), dataset))
w = np.random.rand(len(dataset[0]))
loss = get_loss(dataset,target,w)
print(loss)
epoch = 10000
learning_rate = 0.1
for i in range(epoch):
''' to update w'''
w = get_w(dataset,target,w,learning_rate)
''' Calculation loss'''
loss = get_loss(dataset,target,w)
print(' after {} iteration ,loss:{}'.format(i,loss))
if loss == 0:
break
print(w)
x_point = np.array([i[1] for i in dataset])
y_point = np.array([i[2] for i in dataset])
x = np.linspace(4,7.2,10)
y = (-w[1] * x - w[0])/w[2]
plt.scatter(x_point[:50],y_point[:50],label = '0')
plt.scatter(x_point[-50:],y_point[-50:],label = '1')
plt.legend() # Show Legend
plt.plot(x,y)
plt.show()Here is a random gradient descent
jie The result of perceptron classification is not unique
边栏推荐
- MySQL报错“Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggre”解决方法
- Content type ‘application/x-www-form-urlencoded; charset=UTF-8‘ not supported
- Several problems encountered in installing MySQL under MAC system
- Leetcode - the k-th element in 703 data flow (design priority queue)
- 八、MySQL之事务控制语言
- Configure opencv in QT Creator
- [LZY learning notes dive into deep learning] 3.5 image classification dataset fashion MNIST
- Matrix calculation of Neural Network Introduction (pytoch)
- Hands on deep learning pytorch version exercise solution - 2.4 calculus
- I really want to be a girl. The first step of programming is to wear women's clothes
猜你喜欢

Hands on deep learning pytorch version exercise solution - 2.3 linear algebra

Leetcode - the k-th element in 703 data flow (design priority queue)

Knowledge map reasoning -- hybrid neural network and distributed representation reasoning

丢弃法Dropout(Pytorch)

Hands on deep learning pytorch version exercise solution - 3.1 linear regression

Secure in mysql8.0 under Windows_ file_ Priv is null solution

High imitation Netease cloud music

安装yolov3(Anaconda)

Rewrite Boston house price forecast task (using paddlepaddlepaddle)

【吐槽&脑洞】关于逛B站时偶然体验的弹幕互动游戏魏蜀吴三国争霸游戏的一些思考
随机推荐
【吐槽&脑洞】关于逛B站时偶然体验的弹幕互动游戏魏蜀吴三国争霸游戏的一些思考
Policy Gradient Methods of Deep Reinforcement Learning (Part Two)
Leetcode刷题---202
[LZY learning notes dive into deep learning] 3.1-3.3 principle and implementation of linear regression
7、 Data definition language of MySQL (2)
Softmax regression (pytorch)
多层感知机(PyTorch)
Neural Network Fundamentals (1)
熵值法求权重
Leetcode刷题---1
Matrix calculation of Neural Network Introduction (pytoch)
Multilayer perceptron (pytorch)
一步教你溯源【钓鱼邮件】的IP地址
Hands on deep learning pytorch version exercise solution - 2.5 automatic differentiation
Ut2015 learning notes
Leetcode skimming ---283
Stroke prediction: Bayesian
Tensorflow - tensorflow Foundation
Ut2013 learning notes
Leetcode刷题---852