当前位置:网站首页>Keras 分割网络自定义评估函数 - mean iou
Keras 分割网络自定义评估函数 - mean iou
2022-08-05 11:34:00 【为为为什么】
Keras训练网络过程中需要实时观察性能,mean iou不是keras自带的评估函数,tf的又觉得不好用,自己写了一个,经过测试没有问题,本文记录自定义keras mean iou评估的实现方法。
计算 IoU
用numpy计算的,作为IoU的ground truth用作测试使用:
def iou_numpy(y_true,y_pred):
intersection = np.sum(np.multiply(y_true.astype('bool'),y_pred == 1))
union = np.sum((y_true.astype('bool')+y_pred.astype('bool'))>0)
return intersection/unionkeras metric IoU
def iou_keras(y_true, y_pred):
"""
Return the Intersection over Union (IoU).
Args:
y_true: the expected y values as a one-hot
y_pred: the predicted y values as a one-hot or softmax output
Returns:
the IoU for the given label
"""
label = 1
# extract the label values using the argmax operator then
# calculate equality of the predictions and truths to the label
y_true = K.cast(K.equal(y_true, label), K.floatx())
y_pred = K.cast(K.equal(y_pred, label), K.floatx())
# calculate the |intersection| (AND) of the labels
intersection = K.sum(y_true * y_pred)
# calculate the |union| (OR) of the labels
union = K.sum(y_true) + K.sum(y_pred) - intersection
# avoid divide by zero - if the union is zero, return 1
# otherwise, return the intersection over union
return K.switch(K.equal(union, 0), 1.0, intersection / union)计算 mean IoU
mean IoU 简便起见,选取 (0,1,0.05) 作为不同的IoU阈值,计算平均IoU numpy 真实值计算
def mean_iou_numpy(y_true,y_pred):
iou_list = []
for thre in list(np.arange(0.0000001,0.99,0.05)):
y_pred_temp = y_pred >= thre
iou = iou_numpy(y_true, y_pred_temp)
iou_list.append(iou)
return np.mean(iou_list)Keras mean IoU
def mean_iou_keras(y_true, y_pred):
"""
Return the mean Intersection over Union (IoU).
Args:
y_true: the expected y values as a one-hot
y_pred: the predicted y values as a one-hot or softmax output
Returns:
the mean IoU
"""
label = 1
# extract the label values using the argmax operator then
# calculate equality of the predictions and truths to the label
y_true = K.cast(K.equal(y_true, label), K.floatx())
mean_iou = K.variable(0)
thre_list = list(np.arange(0.0000001,0.99,0.05))
for thre in thre_list:
y_pred_temp = K.cast(y_pred >= thre, K.floatx())
y_pred_temp = K.cast(K.equal(y_pred_temp, label), K.floatx())
# calculate the |intersection| (AND) of the labels
intersection = K.sum(y_true * y_pred_temp)
# calculate the |union| (OR) of the labels
union = K.sum(y_true) + K.sum(y_pred_temp) - intersection
iou = K.switch(K.equal(union, 0), 1.0, intersection / union)
mean_iou = mean_iou + iou
return mean_iou / len(thre_list)测试
## 随机生成预测值
y_true_np = np.ones([10,10])
y_pred_np = np.random.rand(10,10)
## 真实IoU值
print(f' iou : {iou_numpy(y_true_np, y_pred_np)}')
print(f' mean_iou_numpy : {mean_iou_numpy(y_true_np, y_pred_np)}')
y_true = tf.Variable(y_true_np)
y_pred = tf.Variable(y_pred_np)
## 计算节点
iou_res = iou_keras (y_true, y_pred)
m_iou_res = mean_iou_keras (y_true, y_pred)
## 变量初始化
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
## 由于存在误差,结果在0.0000001范围内即可认为相同
result = sess.run(iou_res)
print(f'result : {result} \nsame with ground truth: {abs(iou_numpy(y_true_np, y_pred_np) - result)< 0.0000001}')
result = sess.run(m_iou_res)
print(f'result : {result} \nsame with ground truth: {abs(mean_iou_numpy(y_true_np, y_pred_np) - result) < 0.0000001}') 输出:
iou : 0.0
mean_iou_numpy : 0.5295
result : 0.0
same with ground truth: True
result : 0.5295000076293945
same with ground truth: True源码下载
边栏推荐
- 提问题进不去。想问大家一个关于返回值的问题(图的遍历),求给小白解答啊
- 【名词】什么是PV和UV?
- 2022 极术通讯-基于安谋科技 “星辰” STAR-MC1的灵动MM32F2570开发板深度评测
- 多业务模式下的交易链路探索与实践
- Hands-on Deep Learning_GoogLeNet / Inceptionv1v2v3v4
- Android development with Kotlin programming language three loop control
- CenOS MySQL入门及安装
- 巴比特 | 元宇宙每日必读:中国1775万件数字藏品分析报告显示,85%的已发行数藏开通了转赠功能...
- STM32 entry development: write XPT2046 resistive touch screen driver (analog SPI)
- Guys, I am a novice. I use flinksql to write a simple count of user visits according to the document, but it ends after executing it once.
猜你喜欢
随机推荐
五大理由告诉你为什么开发人员选择代码质量静态分析工具Klocwork来实现软件安全
Android 开发用 Kotlin 编程语言 二 条件控制
Gao Zelong attended the Boao Global Tourism Ecology Conference to talk about Metaverse and Future Network Technology
脱光衣服待着就能减肥,当真有这好事?
Version Control | Longzhi invites you to go to the GOPS Global Operation and Maintenance Conference to explore the road of large-scale, agile, high-quality and open software development and operation
问题征集丨ECCV 2022中国预讲会 · Panel专题研讨会
智源社区AI周刊No.92:“计算复杂度”理论奠基人Juris Hartmanis逝世;美国AI学生九年涨2倍,大学教师短缺;2022智源大会观点报告发布[附下载]
常见的 web 安全问题总结
“蘑菇书”是怎样磨出来的?
TiDB 6.0 Placement Rules In SQL 使用实践
Scaling-law和模型结构的关系:不是所有的结构放大后都能保持最好性能
2022杭电多校联赛第六场 题解
可视化开发必看:智慧城市四大核心技术
软件测试之集成测试
Discover the joy of C language
hdu 1870 愚人节的礼物 (栈)
Linux:记一次CentOS7安装MySQL8(博客合集)
自定义过滤器和拦截器实现ThreadLocal线程封闭
MMDetection实战:MMDetection训练与测试
365天挑战LeetCode1000题——Day 050 在二叉树中增加一行 二叉树







