当前位置:网站首页>神经网络与深度学习-5- 感知机-PyTorch
神经网络与深度学习-5- 感知机-PyTorch
2022-07-08 00:16:00 【明朝百晓生】
参考文档:
《神经网络与深度学习》
前言:
感知机是1957年由Frank RoseBlatt 提出的,是一种广泛应用的线性分类器。
这是一种错误驱动的算法
Scikit-Learn (1.Sklearn提供的常用数据集 - 自带的小数据集)_Micheal超的博客-CSDN博客_体能训练数据集
一 感知机

1.1 参数学习

算法试图找到一组参数w,使得对于每个样本
有

1.1 损失函数

采用随机梯度下降

1.2 算法流程

1.3 算法的收敛性
1963年 Novioff证明了该算法在线性可分数据集上的可收敛性
缺点:
泛化能力差,
每次迭代顺序不一致超平面也不一样。
如果线性不可分,则永远不会收敛
二 投票感知器
感知机学习的权重向量和训练样本的顺序相关。
为了提高感知器的鲁棒性和泛化能力,我们可以讲感知器学习过程中的所有K个权重向量保存起来(出错的),并赋予每个权重向量一个
一个置信系数
,最终的分类结果通过K个不同权重的感知器投票决定,这个模型也称为投票感知器
设
为第k次更新权重
的迭代次数(训练过的样本数量)
为下次更新迭代次数
的置信系数
设置成
到
之间间隔的迭代次数
,置信系数
越大,
说明权重
之后分类正确的样本越多,越值得信赖
这是一种集成学习的思想,用K个分类器,投票决定一个结果
三 平均感知器




T为总迭代次数,
为T次迭代平均权重向量


四 生成线性可分数据集
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 6 12:07:09 2022
@author: chengxf2
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
import csv
def saveCsv(trainData):
with open('trainData.csv','w') as f:
wr = csv.writer(f)
wr.writerows(trainData)
'''
分类生成器
参考 https://cloud.tencent.com/developer/ask/sof/1961912/answer/2665673
args
n_samples: 生成样本总数
n_features: 单样本维度 n_informative + n_redundant + n_repeated
n_classes : 类别
centers: 要生成的样本中心,默认为3
n_informative: 多信息特征维度
n_redundant: 冗余特征维度
n_repeated: 重复信息
shuffle: 打乱
n_clusters-per_calsss: 一个类别由几个cluster组成
return
data: array 样本X
feature: 样本特征
'''
def makeTrain(batch=1000):
separable = False
trainData =[]
while not separable:
samples = make_classification(n_samples=batch, n_features=2, n_redundant=0, n_informative=1, n_clusters_per_class=1, flip_y=-1)
red = samples[0][samples[1] == 0]
blue = samples[0][samples[1] == 1]
separable = any([red[:, k].max() < blue[:, k].min() or red[:, k].min() > blue[:, k].max() for k in range(2)])
data = samples[0]
feature = samples[1]
#print(np.shape(data),type(data))
for i in range(batch):
item = list(data[i])
label = feature[i]
item.append(label)
trainData.append(item)
#print(label)
plt.plot(red[:, 0], red[:, 1], 'r.')
plt.plot(blue[:, 0], blue[:, 1], 'b.')
plt.show()
return trainData
data = makeTrain(100)
saveCsv(data)五 两类感知器的参数学习例子
以上面生成的数据集进行Train

