当前位置:网站首页>神经网络实现鸢尾花分类
神经网络实现鸢尾花分类
2022-07-28 05:22:00 【积雨辋川】
神经网络实现鸢尾花分类
一、数据集介绍
共有数据150组,每组包括花萼长、花萼宽、花瓣长、花瓣宽4个输入特征。
同时给出了,这一组特征对应的鸢尾花类别。类别包括Setosa Iris(狗尾草
鸢尾),Versicolour Iris(杂色鸢尾),Virginica Iris(弗吉尼亚鸢尾)三
类,分别用数字0,1,2表示。
从sklearn包 datasets 读入数据集,语法为:
from sklearn.datasets import load_iris
x_data = datasets.load_iris().data 返回iris数据集所有输入特征
y_data = datasets.load_iris().target 返回iris数据集所有标签
二、鸢尾花分类
2.1 准备数据
• 数据集读入
• 数据集乱序
• 生成训练集和测试集(即 x_train / y_train)
• 配成 (输入特征,标签) 对,每次读入一小撮(batch)
2.2 搭建网络
• 定义神经网路中所有可训练参数
2.3 参数优化
• 嵌套循环迭代,with结构更新参数,显示当前loss
2.4 测试效果
• 计算当前参数前向传播后的准确率,显示当前acc
2.5 acc/loss可视化
三、代码实现
# -*- coding: UTF-8 -*-
# 利用鸢尾花数据集,实现前向传播、反向传播,可视化loss曲线
# 导入所需模块
import tensorflow as tf
from sklearn import datasets
from matplotlib import pyplot as plt
import numpy as np
# 导入数据,分别为输入特征和标签
x_data = datasets.load_iris().data
y_data = datasets.load_iris().target
# 随机打乱数据(因为原始数据是顺序的,顺序不打乱会影响准确率)
# seed: 随机数种子,是一个整数,当设置之后,每次生成的随机数都一样(为方便教学,以保每位同学结果一致)
np.random.seed(116) # 使用相同的seed,保证输入特征和标签一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
# 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
# 转换x的数据类型,否则后面矩阵相乘时会因数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
# from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
# 生成神经网络的参数,4个输入特征,故输入层为4个输入节点;因为3分类,故输出层为3个神经元
# 用tf.Variable()标记参数可训练
# 使用seed使每次生成的随机数相同(方便教学,使大家结果都一致,在现实使用时不写seed)
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))
lr = 0.1 # 学习率为0.1
train_loss_results = [] # 将每轮的loss记录在此列表中,为后续画loss曲线提供数据
test_acc = [] # 将每轮的acc记录在此列表中,为后续画acc曲线提供数据
epoch = 500 # 循环500轮
loss_all = 0 # 每轮分4个step,loss_all记录四个step生成的4个loss的和
# 训练部分
for epoch in range(epoch): # 数据集级别的循环,每个epoch循环一次数据集
for step, (x_train, y_train) in enumerate(train_db): # batch级别的循环 ,每个step循环一个batch
with tf.GradientTape() as tape: # with结构记录梯度信息
y = tf.matmul(x_train, w1) + b1 # 神经网络乘加运算
y = tf.nn.softmax(y) # 使输出y符合概率分布(此操作后与独热码同量级,可相减求loss)
y_ = tf.one_hot(y_train, depth=3) # 将标签值转换为独热码格式,方便计算loss和accuracy
loss = tf.reduce_mean(tf.square(y_ - y)) # 采用均方误差损失函数mse = mean(sum(y-out)^2)
loss_all += loss.numpy() # 将每个step计算出的loss累加,为后续求loss平均值提供数据,这样计算的loss更准确
# 计算loss对各个参数的梯度
grads = tape.gradient(loss, [w1, b1])
# 实现梯度更新 w1 = w1 - lr * w1_grad b = b - lr * b_grad
w1.assign_sub(lr * grads[0]) # 参数w1自更新
b1.assign_sub(lr * grads[1]) # 参数b自更新
# 每个epoch,打印loss信息
print("Epoch {}, loss: {}".format(epoch, loss_all/4))
train_loss_results.append(loss_all / 4) # 将4个step的loss求平均记录在此变量中
loss_all = 0 # loss_all归零,为记录下一个epoch的loss做准备
# 测试部分
# total_correct为预测对的样本个数, total_number为测试的总样本数,将这两个变量都初始化为0
total_correct, total_number = 0, 0
for x_test, y_test in test_db:
# 使用更新后的参数进行预测
y = tf.matmul(x_test, w1) + b1
y = tf.nn.softmax(y)
pred = tf.argmax(y, axis=1) # 返回y中最大值的索引,即预测的分类
# 将pred转换为y_test的数据类型
pred = tf.cast(pred, dtype=y_test.dtype)
# 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型
correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
# 将每个batch的correct数加起来
correct = tf.reduce_sum(correct)
# 将所有batch中的correct数加起来
total_correct += int(correct)
# total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数
total_number += x_test.shape[0]
# 总的准确率等于total_correct/total_number
acc = total_correct / total_number
test_acc.append(acc)
print("Test_acc:", acc)
print("--------------------------")
# 绘制 loss 曲线
plt.title('Loss Function Curve') # 图片标题
plt.xlabel('Epoch') # x轴变量名称
plt.ylabel('Loss') # y轴变量名称
plt.plot(train_loss_results, label="$Loss$") # 逐点画出trian_loss_results值并连线,连线图标是Loss
plt.legend() # 画出曲线图标
plt.show() # 画出图像
# 绘制 Accuracy 曲线
plt.title('Acc Curve') # 图片标题
plt.xlabel('Epoch') # x轴变量名称
plt.ylabel('Acc') # y轴变量名称
plt.plot(test_acc, label="$Accuracy$") # 逐点画出test_acc值并连线,连线图标是Accuracy
plt.legend()
plt.show()
运行结果如下:


边栏推荐
- Centos7 installing MySQL
- Xshell suddenly failed to connect to the virtual machine
- 【一】redis简介
- 知识点21-泛型
- Use Python to encapsulate a tool class that sends mail regularly
- Two methods of covering duplicate records in tables in MySQL
- 强化学习——多智能体强化学习
- 3: MySQL master-slave replication setup
- Kubesphere installation version problem
- tf.keras搭建神经网络功能扩展
猜你喜欢

高端大气的小程序开发设计有哪些注意点?

Applet development

matplotlib数据可视化

self-attention学习笔记

强化学习——价值学习中的SARSA

深度学习(自监督:SimCLR)——A Simple Framework for Contrastive Learning of Visual Representations

On July 7, the national wind 24 solar terms "Xiaoshu" came!! Attachment.. cooperation.. completion.. advance.. report

Four perspectives to teach you to choose applet development tools?

Service reliability guarantee -watchdog

Manually create a simple RPC (< - < -)
随机推荐
tensorboard可视化
分布式集群架构场景化解决方案:集群时钟同步问题
The project does not report an error, operates normally, and cannot request services
XML parsing entity tool class
Mars number * word * Tibet * product * Pingtai defender plan details announced
Mysql5.6 (according to.Ibd,.Frm file) restore single table data
如何选择小程序开发企业
分布式集群架构场景优化解决方案:Session共享问题
强化学习——连续控制
CertPathValidatorException:validity check failed
CertPathValidatorException:validity check failed
Linux(centOs7) 下安装redis
强化学习——不完全观测问题、MCTS
Kubesphere installation version problem
Assembly packaging
matplotlib数据可视化
深度学习——Patches Are All You Need
uniapp webview监听页面加载后回调
微信小程序开发制作注意这几个重点方面
速查表之转MD5