当前位置:网站首页>【问题解决】Tensorflow中run究竟运行了哪些节点?
【问题解决】Tensorflow中run究竟运行了哪些节点?
2022-06-27 13:07:00 【Can__er】
我们都知道Tensorflow中是先定义图结构,再驱动节点进行运算。当构建完graph图后,需要在一个session会话中启动图,可以自定义创建一个对话,也可以使用默认对话。
其中典型的两个驱动节点的方法是session.run()和tensor.eval(),调用tensor.eval()相当于调用session().run(tensor)。二者的区别在这篇文章中有了较为详细的解释:
下面针对run(self, fetches, feed_dict=None, options=None, run_metadata=None) 方法做出详细介绍:
其中常用的fetches和feed_dict就是常用的传入参数。fetches主要指从计算图中取回计算结果进行放回的那些placeholder和变量,而feed_dict则是将对应的数据传入计算图中占位符,它是字典数据结构只在调用方法内有效。
其中需要注意的一点是tensorflow并不是计算了整个图,只是计算了与想要fetch的值相关的部分,这一点在很多答案中给出了错误的解答,可以看一下官方注释,是 “running the necessary graph fragment”:
''' This method runs one "step" of TensorFlow computation, by running the necessary graph fragment to execute every `Operation` and evaluate every `Tensor` in `fetches`, substituting the values in `feed_dict` for the corresponding input values. '''
这种类似预“用到什么才去计算什么”的感觉非常像拉式,这也说明了为什么定义了一个完整的网络,但有的时候传入X和y进行训练,有的时候只传入X来进行预测,因为没有用到相关参数的图节点,所以无需传入。
但是,其中有个例外,就是网络本身的层参数,看以下代码:
X = tf.placeholder(tf.float32, shape=[None, n_inputs])
cross_entropy_loss = NetWork_loss(X)
grads_and_vars = optimizer.compute_gradients(cross_entropy_loss)
# print(grads_and_vars)
gradients = [grad for grad, variable in grads_and_vars]
gradient_placeholders = []
grads_and_vars_feed = []
for grad, variable in grads_and_vars:
gradient_placeholder = tf.placeholder(tf.float32)
# gradient_placeholder = tf.placeholder(tf.float32, shape=grad.get_shape())
gradient_placeholders.append(gradient_placeholder)
grads_and_vars_feed.append((gradient_placeholder, variable))
# training_op = optimizer.apply_gradients(grads_and_vars_feed)
training_op = grads_and_vars_feed
这个网络的定义将计算出的损失,也就是cross_entropy_loss传入优化器optimizer,进行梯度的计算。
这个梯度计算的方法会返回一个个的元组,分别代表的就是当层的梯度和参数。
后面那一大段是为了进行梯度的更新,不过这里使用的是 “策略梯度”,也就是强化学习中需要它执行过了若干步后再更新梯度。
这里的更新的梯度就不再是神经网络中直接计算出的导数了,而是和你强化学习的reward挂钩,不再细讲,总之就是通过一个循环,把每层计算出的梯度传出,再传入根据自己的策略计算出的新的梯度。
重点来了,当我使用下面这段代码进行执行的时候:
with tf.Session() as sess:
init.run()
feed_dict = {
}
for var_index, gradient_placeholder in enumerate(gradient_placeholders):
feed_dict[gradient_placeholder] = [9.]
a = sess.run(training_op, feed_dict=feed_dict)
居然没报错!我们明明看到要计算的 training_op 需要的是grads_and_vars_feed,而后面这个参数用到了variable,也就是最终需要的还是 cross_entropy_loss,这个loss又是根据传入的X计算出来的
那追本溯源下去,这个training_op用到了两个需要填充的地方,而我们仅仅传入了梯度,凭什么不传入X就可以执行?
经过我的一番探索,发现把variable变成grad就会报错需要传入X,可这两个明明是一个循环取出来的。
经过一步步断点调试,终于破案了,原来这个 variable 是网络参数,是定义的图结构中的一些成员变量,也就是在执行 init.run() 的时候已经被初始化了~图结构中会直接把这样的变量标记为“已计算”。而grad则需要通过X,也就是上面正常逻辑回溯一步一步计算过来,所以没传入X的时候这个变量会被标记为“未计算”,那在拉取的时候才会回溯到X。
所以tensorflow并不是计算了整个图,只是计算了与想要fetch的值相关的部分,而这个“相关的部分”指的是还没被标记为“已计算的部分”,尤其是对于网络参数这类“固有的”,虽然表面上看起来是通过一个个的函数得出,实际上则不然,很有误导性。
边栏推荐
猜你喜欢

让学指针变得更简单(二)

How to download pictures with hyperlinks

深信服X计划-系统基础总结

Daily question brushing record (6)

C语言 函数指针与回调函数

Airbnb double disk microservice

Deeply convinced plan X - system foundation summary

【动态规划】—— 背包问题

On the complexity of software development and the way to improve its efficiency
Kotlin函数使用示例教程
随机推荐
【周赛复盘】LeetCode第81场双周赛
清楚的自我定位
Database Series: MySQL index optimization and performance improvement summary (comprehensive version)
IJCAI 2022 | 用一行代码大幅提升零样本学习方法效果,南京理工&牛津提出即插即用分类器模块
Failed to execute NPM instruction, prompting ssh: Permission denied
Steps for win10 to completely and permanently turn off automatic updates
PLM还能怎么用?
云原生(三十) | Kubernetes篇之应用商店-Helm
Vs debugging skills
How to close windows defender Security Center
What else can PLM do?
Journal quotidien des questions (6)
[medical segmentation] unet3+
【医学分割】unet3+
【TcaplusDB知识库】TcaplusDB-tcapulogmgr工具介绍(一)
再懂已是曲中人
Realization of hospital medical record management system based on JSP
ViewPager2使用记录
使用bitnamiredis-sentinel部署Redis 哨兵模式
Bluetooth health management device based on stm32