当前位置:网站首页>Statistical learning method (2/22) perceptron
Statistical learning method (2/22) perceptron
2022-06-29 01:11:00 【Xiaoshuai acridine】
Perceptron is a linear classification model of two classification , The input is the eigenvector of the instance , The output is the category of the instance , take +1 and -1 binary . The perceptron corresponds to the input space ( The feature space ) The examples are divided into positive and negative hyperplanes , It's a discriminant model . Perceptron learning aims at finding the separation hyperplane which divides the training data linearly , So , Import loss function based on misclassification , The loss function is minimized by gradient descent method , Get the perceptron model .
Eye of depth course link :https://ai.deepshare.net/detail/p_619b93d0e4b07ededa9fcca0/5
Code link :
https://github.com/zxs-000202/Statistical-Learning-Methods





The perceptron is a linear model , Processing linear data sets 


For a linearly separable data , Finally, there are no misclassified points 






The initial value of the parameter is 0

import numpy as np
import time
from tqdm import tqdm
def loadData(fileName):
''' load Mnist Data sets :param fileName: Data set path to load :return: list Data sets and tags in the form of '''
print('start to read data')
# Storage of data and marks list
dataArr = []; labelArr = []
# Open file
fr = open(fileName, 'r')
# Read the file by line
for line in tqdm(fr.readlines()):
# For each row of data, press the cut button ',' For cutting , Return to the list of fields
curLine = line.strip().split(',')
# Mnsit Yes 0-9 It's a sign , Because it's a two category task , So will >=5 As 1,<5 by -1
if int(curLine[0]) >= 5:
labelArr.append(1)
else:
labelArr.append(-1)
# Storage mark
#[int(num) for num in curLine[1:]] -> Traverse every line except for the first brother element ( Mark ) Convert all elements into int type
#[int(num)/255 for num in curLine[1:]] -> Divide all data by 255 normalization ( It's not necessary , It's OK not to )
dataArr.append([int(num)/255 for num in curLine[1:]])
# return data and label
return dataArr, labelArr
def perceptron(dataArr, labelArr, iter=50):
''' Perceptron training process :param dataArr: Training set data (list) :param labelArr: Training set label (list) :param iter: The number of iterations , Default 50 :return: Well trained w and b '''
print('start to trans')
# Convert data into matrix form ( In machine learning, it is usually the operation of vectors , The transformation is called matrix form, which is convenient for operation )
# The vector of each sample in the converted data is horizontal
dataMat = np.mat(dataArr)
# Convert labels into matrices , Then transpose (.T Transpose ).
# Transpose is because you need to take... Separately in the operation label One of the elements in , If it is 1xN Matrix of , No use label[i] Read by
# For only 1xN Of label Can not be converted into a matrix , direct label[i] that will do , The conversion here is for the unification of format
labelMat = np.mat(labelArr).T
# Get the size of the data matrix , by m*n
m, n = np.shape(dataMat)
# Create initial weights w, The initial values are all 0.
#np.shape(dataMat) The return value of is m,n -> np.shape(dataMat)[1]) The value is n, And
# The sample length is consistent
w = np.zeros((1, np.shape(dataMat)[1]))
# Initialize the bias b by 0
b = 0
# Initialization step , That is, in the process of gradient descent n, Control the gradient descent rate
h = 0.0001
# Conduct iter Iterations
for k in range(iter):
# For each sample, perform a gradient descent
# In Li Hang's book 2.3.1 Gradient descent used at the beginning , After all the samples have been counted , Unified
# Make a gradient descent
# stay 2.3.1 You can see the second half of ( For example, the formula 2.6 2.7), The summation symbol is gone , Use at this time
# Is a random gradient descent , That is, calculate a sample and perform a gradient descent for the sample .
# There are differences between the two , But random gradient descent is commonly used .
for i in range(m):
# Get the vector of the current sample
xi = dataMat[i]
# Get the tag corresponding to the current sample
yi = labelMat[i]
# Determine whether it is a misclassified sample
# The special diagnosis of misclassified samples is : -yi(w*xi+b)>=0, For details, please refer to 2.2.2 Section
# The formula in the book says >0, In fact, if =0, Explain that the change point is on the hyperplane , It's not right
if -1 * yi * (w * xi.T + b) >= 0:
# For misclassified samples , Make a gradient descent , to update w and b
w = w + h * yi * xi
b = b + h * yi
# Print training progress
print('Round %d:%d training' % (k, iter))
# Return to the end of training w、b
return w, b
def model_test(dataArr, labelArr, w, b):
''' Test accuracy :param dataArr: Test set :param labelArr: Test set label :param w: Weight gained from training w :param b: Training gains b :return: Accuracy rate '''
print('start to test')
# Convert the data set into matrix form for convenient operation
dataMat = np.mat(dataArr)
# take label Convert to matrix and transpose , Refer to the above for details perceptron in
# For the explanation of this part
labelMat = np.mat(labelArr).T
# Get the size of the test data set matrix
m, n = np.shape(dataMat)
# Error sample count
errorCnt = 0
# Traverse all test samples
for i in range(m):
# Obtain a single sample vector
xi = dataMat[i]
# Obtain the sample mark
yi = labelMat[i]
# Get the result
result = -1 * yi * (w * xi.T + b)
# If -yi(w*xi+b)>=0, It indicates that the sample is misclassified , Add one to the number of wrong samples
if result >= 0: errorCnt += 1
# Accuracy rate = 1 - ( Number of sample classification errors / The total number of samples )
accruRate = 1 - (errorCnt / m)
# Return the correct rate
return accruRate
if __name__ == '__main__':
# Get the current time
# Also get the current time at the end of the text , The time difference between the two is the program running time
start = time.time()
# Get training sets and tags
trainData, trainLabel = loadData('../Mnist/mnist_train.csv')
# Get test set and label
testData, testLabel = loadData('../Mnist/mnist_test.csv')
# Training to gain weight
w, b = perceptron(trainData, trainLabel, iter = 30)
# To test , Get the correct rate
accruRate = model_test(testData, testLabel, w, b)
# Get the current time , As the end time
end = time.time()
# Display accuracy
print('accuracy rate is:', accruRate)
# Display duration
print('time span:', end - start)
边栏推荐
- EasyCVR集群版本替换成老数据库造成的服务崩溃是什么原因?
- PR 2021 quick start tutorial, how to use audio editing in PR?
- 统计学习方法(2/22)感知机
- WPF 实现心电图曲线绘制
- [MCU club] design of GSM version of range hood based on MCU [physical design]
- Is l1-031 too fat (10 points)
- EdrawMax思维导图,EdrawMax组织结构图
- Notes on the infrastructure of large websites
- [Architect (Part 38)] locally install the latest version of MySQL database developed by the server
- [MCU club] design of classroom number detection based on MCU [simulation design]
猜你喜欢

