当前位置:网站首页>tensorflow的session和内存溢出
tensorflow的session和内存溢出
2022-08-05 05:15:00 【sc0024】
我有一个九十多张图片的成对测试集,要计算输出图片和gt的PSNR和SSIM,这两个的计算tf是有接口的,直接调用就可以,问题是每次跑几张图片就内存溢出,进程被kill掉了,经过一番调研终于解决,来此记录。
TensorFlow的session
首先介绍一下session(会话)。
会话的作用是处理内存分配和优化,使我们能够实际执行由计算图指定的计算。你可以将计算图想象为我们想要执行的计算的「模版」:它列出了所有步骤。为了使用计算图,我们需要启动一个会话,它使我们能够实际地完成任务;例如,遍历模版的所有节点来分配一堆用于存储计算输出的存储器。为了使用 TensorFlow 进行各种计算,你既需要计算图也需要会话。
会话包含一个指向全局图的指针,该指针通过指向所有节点的指针不断更新。这意味着在创建节点之前还是之后创建会话都无所谓。
创建会话对象后,可以使用 sess.run(node) 返回节点的值,并且 TensorFlow 将执行确定该值所需的所有计算。
常用的两种用法:
1、sess = tf.Session()使用完需要显式关闭:
sess.close()
2、with tf.Session() as sess: 自动完成关闭动作
问题分析
之前我以为,每次循环会覆盖掉psnr和ssim两个变量的内容,所以训练集有多大应该不影响内存的占用。但每次跑十几张图片进程就被kill掉,说明在循环的过程中有一些内存没有被释放掉,产生了积累。查了一些资料之后,我找到了tf.reset_default_graph()这个函数。
TensorFlow执行的时候会自动新建很多节点,而重置计算图就可以解决这个问题。
for test_id in test_ids:
with tf.Session() as sess:
img1 = cv2.imread(result_dir + '%12s_out.png' % test_id)
img2 = cv2.imread(result_dir + '%12s_gt.png' % test_id)
psnr = tf.image.psnr(img1, img2, max_val=255)
ssim = tf.image.ssim(tf.convert_to_tensor(img1), tf.convert_to_tensor(img2), max_val=255)
i = i + 1
print(i, test_id, psnr.eval(), ssim.eval())
tf.reset_default_graph()#在每个session执行完重置计算图,注意一定不能在sess代码块里面调用
边栏推荐
- [Database and SQL study notes] 10. (T-SQL language) functions, stored procedures, triggers
- My 的第一篇博客!!!
- CVPR 2022 |节省70%的显存,训练速度提高2倍
- 初识机器学习
- 基于Flink CDC实现实时数据采集(二)-Source接口实现
- el-pagination左右箭头替换成文字上一页和下一页
- HQL statement execution process
- 【Pytorch学习笔记】10.如何快速创建一个自己的Dataset数据集对象(继承Dataset类并重写对应方法)
- RecycleView和ViewPager2
- Pandas(五)—— 分类数据、读取数据库
猜你喜欢

Thread handler handle IntentServvice handlerThread

spingboot 容器项目完成CICD部署

MaskDistill - Semantic segmentation without labeled data

Calling Matlab configuration in pycharm: No module named 'matlab.engine'; 'matlab' is not a package

神经网络也能像人类利用外围视觉一样观察图像

el-pagination左右箭头替换成文字上一页和下一页

【Pytorch学习笔记】10.如何快速创建一个自己的Dataset数据集对象(继承Dataset类并重写对应方法)

【论文精读】Rich Feature Hierarchies for Accurate Object Detection and Semantic Segmentation(R-CNN)

大型Web网站高并发架构方案

AWS 常用服务
随机推荐
拿出接口数组对象中的所有name值,取出同一个值
【数据库和SQL学习笔记】9.(T-SQL语言)定义变量、高级查询、流程控制(条件、循环等)
ECCV2022 | RU&谷歌提出用CLIP进行zero-shot目标检测!
【论文阅读-表情捕捉】ExpNet: Landmark-Free, Deep, 3D Facial Expressions
Kubernetes常备技能
redis persistence
npm搭建本地服务器,直接运行build后的目录
Mysql-连接https域名的Mysql数据源踩的坑
【Pytorch学习笔记】11.取Dataset的子集、给Dataset打乱顺序的方法(使用Subset、random_split)
服务网格istio 1.12.x安装
Pandas(五)—— 分类数据、读取数据库
CVPR最佳论文得主清华黄高团队提出首篇动态网络综述
flink基本原理及应用场景分析
基于Flink CDC实现实时数据采集(一)-接口设计
MySQL
day8字典作业
【Reading】Long-term update
【数据库和SQL学习笔记】10.(T-SQL语言)函数、存储过程、触发器
CVPR 2022 | 70% memory savings, 2x faster training
day6-列表作业