当前位置:网站首页>神经网络与深度学习-5- 感知机-PyTorch
神经网络与深度学习-5- 感知机-PyTorch
2022-07-08 00:16:00 【明朝百晓生】
参考文档:
《神经网络与深度学习》
前言:
感知机是1957年由Frank RoseBlatt 提出的,是一种广泛应用的线性分类器。
这是一种错误驱动的算法
Scikit-Learn (1.Sklearn提供的常用数据集 - 自带的小数据集)_Micheal超的博客-CSDN博客_体能训练数据集
一 感知机

1.1 参数学习

算法试图找到一组参数w,使得对于每个样本
有

1.1 损失函数

采用随机梯度下降

1.2 算法流程

1.3 算法的收敛性
1963年 Novioff证明了该算法在线性可分数据集上的可收敛性
缺点:
泛化能力差,
每次迭代顺序不一致超平面也不一样。
如果线性不可分,则永远不会收敛
二 投票感知器
感知机学习的权重向量和训练样本的顺序相关。
为了提高感知器的鲁棒性和泛化能力,我们可以讲感知器学习过程中的所有K个权重向量保存起来(出错的),并赋予每个权重向量一个
一个置信系数
,最终的分类结果通过K个不同权重的感知器投票决定,这个模型也称为投票感知器
设
为第k次更新权重
的迭代次数(训练过的样本数量)
为下次更新迭代次数
的置信系数
设置成
到
之间间隔的迭代次数
,置信系数
越大,
说明权重
之后分类正确的样本越多,越值得信赖
这是一种集成学习的思想,用K个分类器,投票决定一个结果
三 平均感知器




T为总迭代次数,
为T次迭代平均权重向量


四 生成线性可分数据集
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 6 12:07:09 2022
@author: chengxf2
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
import csv
def saveCsv(trainData):
with open('trainData.csv','w') as f:
wr = csv.writer(f)
wr.writerows(trainData)
'''
分类生成器
参考 https://cloud.tencent.com/developer/ask/sof/1961912/answer/2665673
args
n_samples: 生成样本总数
n_features: 单样本维度 n_informative + n_redundant + n_repeated
n_classes : 类别
centers: 要生成的样本中心,默认为3
n_informative: 多信息特征维度
n_redundant: 冗余特征维度
n_repeated: 重复信息
shuffle: 打乱
n_clusters-per_calsss: 一个类别由几个cluster组成
return
data: array 样本X
feature: 样本特征
'''
def makeTrain(batch=1000):
separable = False
trainData =[]
while not separable:
samples = make_classification(n_samples=batch, n_features=2, n_redundant=0, n_informative=1, n_clusters_per_class=1, flip_y=-1)
red = samples[0][samples[1] == 0]
blue = samples[0][samples[1] == 1]
separable = any([red[:, k].max() < blue[:, k].min() or red[:, k].min() > blue[:, k].max() for k in range(2)])
data = samples[0]
feature = samples[1]
#print(np.shape(data),type(data))
for i in range(batch):
item = list(data[i])
label = feature[i]
item.append(label)
trainData.append(item)
#print(label)
plt.plot(red[:, 0], red[:, 1], 'r.')
plt.plot(blue[:, 0], blue[:, 1], 'b.')
plt.show()
return trainData
data = makeTrain(100)
saveCsv(data)五 两类感知器的参数学习例子
以上面生成的数据集进行Train

