当前位置:网站首页>神经网络实现鸢尾花分类
神经网络实现鸢尾花分类
2022-07-28 05:22:00 【积雨辋川】
神经网络实现鸢尾花分类
一、数据集介绍
共有数据150组,每组包括花萼长、花萼宽、花瓣长、花瓣宽4个输入特征。
同时给出了,这一组特征对应的鸢尾花类别。类别包括Setosa Iris(狗尾草
鸢尾),Versicolour Iris(杂色鸢尾),Virginica Iris(弗吉尼亚鸢尾)三
类,分别用数字0,1,2表示。
从sklearn包 datasets 读入数据集,语法为:
from sklearn.datasets import load_iris
x_data = datasets.load_iris().data 返回iris数据集所有输入特征
y_data = datasets.load_iris().target 返回iris数据集所有标签
二、鸢尾花分类
2.1 准备数据
• 数据集读入
• 数据集乱序
• 生成训练集和测试集(即 x_train / y_train)
• 配成 (输入特征,标签) 对,每次读入一小撮(batch)
2.2 搭建网络
• 定义神经网路中所有可训练参数
2.3 参数优化
• 嵌套循环迭代,with结构更新参数,显示当前loss
2.4 测试效果
• 计算当前参数前向传播后的准确率,显示当前acc
2.5 acc/loss可视化
三、代码实现
# -*- coding: UTF-8 -*-
# 利用鸢尾花数据集,实现前向传播、反向传播,可视化loss曲线
# 导入所需模块
import tensorflow as tf
from sklearn import datasets
from matplotlib import pyplot as plt
import numpy as np
# 导入数据,分别为输入特征和标签
x_data = datasets.load_iris().data
y_data = datasets.load_iris().target
# 随机打乱数据(因为原始数据是顺序的,顺序不打乱会影响准确率)
# seed: 随机数种子,是一个整数,当设置之后,每次生成的随机数都一样(为方便教学,以保每位同学结果一致)
np.random.seed(116) # 使用相同的seed,保证输入特征和标签一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
# 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
# 转换x的数据类型,否则后面矩阵相乘时会因数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
# from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
# 生成神经网络的参数,4个输入特征,故输入层为4个输入节点;因为3分类,故输出层为3个神经元
# 用tf.Variable()标记参数可训练
# 使用seed使每次生成的随机数相同(方便教学,使大家结果都一致,在现实使用时不写seed)
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))
lr = 0.1 # 学习率为0.1
train_loss_results = [] # 将每轮的loss记录在此列表中,为后续画loss曲线提供数据
test_acc = [] # 将每轮的acc记录在此列表中,为后续画acc曲线提供数据
epoch = 500 # 循环500轮
loss_all = 0 # 每轮分4个step,loss_all记录四个step生成的4个loss的和
# 训练部分
for epoch in range(epoch): # 数据集级别的循环,每个epoch循环一次数据集
for step, (x_train, y_train) in enumerate(train_db): # batch级别的循环 ,每个step循环一个batch
with tf.GradientTape() as tape: # with结构记录梯度信息
y = tf.matmul(x_train, w1) + b1 # 神经网络乘加运算
y = tf.nn.softmax(y) # 使输出y符合概率分布(此操作后与独热码同量级,可相减求loss)
y_ = tf.one_hot(y_train, depth=3) # 将标签值转换为独热码格式,方便计算loss和accuracy
loss = tf.reduce_mean(tf.square(y_ - y)) # 采用均方误差损失函数mse = mean(sum(y-out)^2)
loss_all += loss.numpy() # 将每个step计算出的loss累加,为后续求loss平均值提供数据,这样计算的loss更准确
# 计算loss对各个参数的梯度
grads = tape.gradient(loss, [w1, b1])
# 实现梯度更新 w1 = w1 - lr * w1_grad b = b - lr * b_grad
w1.assign_sub(lr * grads[0]) # 参数w1自更新
b1.assign_sub(lr * grads[1]) # 参数b自更新
# 每个epoch,打印loss信息
print("Epoch {}, loss: {}".format(epoch, loss_all/4))
train_loss_results.append(loss_all / 4) # 将4个step的loss求平均记录在此变量中
loss_all = 0 # loss_all归零,为记录下一个epoch的loss做准备
# 测试部分
# total_correct为预测对的样本个数, total_number为测试的总样本数,将这两个变量都初始化为0
total_correct, total_number = 0, 0
for x_test, y_test in test_db:
# 使用更新后的参数进行预测
y = tf.matmul(x_test, w1) + b1
y = tf.nn.softmax(y)
pred = tf.argmax(y, axis=1) # 返回y中最大值的索引,即预测的分类
# 将pred转换为y_test的数据类型
pred = tf.cast(pred, dtype=y_test.dtype)
# 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型
correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
# 将每个batch的correct数加起来
correct = tf.reduce_sum(correct)
# 将所有batch中的correct数加起来
total_correct += int(correct)
# total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数
total_number += x_test.shape[0]
# 总的准确率等于total_correct/total_number
acc = total_correct / total_number
test_acc.append(acc)
print("Test_acc:", acc)
print("--------------------------")
# 绘制 loss 曲线
plt.title('Loss Function Curve') # 图片标题
plt.xlabel('Epoch') # x轴变量名称
plt.ylabel('Loss') # y轴变量名称
plt.plot(train_loss_results, label="$Loss$") # 逐点画出trian_loss_results值并连线,连线图标是Loss
plt.legend() # 画出曲线图标
plt.show() # 画出图像
# 绘制 Accuracy 曲线
plt.title('Acc Curve') # 图片标题
plt.xlabel('Epoch') # x轴变量名称
plt.ylabel('Acc') # y轴变量名称
plt.plot(test_acc, label="$Accuracy$") # 逐点画出test_acc值并连线,连线图标是Accuracy
plt.legend()
plt.show()
运行结果如下:


边栏推荐
- 卷积神经网络
- 小程序开发哪家更靠谱呢?
- What are the points for attention in the development and design of high-end atmospheric applets?
- raise RuntimeError(‘DataLoader worker (pid(s) {}) exited unexpectedly‘.format(pids_str))RuntimeErro
- Alpine, Debian replacement source
- 微信小程序制作模板套用时需要注意什么呢?
- 强化学习——价值学习中的DQN
- Micro service architecture cognition and service governance Eureka
- 分布式集群架构场景优化解决方案:分布式ID解决方案
- Distributed lock database implementation
猜你喜欢
随机推荐
On how digital collections and entities can empower each other
SQLAlchemy使用相关
Xshell suddenly failed to connect to the virtual machine
Use Python to encapsulate a tool class that sends mail regularly
【1】 Introduction to redis
使用pyhon封装一个定时发送邮件的工具类
Sales notice: on July 22, the "great heat" will be sold, and the [traditional national wind 24 solar terms] will be sold in summer.
微信团购小程序怎么做?一般要多少钱?
pytorch深度学习单卡训练和多卡训练
深度学习(自监督:MoCo V3):An Empirical Study of Training Self-Supervised Vision Transformers
简单理解一下MVC和三层架构
Pytorch deep learning single card training and multi card training
What should we pay attention to when making template application of wechat applet?
Various programming languages decimal | time | Base64 and other operations of the quick look-up table
There is a problem with MySQL paging
1: Why should databases be divided into databases and tables
如何选择小程序开发企业
Idempotent component
【一】redis简介
Two methods of covering duplicate records in tables in MySQL









