当前位置:网站首页>深度学习---三好学生各成绩所占权重问题(2)
深度学习---三好学生各成绩所占权重问题(2)
2022-06-29 18:13:00 【knighthood2001】
🥰 博客首页:knighthood2001
欢迎点赞评论️
️ 热爱python,期待与大家一同进步成长!!️
上文中深度学习(初识tensorflow2.版本)之三好学生成绩问题(1) 我们可以发现,搭建的神经网络已经可以运行,但显然还不能真正使用,因为它最终的计算结果是存在误差的。神经网络在投入使用前,都要经过训练的过程。那么,如何来训练神经网络呢?
目录
训练神经网络步骤步骤
①输入数据:例如例子中输入的x1、x2、x3,也就是两位学生各自的德育、智育、体育3项分数。
②计算结果:神经网络根据输入的数据和当前的可变参数值计算出结果,本文例子中为y。
③计算误差:将计算出来的结果y与我们期待的结果( 或者说标准答案,把它暂时称为yTrain进行比对,看看误差(loss)是多少;在例子中,yTrain 的值也就是两位学生各自已知的总分。
④调整神经网络可变参数:根据误差的大小,使用反向传播算法,对神经网络中的可变参数(也就是本章例子中的w1、w2、w3)进行相应的调节。
⑤再次训练:在调整完可变参数后,重复上述步骤重新进行训练,直至误差低于我们的理想水平,神经网络的训练就完成了。
上篇文章编写的程序已经实现了这个流程中的前两个步骤,下面我们来实现剩余的步骤。
代码展示
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
x1 = tf.compat.v1.placeholder(dtype=tf.float32)
x2 = tf.compat.v1.placeholder(dtype=tf.float32)
x3 = tf.compat.v1.placeholder(dtype=tf.float32)
# 设置标准答案
yTrain = tf.compat.v1.placeholder(dtype=tf.float32)
w1 = tf.Variable(0.1, dtype=tf.float32)
w2 = tf.Variable(0.1, dtype=tf.float32)
w3 = tf.Variable(0.1, dtype=tf.float32)
n1 = x1 * w1
n2 = x2 * w2
n3 = x3 * w3
y = n1 + n2 + n3
loss = tf.abs(y - yTrain)
optimizer = tf.compat.v1.train.RMSPropOptimizer(0.001)
train = optimizer.minimize(loss)
sess = tf.compat.v1.Session()
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
for i in range(10000):
result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain, loss], feed_dict={x1: 90, x2: 80, x3: 70, yTrain: 85})
print(result)
result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain, loss], feed_dict={x1: 98, x2: 95, x3: 87, yTrain: 96})
print(result)
变化之处
①定义了占位符yTrain,这是用来在训练时传入争对每一组输入数据我们期待的对应计算结果值的,后面一般把它简称为“目标计算结果”或“目标值”。
# 目标计算结果(目标值)
yTrain = tf.compat.v1.placeholder(dtype=tf.float32)
②Ⅰ在计算出结果y后,我们用tf.abs(y-yTrain)来计算误差,
Ⅱ然后定义了一个优化器变量optimizer。所谓优化器,就是用来调整神经网络可变参数的对象。我们采用的是RMSPropOptimizer,参数0.001是这个优化器的学习率(learn rate)。所谓学习率,我们在这里可以先简单的理解为:学习率决定了优化器每次调整参数的幅度大小。
Ⅲ定义完优化器后,我们又定义了一个训练对象train,它代表了我们准备如何来训练这个神经网络。我们把它定义为optimizer.minimize(loss),也就是要求优化器按照把loss最小化的原则来调整可变参数。
loss = tf.abs(y - yTrain)
optimizer = tf.compat.v1.train.RMSPropOptimizer(0.001)
train = optimizer.minimize(loss)
接下来我们就可以进行训练了,训练的代码和之前计算的很相似。
for i in range(10000):
result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain, loss], feed_dict={x1: 90, x2: 80, x3: 70, yTrain: 85})
print(result)
result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain, loss], feed_dict={x1: 98, x2: 95, x3: 87, yTrain: 96})
print(result)
不同之处主要有两个,一是在feed_dict参数中多指定一个yTrain的数值,也就是对应每一组输入数据x1,x2,x3,我们指定的目标结果值;二是在sess.run函数的第一个参数,也就是我们要求输出的结果数组当中,多加了一个train对象,在结果数组中有train对象,意味着要求程序要执行train对象所包含的训练过程,那么在这个过程中,y、loss等计算结果自然也会被计算出来;所以在结果数组中即使只写一个train,其他的结果也会都被计算出来。只不过我们看不到而已。
只有在结果数组中加上了训练对象,这次sess.run函数的执行才能被称为一次“训练”,否则只是“运行”一次神经网络或者说是用神经网络进行一次“计算”。
尽管两次训练的x1,x2,x3不同,但是神经网络的训练具备适应能力,能够在训练过程中逐步调整可变参数,试图去缩小所有输入数据的计算结果误差。
我们采用for循环,来个5000轮。最后两条结果如下:
loss缩小到0.023246765-0.0332489,w1,w2,w3的数值也很接近我们期待的0.6,0.3,0.1(我们之前假设的权重)。
之后,笔者将会讲解如何优化这里的神经网络模型。
边栏推荐
- WBF: new method of NMS post filter frame for detection task?
- Sd6.25 summary of intensive training
- C comparison of the performance of dapper efcore sqlsugar FreeSQL hisql sqlserver, an ORM framework at home and abroad
- Mac: MySQL 66 questions, 20000 words + 50 pictures!
- How do I add SmartArt to slides in PowerPoint?
- Shell basic syntax -- process control
- Adobe Premiere基础-批量素材导入序列-变速和倒放(回忆)-连续动作镜头切换-字幕要求(十三)
- Adobe Premiere foundation - time remapping (10)
- 行程卡“摘星”热搜第一!刺激旅游产品搜索量齐上涨
- Us judge ruled that the former security director of Uber accused of covering up hacking must face fraud charges
猜你喜欢
Know that Chuangyu has helped the energy industry in asset management and was selected into the 2021 IOT demonstration project of the Ministry of industry and information technology
Shandong University project training (VIII) design rotation map entry page
源码安装MAVROS
MySQL -connector/j driver download
C comparison of the performance of dapper efcore sqlsugar FreeSQL hisql sqlserver, an ORM framework at home and abroad
AMAZING PANDAVERSE:META”无国界,来2.0新征程激活时髦属性
Application and practice of DDD in domestic hotel transaction -- Theory
RocketMQ的tag过滤和sql过滤
Cannot retrieve repository metadata 处理记录
Adobe Premiere foundation - time remapping (10)
随机推荐
Apache InLong百万亿级数据流处理
Mac: MySQL 66 questions, 20000 words + 50 pictures!
Error building sqlsession problem
Adobe Premiere Foundation - réglage du son (correction du volume, réduction du bruit, tonalité téléphonique, changement de hauteur, égaliseur de paramètres) (XVIII)
idea怎么使用?
山东大学项目实训(八)设计轮播图进入页面
Adobe Premiere基础-素材嵌套(制作抖音结尾头像动画)(九)
如何在树莓派上使用OAK相机?
BeanUtils属性复制的用法
jdbc_相关代码
Basis of data analysis -- prediction model
对强缓存和协商缓存的理解
数据仓库模型分层ODS、DWD、DWM实战
The table ‘table_name‘ is full 异常排查及解决方案
[how the network is connected] Chapter 3 explores hubs, switches and routers
JWT登录验证
Shandong University project training (VI) Click event display line chart
What is a multi paradigm programming language and what does "multi paradigm" mean?
C Primer Plus Chapter 12_ Storage categories, links, and memory management_ Codes and exercises
SD6.23集训总结