# -*- coding: utf-8 -*-
"""
Created on Wed Jul 6 14:08:27 2022
@author: chengxf2
"""
import numpy as np
import torch
import csv
import os
#感知机(模拟大脑识别事物和差异的人工网络)
class perceptron():
'''
从csv文件里面加载数据集
args
self.fileName : 文件路径
'''
def loadData(self):
if not os.path.exists(self.fileName) :
print("\n ------文件路径不存在-----")
return None
feature =[] #样本标签Y skelearn 为0-1 要转为 -1,1
trainData =[] # 训练集X
with open(self.fileName)as f:
f_csv = csv.reader(f)
for row in f_csv:
Y = int(row[-1])*2-1 # 变成[-1,1]
X = [float(v) for v in row[0:-1]]
X.append(1) #生成增广矩阵
trainData.append(X)
feature.append(Y)
#print(data,label)
self.m,self.n = np.shape(trainData)
print("\n ----第一步 加载数据集---------")
return torch.FloatTensor(trainData),torch.IntTensor(feature)
'''
预测
'''
def forecast(self,w,x):
hatY = torch.matmul(w.T, x)
sgnY = 0
#print("\n forecast",w, x)
if hatY>0:
sgnY =1
elif hatY<0:
sgnY =-1
return hatY, sgnY
'''
预测
'''
def test(self,trainData, feature,w):
err = 0
for i in range(0,self.m):
index = torch.LongTensor([i])
#print(index, index.dtype)
i = 0
x = trainData[index].T# 列向量
y = torch.index_select(feature, -1, index) #对应的标签值
predy,sngY =self.forecast(w,x)
result = sngY*y
#print("\n sngY: ",sngY,"\t y ",y)
if result<=0:
err = err+1
print("\n 分类出错的数目 ",err)
'''
训练
'''
def train(self,trainData, feature):
w = torch.zeros((3,1)) #权重系数
k = 0 #每轮预测出错的样本个数
t = 0 #迭代次数
bLoop = True
print("\n -----step2 训练开始了---------------")
while(bLoop):
perm = torch.randperm(self.m) #打乱数据集 随机采样
k = 0 #本轮分类出错的默认为0
t =t+1 #迭代次数
for i in range(self.m):
index = perm[i] #取索引
x = trainData[index].T# 列向量
y = torch.index_select(feature, -1, index) #对应的标签值
hatY,sgnY = self.forecast(w,x) #预测
result = y*sgnY #预测是否正确
if result<=0: #预测出错了
k = k+1
a = y*x
w = w+a.view(3,1) #更新梯度
#print("\n result ",result, "\t ",y, "\t sgny ",sgnY,index)
print("\n k:%d t:%d "%(k,t),"w: ",w)
if t == self.maxIter:
print("\n ---停止训练---------")
bLoop = False
break
return w
def __init__(self):
self.m = 0 #样本个数
self.n = 0 #样本维度
self.maxIter = 10 #最大迭代次数
self.fileName = "trainData.csv"
if __name__ == "__main__":
model = perceptron()
trainData, feature = model.loadData()
w = model.train(trainData, feature)
model.test(trainData, feature, w)边栏推荐
- Understanding of prior probability, posterior probability and Bayesian formula
- About snake equation (3)
- Gbase observation | how to protect the security of information system with frequent data leakage
- 从Starfish OS持续对SFO的通缩消耗,长远看SFO的价值
- How does Matplotlib generate multiple pictures in turn & only save these pictures without displaying them in the compiler
- NPM internal split module
- break net
- 能力贡献 GBASE三大解决方案入选“金融信创生态实验室-金融信创解决方案(第一批)”
- Codeforces Round #649 (Div. 2)——A. XXXXX
- Codeforces Round #633 (Div. 2) B. Sorted Adjacent Differences
猜你喜欢

How to make enterprise recruitment QR code?

Leetcode exercise - Sword finger offer 36 Binary search tree and bidirectional linked list

3. Multi agent reinforcement learning

How to fix the slip ring

能力贡献 GBASE三大解决方案入选“金融信创生态实验室-金融信创解决方案(第一批)”

Running OFDM in gnuradio_ RX error: gr:: Log: info: packet_ headerparser_ b0 - Detected an invalid packet at item ××

Get familiar with XML parsing quickly

Redux usage

【目标跟踪】|DiMP: Learning Discriminative Model Prediction for Tracking

Why does the updated DNS record not take effect?
随机推荐
COMSOL - Construction of micro resistance beam model - final temperature distribution and deformation - establishment of geometric model
Working principle of stm32gpio port
生态 | 湖仓一体的优选:GBase 8a MPP + XEOS
LaTeX 中 xcolor 颜色的用法
QT--创建QT程序
如果时间是条河
Version 2.0 de tapdata, Open Source Live Data Platform est maintenant disponible
Euler Lagrange equation
break net
In depth analysis of ArrayList source code, from the most basic capacity expansion principle, to the magic iterator and fast fail mechanism, you have everything you want!!!
Urban land use distribution data / urban functional zoning distribution data / urban POI points of interest / vegetation type distribution
Usage of xcolor color in latex
GBASE观察 | 数据泄露频发 信息系统安全应如何守护
qt-使用自带的应用框架建立--hello world--使用min GW 32bit
PHP to get information such as audio duration
break net
滑环在直驱电机转子的应用领域
C语言-模块化-Clion(静态库,动态库)使用
[loss function] entropy / relative entropy / cross entropy
Running OFDM in gnuradio_ RX error: gr:: Log: info: packet_ headerparser_ b0 - Detected an invalid packet at item ××
为第k次更新权重
为下次更新迭代次数
设置成
,置信系数