戴口罩人臉數據集和戴口罩人臉生成方法

Large-scale case applications to developing post-click conversion rate estimation with MTL

《Reinforcement learning based parameters adaption method for particleswarm optimization》代码复现

Interviewer: with the for loop, why do you need foreach??
![[staff] accent mark, gradually stronger mark and gradually weaker mark](/img/5d/5738bd5503d7ed0621932f901c2e8d.jpg)
[staff] accent mark, gradually stronger mark and gradually weaker mark

Getting started with SQL

Breadth first search to catch cattle

盘点 6 月 yyds 的开源项目!

多维分析预汇总应该怎样做才管用?

接雨水系列问题
随机推荐
手把手教你搞懂测试环境项目部署
Day 7 scripts and special effects
EasyCVR接入Ehome协议的设备,无法观看设备录像是什么原因?
Seven mistakes in IT Governance and how to avoid them
最新Justnews主题源码6.0.1开心版+社交问答插件2.3.1+附教程
Different subsequence problems I
QT基於RFID管理系統(可應用於大多數RFID管理系統)
Nodejs安装和下载
接雨水系列问题
月薪过万的测试员,是一种什么样的生活状态?
有了这款工具,自动化识别验证码再也不是问题
Advanced Installer Architect创作工具
浏览器缓存库设计总结(localStorage/indexedDB)
统计学习方法(2/22)感知机
EdrawMax思维导图,EdrawMax组织结构图
After easycvr creates a new user, the video access page cannot be clicked. Fix the problem
What is the difference between immunohistochemistry and immunohistochemistry?
狼人杀休闲游戏微信小程序模板源码/微信小游戏源码
[js practice every m days] JS export object analysis based on libcef application (steam)
使用.Net驱动Jetson Nano的OLED显示屏