当前位置:网站首页>统计学习方法(2/22)感知机
统计学习方法(2/22)感知机
2022-06-29 01:06:00 【小帅吖】
感知机是二分类的线性分类模型,其输入为实例的特征向量,输出为实例的类别,取+1和-1二值。感知机对应于输入空间(特征空间)中将实例划分为正负两类的分离超平面,属于判别模型。感知机学习旨在求出将训练数据进行线性划分的分离超平面,为此,导入基于误分类的损失函数,利用梯度下降法对损失函数进行极小化,求得感知机模型。
深度之眼课程链接:https://ai.deepshare.net/detail/p_619b93d0e4b07ededa9fcca0/5
代码链接:
https://github.com/zxs-000202/Statistical-Learning-Methods





感知机是一个线性模型,处理线性数据集


针对一个线性可分数据据,最后没有误分类的点






参数初始值为0

import numpy as np
import time
from tqdm import tqdm
def loadData(fileName):
''' 加载Mnist数据集 :param fileName:要加载的数据集路径 :return: list形式的数据集及标记 '''
print('start to read data')
# 存放数据及标记的list
dataArr = []; labelArr = []
# 打开文件
fr = open(fileName, 'r')
# 将文件按行读取
for line in tqdm(fr.readlines()):
# 对每一行数据按切割福','进行切割,返回字段列表
curLine = line.strip().split(',')
# Mnsit有0-9是个标记,由于是二分类任务,所以将>=5的作为1,<5为-1
if int(curLine[0]) >= 5:
labelArr.append(1)
else:
labelArr.append(-1)
#存放标记
#[int(num) for num in curLine[1:]] -> 遍历每一行中除了以第一哥元素(标记)外将所有元素转换成int类型
#[int(num)/255 for num in curLine[1:]] -> 将所有数据除255归一化(非必须步骤,可以不归一化)
dataArr.append([int(num)/255 for num in curLine[1:]])
#返回data和label
return dataArr, labelArr
def perceptron(dataArr, labelArr, iter=50):
''' 感知器训练过程 :param dataArr:训练集的数据 (list) :param labelArr: 训练集的标签(list) :param iter: 迭代次数,默认50 :return: 训练好的w和b '''
print('start to trans')
#将数据转换成矩阵形式(在机器学习中因为通常都是向量的运算,转换称矩阵形式方便运算)
#转换后的数据中每一个样本的向量都是横向的
dataMat = np.mat(dataArr)
#将标签转换成矩阵,之后转置(.T为转置)。
#转置是因为在运算中需要单独取label中的某一个元素,如果是1xN的矩阵的话,无法用label[i]的方式读取
#对于只有1xN的label可以不转换成矩阵,直接label[i]即可,这里转换是为了格式上的统一
labelMat = np.mat(labelArr).T
#获取数据矩阵的大小,为m*n
m, n = np.shape(dataMat)
#创建初始权重w,初始值全为0。
#np.shape(dataMat)的返回值为m,n -> np.shape(dataMat)[1])的值即为n,与
#样本长度保持一致
w = np.zeros((1, np.shape(dataMat)[1]))
#初始化偏置b为0
b = 0
#初始化步长,也就是梯度下降过程中的n,控制梯度下降速率
h = 0.0001
#进行iter次迭代计算
for k in range(iter):
#对于每一个样本进行梯度下降
#李航书中在2.3.1开头部分使用的梯度下降,是全部样本都算一遍以后,统一
#进行一次梯度下降
#在2.3.1的后半部分可以看到(例如公式2.6 2.7),求和符号没有了,此时用
#的是随机梯度下降,即计算一个样本就针对该样本进行一次梯度下降。
#两者的差异各有千秋,但较为常用的是随机梯度下降。
for i in range(m):
#获取当前样本的向量
xi = dataMat[i]
#获取当前样本所对应的标签
yi = labelMat[i]
#判断是否是误分类样本
#误分类样本特诊为: -yi(w*xi+b)>=0,详细可参考书中2.2.2小节
#在书的公式中写的是>0,实际上如果=0,说明改点在超平面上,也是不正确的
if -1 * yi * (w * xi.T + b) >= 0:
#对于误分类样本,进行梯度下降,更新w和b
w = w + h * yi * xi
b = b + h * yi
#打印训练进度
print('Round %d:%d training' % (k, iter))
#返回训练完的w、b
return w, b
def model_test(dataArr, labelArr, w, b):
''' 测试准确率 :param dataArr:测试集 :param labelArr: 测试集标签 :param w: 训练获得的权重w :param b: 训练获得的偏置b :return: 正确率 '''
print('start to test')
#将数据集转换为矩阵形式方便运算
dataMat = np.mat(dataArr)
#将label转换为矩阵并转置,详细信息参考上文perceptron中
#对于这部分的解说
labelMat = np.mat(labelArr).T
#获取测试数据集矩阵的大小
m, n = np.shape(dataMat)
#错误样本数计数
errorCnt = 0
#遍历所有测试样本
for i in range(m):
#获得单个样本向量
xi = dataMat[i]
#获得该样本标记
yi = labelMat[i]
#获得运算结果
result = -1 * yi * (w * xi.T + b)
#如果-yi(w*xi+b)>=0,说明该样本被误分类,错误样本数加一
if result >= 0: errorCnt += 1
#正确率 = 1 - (样本分类错误数 / 样本总数)
accruRate = 1 - (errorCnt / m)
#返回正确率
return accruRate
if __name__ == '__main__':
#获取当前时间
#在文末同样获取当前时间,两时间差即为程序运行时间
start = time.time()
#获取训练集及标签
trainData, trainLabel = loadData('../Mnist/mnist_train.csv')
#获取测试集及标签
testData, testLabel = loadData('../Mnist/mnist_test.csv')
#训练获得权重
w, b = perceptron(trainData, trainLabel, iter = 30)
#进行测试,获得正确率
accruRate = model_test(testData, testLabel, w, b)
#获取当前时间,作为结束时间
end = time.time()
#显示正确率
print('accuracy rate is:', accruRate)
#显示用时时长
print('time span:', end - start)
边栏推荐
- WPF 实现心电图曲线绘制
- [SV basics] some usage of queue
- EasyCVR播放视频出现卡顿花屏时如何解决?
- 如何进行数据库选型
- Reference materials in the process of using Excel
- 多维分析预汇总应该怎样做才管用?
- [MCU club] design of blind water cup based on MCU [simulation design]
- 同期群分析是什么?教你用 SQL 来搞定
- 戴口罩人臉數據集和戴口罩人臉生成方法
- Mask wearing face data set and mask wearing face generation method
猜你喜欢
UI高度自适应的修改方案
![[leetcode] 522. 最长特殊序列 II 暴力 + 双指针](/img/88/3ddeefaab7e29b8eeb412bb5c3e9b8.png)
[leetcode] 522. 最长特殊序列 II 暴力 + 双指针

