当前位置:网站首页>在Tensorflow中把Tensor转换为ndarray时,循环中不断调用run或者eval函数,代码运行越来越慢!
在Tensorflow中把Tensor转换为ndarray时,循环中不断调用run或者eval函数,代码运行越来越慢!
2022-06-12 08:28:00 【躁动的风儿】
问题
我有一个这样的需求:我目前有一个已经训练好的encoder模型,它的输出是Tensor类型,我想把它转换成ndarray类型。通过查询资料,我发现可以利用sess.run()把Tensor转换为ndarray,于是在我的代码里调用sess.run()成功转换了数据类型。
但是,我这个数据转换在每一次的循环中都会调用,也就是循环中一直调用sess.run(),于是问题来了,每循环一次,sess.run的用时都比上一次要久,导致后面训练越来越慢。从第一次调用用时0.17s到后面第100次调用时0.27s,而且这才是100次,如果训练10000次,那不知道要等多久,所以这个问题必须解决!
问题原因
如果在某一个循环里不断建立tensorflow图节点再运行的话,会导致tensorflow运行越来越慢。具体问题请看代码注释,没有注释的代码行可以不用关注,问题代码如下:
import gym
from gym.spaces import Box
import numpy as np
from tensorflow import keras
import tensorflow as tf
import time
class MyWrapper(gym.ObservationWrapper):
def __init__(self, env, encoder, latent_dim = 2):
super().__init__(env)
self._observation_space = Box(-np.inf, np.inf, shape=(7 + latent_dim,), dtype=np.float32)
self.observation_space = self._observation_space
self.encoder = encoder # 这是我已经提前训练好的模型
tf.InteractiveSession()
self.sess = tf.get_default_session()
self.sess.run(tf.global_variables_initializer())
def observation(self, obs):
obs = np.reshape(obs, (1, -1))
latent_z_tensor = self.encoder(obs)[2] # 问题就在与这里,这行代码在调用run时,会不断的创建图节点,所以越来越慢
t=time.time() # 测试运行用时
latent_z_arr = sels.sess.run(latent_z_tensor) # 每次run时,就会把上面的图重新构建一次
print(time.time()-t) # 测试运行用时
obs = np.reshape(obs, (-1,))
latent_z_arr = np.reshape(latent_z_arr, (-1,))
obs = obs.tolist()
obs.extend(latent_z_arr.tolist())
obs = np.array(obs)
return obs
解决思路
在初始化时,就建立好图结构,使用tf.placeholder占位符表示obs这个变量,具体方案示例如下(可以只关注带有注释的行):
import gym
from gym.spaces import Box
import numpy as np
from tensorflow import keras
import tensorflow as tf
import time
class MyWrapper(gym.ObservationWrapper):
def __init__(self, env, encoder, latent_dim = 2):
super().__init__(env)
self._observation_space = Box(-np.inf, np.inf, shape=(7 + latent_dim,), dtype=np.float32)
self.observation_space = self._observation_space
self.encoder = encoder
tf.InteractiveSession()
self.sess = tf.get_default_session()
self.obs=tf.placeholder(dtype=tf.float32,shape=(1,7)) # 重点在于这两行代码,初始化时先构建好图,先用占位符表示obs,实际运行时只需喂数据obs就好了
self.latent_z_tensor = self.encoder(self.obs)[2] # 在初始化时构建图
self.sess.run(tf.global_variables_initializer())
def observation(self, obs):
obs = np.reshape(obs, (1, -1))
t=time.time() # 测试运行用时
latent_z_arr = self.sess.run(self.latent_z_tensor, feed_dict={
self.obs:obs}) # 这里只需喂数据,不会重新构建图了。
print(time.time()-t) # 测试运行用时
obs = np.reshape(obs, (-1,))
latent_z_arr = np.reshape(latent_z_arr, (-1,))
obs = obs.tolist()
obs.extend(latent_z_arr.tolist())
obs = np.array(obs)
return obs
现在,数据类型转换完成,代码运行慢也解决了!
这个问题至此解决完成,查了四五天的资料终于搞定,这一刻,解决问题带来的快乐把前两天失恋的阴霾都冲散了不少,真是太开心了。
边栏推荐
- MPLS的原理与配置
- Model Trick | CVPR 2022 Oral - Stochastic Backpropagation A Memory Efficient Strategy
- (P36-P39)右值和右值引用、右值引用的作用以及使用、未定引用类型的推导、右值引用的传递
- Vscade debug TS
- MATLAB image processing - cosine noise removal in image (with code)
- 报错:清除网站内搜索框中的历史记录?
- (p25-p26) three details of non range based for loop and range based for loop
- Seurat package addmodulescore is used for bulk RNA SEQ data
- Py & go programming skills: logic control to avoid if else
- How to write simple music program with MATLAB
猜你喜欢

(p21-p24) unified data initialization method: List initialization, initializing objects of non aggregate type with initialization list, initializer_ Use of Lisy template class

Hypergeometric cumulative distribution test overlap

ctfshow web4

Where does the driving force of MES system come from? What problems should be paid attention to in model selection?

Model compression | tip 2022 - Distillation position adaptation: spot adaptive knowledge distillation

企业为什么要实施MES?具体操作流程有哪些?

MPLS的原理与配置

JVM learning notes: three local method interfaces and execution engines

【指針進階三】實現C語言快排函數qsort&回調函數

Special notes on using NAT mode in VM virtual machine
随机推荐
ctfshow web3
[dynamic memory management] malloc & calloc and realloc and written test questions and flexible array
You get download the installation and use of artifact
(P13)final关键字的使用
Error: clear the history in the search box in the website?
Gtest/gmock introduction and Practice
(P19-P20)委托构造函数(代理构造函数)和继承构造函数(使用using)
What should be paid attention to when establishing MES system? What benefits can it bring to the enterprise?
What is the quality traceability function of MES system pursuing?
(P13) use of final keyword
(p19-p20) delegate constructor (proxy constructor) and inheritance constructor (using)
What is an extension method- What are Extension Methods?
How to write simple music program with MATLAB
MATLAB image processing - Otsu threshold segmentation (with code)
(p25-p26) three details of non range based for loop and range based for loop
Webrtc adding third-party libraries
vscode 下载慢解决办法
Special notes on using NAT mode in VM virtual machine
对企业来讲,MES设备管理究竟有何妙处?
At present, MES is widely used. Why are there few APS scheduling systems? Why?