# -*- coding: utf-8 -*-
"""
Created on Wed Jul 6 14:08:27 2022
@author: chengxf2
"""
import numpy as np
import torch
import csv
import os
#感知机(模拟大脑识别事物和差异的人工网络)
class perceptron():
'''
从csv文件里面加载数据集
args
self.fileName : 文件路径
'''
def loadData(self):
if not os.path.exists(self.fileName) :
print("\n ------文件路径不存在-----")
return None
feature =[] #样本标签Y skelearn 为0-1 要转为 -1,1
trainData =[] # 训练集X
with open(self.fileName)as f:
f_csv = csv.reader(f)
for row in f_csv:
Y = int(row[-1])*2-1 # 变成[-1,1]
X = [float(v) for v in row[0:-1]]
X.append(1) #生成增广矩阵
trainData.append(X)
feature.append(Y)
#print(data,label)
self.m,self.n = np.shape(trainData)
print("\n ----第一步 加载数据集---------")
return torch.FloatTensor(trainData),torch.IntTensor(feature)
'''
预测
'''
def forecast(self,w,x):
hatY = torch.matmul(w.T, x)
sgnY = 0
#print("\n forecast",w, x)
if hatY>0:
sgnY =1
elif hatY<0:
sgnY =-1
return hatY, sgnY
'''
预测
'''
def test(self,trainData, feature,w):
err = 0
for i in range(0,self.m):
index = torch.LongTensor([i])
#print(index, index.dtype)
i = 0
x = trainData[index].T# 列向量
y = torch.index_select(feature, -1, index) #对应的标签值
predy,sngY =self.forecast(w,x)
result = sngY*y
#print("\n sngY: ",sngY,"\t y ",y)
if result<=0:
err = err+1
print("\n 分类出错的数目 ",err)
'''
训练
'''
def train(self,trainData, feature):
w = torch.zeros((3,1)) #权重系数
k = 0 #每轮预测出错的样本个数
t = 0 #迭代次数
bLoop = True
print("\n -----step2 训练开始了---------------")
while(bLoop):
perm = torch.randperm(self.m) #打乱数据集 随机采样
k = 0 #本轮分类出错的默认为0
t =t+1 #迭代次数
for i in range(self.m):
index = perm[i] #取索引
x = trainData[index].T# 列向量
y = torch.index_select(feature, -1, index) #对应的标签值
hatY,sgnY = self.forecast(w,x) #预测
result = y*sgnY #预测是否正确
if result<=0: #预测出错了
k = k+1
a = y*x
w = w+a.view(3,1) #更新梯度
#print("\n result ",result, "\t ",y, "\t sgny ",sgnY,index)
print("\n k:%d t:%d "%(k,t),"w: ",w)
if t == self.maxIter:
print("\n ---停止训练---------")
bLoop = False
break
return w
def __init__(self):
self.m = 0 #样本个数
self.n = 0 #样本维度
self.maxIter = 10 #最大迭代次数
self.fileName = "trainData.csv"
if __name__ == "__main__":
model = perceptron()
trainData, feature = model.loadData()
w = model.train(trainData, feature)
model.test(trainData, feature, w)边栏推荐
- About snake equation (3)
- The numerical value of the number of figures thought of by the real-time update of the ranking list
- LeetCode 练习——剑指 Offer 36. 二叉搜索树与双向链表
- regular expression
- What kind of MES system is a good system
- Matlab code on error analysis (MAE, MAPE, RMSE)
- Guojingxin center "APEC investment +": some things about the Internet sector today | observation on stabilizing strategic industrial funds
- Gnuradio3.9.4 create OOT module instances
- PHP to get information such as audio duration
- Version 2.0 of tapdata, the open source live data platform, has been released
猜你喜欢

Understanding of prior probability, posterior probability and Bayesian formula

PB9.0 insert OLE control error repair tool

碳刷滑环在发电机中的作用

【目标跟踪】|DiMP: Learning Discriminative Model Prediction for Tracking

Redis集群

Probability distribution

COMSOL----微阻梁模型的搭建---最终的温度分布和变形情况----几何模型的建立

2、TD+Learning

Different methods for setting headers of different pages in word (the same for footer and page number)

COMSOL----微阻梁模型的搭建---最终的温度分布和变形情况---材料的添加
随机推荐
正则表达式
写一个纯手写的qt的hello world
Capability contribution three solutions of gbase were selected into the "financial information innovation ecological laboratory - financial information innovation solutions (the first batch)"
Tapdata 的 2.0 版 ,開源的 Live Data Platform 現已發布
Apache多个组件漏洞公开(CVE-2022-32533/CVE-2022-33980/CVE-2021-37839)
Understanding of expectation, variance, covariance and correlation coefficient
Kindle operation: transfer downloaded books and change book cover
The solution of frame dropping problem in gnuradio OFDM operation
ArrayList源码深度剖析,从最基本的扩容原理,到魔幻的迭代器和fast-fail机制,你想要的这都有!!!
npm 內部拆分模塊
Voice of users | understanding of gbase 8A database learning
碳刷滑环在发电机中的作用
跨模态语义关联对齐检索-图像文本匹配(Image-Text Matching)
Call (import) in Jupiter notebook ipynb . Py file
From starfish OS' continued deflationary consumption of SFO, the value of SFO in the long run
About snake equation (1)
Euler Lagrange equation
Probability distribution
Codeforces Round #633 (Div. 2) B. Sorted Adjacent Differences
qt--将程序打包--不要安装qt-可以直接运行
为第k次更新权重
为下次更新迭代次数
设置成
,置信系数