Structure of the actual combat battalion | module 5

Esmm reading notes

GUI Graphical user interface programming example - color selection box

cocoscreator动态切换SkeletonData实现骨骼更新

Uvm:field automation mechanism

【温度检测】基于matlab GUI热红外图像温度检测系统【含Matlab源码 1920期】

FSS object storage how to access the Intranet

【leetcode】1719. Number of schemes for reconstructing a tree
随机推荐
WPF 实现心电图曲线绘制
Count the number of different palindrome subsequences in the string
Is l1-031 too fat (10 points)
[MCU club] design of GSM version of range hood based on MCU [physical design]
Is it safe to open a securities account at qiniu business school in 2022?
How many locks are added to an update statement? Take you to understand the underlying principles
What is contemporaneous group analysis? Teach you to use SQL to handle
Esmm reading notes
【RRT三维路径规划】基于matlab快速扩展随机树无人机三维路径规划【含Matlab源码 1914期】
Streaming media cluster application and configuration: how to deploy multiple easycvr on one server?
用户登录(记住用户)&用户注册(验证码) [运用Cookie Session技术]
分析框架——用户体验度量数据体系搭建
[UVM] my main_ Why can't the case exit when the phase runs out? Too unreasonable!
Program environment and pretreatment
Reference materials in the process of using Excel
GUI Graphical user interface programming example - color selection box
【温度检测】基于matlab GUI热红外图像温度检测系统【含Matlab源码 1920期】
大智慧上开户是安全的吗
【SV 基础】queue 的一些用法
多维分析预汇总应该怎样做才管用?