当前位置:网站首页>神经网络与深度学习-5- 感知机-PyTorch
神经网络与深度学习-5- 感知机-PyTorch
2022-07-08 00:16:00 【明朝百晓生】
参考文档:
《神经网络与深度学习》
前言:
感知机是1957年由Frank RoseBlatt 提出的,是一种广泛应用的线性分类器。
这是一种错误驱动的算法
Scikit-Learn (1.Sklearn提供的常用数据集 - 自带的小数据集)_Micheal超的博客-CSDN博客_体能训练数据集
一 感知机
1.1 参数学习
算法试图找到一组参数w,使得对于每个样本有
1.1 损失函数
采用随机梯度下降
1.2 算法流程
1.3 算法的收敛性
1963年 Novioff证明了该算法在线性可分数据集上的可收敛性
缺点:
泛化能力差,
每次迭代顺序不一致超平面也不一样。
如果线性不可分,则永远不会收敛
二 投票感知器
感知机学习的权重向量和训练样本的顺序相关。
为了提高感知器的鲁棒性和泛化能力,我们可以讲感知器学习过程中的所有K个权重向量保存起来(出错的),并赋予每个权重向量一个一个置信系数
,最终的分类结果通过K个不同权重的感知器投票决定,这个模型也称为投票感知器
设
为第k次更新权重
的迭代次数(训练过的样本数量)
为下次更新迭代次数
的置信系数
设置成
到
之间间隔的迭代次数
,置信系数
越大,
说明权重
之后分类正确的样本越多,越值得信赖
这是一种集成学习的思想,用K个分类器,投票决定一个结果
三 平均感知器
T为总迭代次数,
为T次迭代平均权重向量
四 生成线性可分数据集
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 6 12:07:09 2022
@author: chengxf2
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
import csv
def saveCsv(trainData):
with open('trainData.csv','w') as f:
wr = csv.writer(f)
wr.writerows(trainData)
'''
分类生成器
参考 https://cloud.tencent.com/developer/ask/sof/1961912/answer/2665673
args
n_samples: 生成样本总数
n_features: 单样本维度 n_informative + n_redundant + n_repeated
n_classes : 类别
centers: 要生成的样本中心,默认为3
n_informative: 多信息特征维度
n_redundant: 冗余特征维度
n_repeated: 重复信息
shuffle: 打乱
n_clusters-per_calsss: 一个类别由几个cluster组成
return
data: array 样本X
feature: 样本特征
'''
def makeTrain(batch=1000):
separable = False
trainData =[]
while not separable:
samples = make_classification(n_samples=batch, n_features=2, n_redundant=0, n_informative=1, n_clusters_per_class=1, flip_y=-1)
red = samples[0][samples[1] == 0]
blue = samples[0][samples[1] == 1]
separable = any([red[:, k].max() < blue[:, k].min() or red[:, k].min() > blue[:, k].max() for k in range(2)])
data = samples[0]
feature = samples[1]
#print(np.shape(data),type(data))
for i in range(batch):
item = list(data[i])
label = feature[i]
item.append(label)
trainData.append(item)
#print(label)
plt.plot(red[:, 0], red[:, 1], 'r.')
plt.plot(blue[:, 0], blue[:, 1], 'b.')
plt.show()
return trainData
data = makeTrain(100)
saveCsv(data)
五 两类感知器的参数学习例子
以上面生成的数据集进行Train
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 6 14:08:27 2022
@author: chengxf2
"""
import numpy as np
import torch
import csv
import os
#感知机(模拟大脑识别事物和差异的人工网络)
class perceptron():
'''
从csv文件里面加载数据集
args
self.fileName : 文件路径
'''
def loadData(self):
if not os.path.exists(self.fileName) :
print("\n ------文件路径不存在-----")
return None
feature =[] #样本标签Y skelearn 为0-1 要转为 -1,1
trainData =[] # 训练集X
with open(self.fileName)as f:
f_csv = csv.reader(f)
for row in f_csv:
Y = int(row[-1])*2-1 # 变成[-1,1]
X = [float(v) for v in row[0:-1]]
X.append(1) #生成增广矩阵
trainData.append(X)
feature.append(Y)
#print(data,label)
self.m,self.n = np.shape(trainData)
print("\n ----第一步 加载数据集---------")
return torch.FloatTensor(trainData),torch.IntTensor(feature)
'''
预测
'''
def forecast(self,w,x):
hatY = torch.matmul(w.T, x)
sgnY = 0
#print("\n forecast",w, x)
if hatY>0:
sgnY =1
elif hatY<0:
sgnY =-1
return hatY, sgnY
'''
预测
'''
def test(self,trainData, feature,w):
err = 0
for i in range(0,self.m):
index = torch.LongTensor([i])
#print(index, index.dtype)
i = 0
x = trainData[index].T# 列向量
y = torch.index_select(feature, -1, index) #对应的标签值
predy,sngY =self.forecast(w,x)
result = sngY*y
#print("\n sngY: ",sngY,"\t y ",y)
if result<=0:
err = err+1
print("\n 分类出错的数目 ",err)
'''
训练
'''
def train(self,trainData, feature):
w = torch.zeros((3,1)) #权重系数
k = 0 #每轮预测出错的样本个数
t = 0 #迭代次数
bLoop = True
print("\n -----step2 训练开始了---------------")
while(bLoop):
perm = torch.randperm(self.m) #打乱数据集 随机采样
k = 0 #本轮分类出错的默认为0
t =t+1 #迭代次数
for i in range(self.m):
index = perm[i] #取索引
x = trainData[index].T# 列向量
y = torch.index_select(feature, -1, index) #对应的标签值
hatY,sgnY = self.forecast(w,x) #预测
result = y*sgnY #预测是否正确
if result<=0: #预测出错了
k = k+1
a = y*x
w = w+a.view(3,1) #更新梯度
#print("\n result ",result, "\t ",y, "\t sgny ",sgnY,index)
print("\n k:%d t:%d "%(k,t),"w: ",w)
if t == self.maxIter:
print("\n ---停止训练---------")
bLoop = False
break
return w
def __init__(self):
self.m = 0 #样本个数
self.n = 0 #样本维度
self.maxIter = 10 #最大迭代次数
self.fileName = "trainData.csv"
if __name__ == "__main__":
model = perceptron()
trainData, feature = model.loadData()
w = model.train(trainData, feature)
model.test(trainData, feature, w)
边栏推荐
- Remote Sensing投稿经验分享
- The solution of frame dropping problem in gnuradio OFDM operation
- Getting started STM32 -- how to learn stm32
- break net
- MATLAB R2021b 安装libsvm
- Gnuradio 3.9 using OOT custom module problem record
- QML fonts use pixelsize to adapt to the interface
- Tencent game client development interview (unity + cocos) double bombing social recruitment 6 rounds of interviews
- 如何让导电滑环信号更好
- 用户之声 | 冬去春来,静待花开 ——浅谈GBase 8a学习感悟
猜你喜欢
COMSOL----微阻梁模型的搭建---最终的温度分布和变形情况---材料的添加
2、TD+Learning
液压旋转接头的使用事项
PB9.0 insert OLE control error repair tool
Gnuradio operation error: error thread [thread per block [12]: < block OFDM_ cyclic_ prefixer(8)>]: Buffer too small
Problems of font legend and time scale display of MATLAB drawing coordinate axis
Redis集群
Guojingxin center "APEC investment +": some things about the Internet sector today | observation on stabilizing strategic industrial funds
QT build with built-in application framework -- Hello World -- use min GW 32bit
About snake equation (5)
随机推荐
NPM Internal Split module
Graphic network: uncover the principle behind TCP's four waves, combined with the example of boyfriend and girlfriend breaking up, which is easy to understand
In depth analysis of ArrayList source code, from the most basic capacity expansion principle, to the magic iterator and fast fail mechanism, you have everything you want!!!
QT--创建QT程序
2、TD+Learning
批次管控如何实现?MES系统给您答案
云原生应用开发之 gRPC 入门
STM32GPIO口的工作原理
npm 內部拆分模塊
COMSOL - Construction of micro resistance beam model - final temperature distribution and deformation - establishment of geometric model
QML fonts use pixelsize to adapt to the interface
break net
Optimization of ecological | Lake Warehouse Integration: gbase 8A MPP + xeos
uniapp一键复制功能效果demo(整理)
Gnuradio operation error: error thread [thread per block [12]: < block OFDM_ cyclic_ prefixer(8)>]: Buffer too small
Codeforces Round #649 (Div. 2)——A. XXXXX
Mat file usage
Probability distribution
The difference between distribution function and probability density function of random variables
Is it safe to open an account on your mobile phone for small amount of stock speculation?