当前位置:网站首页>Keras 分割网络自定义评估函数 - mean iou
Keras 分割网络自定义评估函数 - mean iou
2022-08-05 11:34:00 【为为为什么】
Keras训练网络过程中需要实时观察性能,mean iou不是keras自带的评估函数,tf的又觉得不好用,自己写了一个,经过测试没有问题,本文记录自定义keras mean iou评估的实现方法。
计算 IoU
用numpy计算的,作为IoU的ground truth用作测试使用:
def iou_numpy(y_true,y_pred):
intersection = np.sum(np.multiply(y_true.astype('bool'),y_pred == 1))
union = np.sum((y_true.astype('bool')+y_pred.astype('bool'))>0)
return intersection/unionkeras metric IoU
def iou_keras(y_true, y_pred):
"""
Return the Intersection over Union (IoU).
Args:
y_true: the expected y values as a one-hot
y_pred: the predicted y values as a one-hot or softmax output
Returns:
the IoU for the given label
"""
label = 1
# extract the label values using the argmax operator then
# calculate equality of the predictions and truths to the label
y_true = K.cast(K.equal(y_true, label), K.floatx())
y_pred = K.cast(K.equal(y_pred, label), K.floatx())
# calculate the |intersection| (AND) of the labels
intersection = K.sum(y_true * y_pred)
# calculate the |union| (OR) of the labels
union = K.sum(y_true) + K.sum(y_pred) - intersection
# avoid divide by zero - if the union is zero, return 1
# otherwise, return the intersection over union
return K.switch(K.equal(union, 0), 1.0, intersection / union)计算 mean IoU
mean IoU 简便起见,选取 (0,1,0.05) 作为不同的IoU阈值,计算平均IoU numpy 真实值计算
def mean_iou_numpy(y_true,y_pred):
iou_list = []
for thre in list(np.arange(0.0000001,0.99,0.05)):
y_pred_temp = y_pred >= thre
iou = iou_numpy(y_true, y_pred_temp)
iou_list.append(iou)
return np.mean(iou_list)Keras mean IoU
def mean_iou_keras(y_true, y_pred):
"""
Return the mean Intersection over Union (IoU).
Args:
y_true: the expected y values as a one-hot
y_pred: the predicted y values as a one-hot or softmax output
Returns:
the mean IoU
"""
label = 1
# extract the label values using the argmax operator then
# calculate equality of the predictions and truths to the label
y_true = K.cast(K.equal(y_true, label), K.floatx())
mean_iou = K.variable(0)
thre_list = list(np.arange(0.0000001,0.99,0.05))
for thre in thre_list:
y_pred_temp = K.cast(y_pred >= thre, K.floatx())
y_pred_temp = K.cast(K.equal(y_pred_temp, label), K.floatx())
# calculate the |intersection| (AND) of the labels
intersection = K.sum(y_true * y_pred_temp)
# calculate the |union| (OR) of the labels
union = K.sum(y_true) + K.sum(y_pred_temp) - intersection
iou = K.switch(K.equal(union, 0), 1.0, intersection / union)
mean_iou = mean_iou + iou
return mean_iou / len(thre_list)测试
## 随机生成预测值
y_true_np = np.ones([10,10])
y_pred_np = np.random.rand(10,10)
## 真实IoU值
print(f' iou : {iou_numpy(y_true_np, y_pred_np)}')
print(f' mean_iou_numpy : {mean_iou_numpy(y_true_np, y_pred_np)}')
y_true = tf.Variable(y_true_np)
y_pred = tf.Variable(y_pred_np)
## 计算节点
iou_res = iou_keras (y_true, y_pred)
m_iou_res = mean_iou_keras (y_true, y_pred)
## 变量初始化
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
## 由于存在误差,结果在0.0000001范围内即可认为相同
result = sess.run(iou_res)
print(f'result : {result} \nsame with ground truth: {abs(iou_numpy(y_true_np, y_pred_np) - result)< 0.0000001}')
result = sess.run(m_iou_res)
print(f'result : {result} \nsame with ground truth: {abs(mean_iou_numpy(y_true_np, y_pred_np) - result) < 0.0000001}') 输出:
iou : 0.0
mean_iou_numpy : 0.5295
result : 0.0
same with ground truth: True
result : 0.5295000076293945
same with ground truth: True源码下载
边栏推荐
- nyoj757 期末考试 (优先队列)
- 提问题进不去。想问大家一个关于返回值的问题(图的遍历),求给小白解答啊
- Learning Deep Compact Image Representations for Visual Tracking
- 并非富人专属,一文让你对NFT改观
- Android 开发用 Kotlin 编程语言 二 条件控制
- 一张图理解EOS是什么
- 解决2022Visual Studio中scanf返回值被忽略问题
- MMDetection in action: MMDetection training and testing
- Exploration and practice of transaction link under multi-service mode
- 解决 json.dump 报错:TypeError - Object of type xxx is not JSON serializable
猜你喜欢
随机推荐
安全软件Avast与赛门铁克诺顿NortonLifeLock合并获英国批准
2022杭电多校联赛第六场 题解
PostgreSQL 2022 Report: Rising popularity, open source, reliability and scaling key
Byte Qiu Zhao confused me on both sides, and asked me under what circumstances would the SYN message be discarded?
Can't get in to ask questions.I want to ask you a question about the return value (traversal of the graph), please give Xiaobai an answer.
时间格式2020-01-13T16:00:00.000Z中的T和Z分别表示什么,如何处理
Support Vector Machine SVM
微服务结合领域驱动设计落地
Http-Sumggling Cache Vulnerability Analysis
常见的 web 安全问题总结
手把手教你定位线上MySQL慢查询问题,包教包会
Letter from Silicon Valley: Act fast, Facebook, Quora and other successful "artifacts"!
Introduction to the Evolution of Data Governance System
hdu1455 Sticks (search+pruning+pruning+.....+pruning)
平安萌娃卡保险怎么样?让父母读懂几个识别产品的方法
Android development with Kotlin programming language three loop control
hdu4545 魔法串
JS 从零手写实现一个call、apply方法
硅谷来信:快速行动,Facebook、Quora等成功的“神器”!
WPF开发随笔收录-WriteableBitmap绘制高性能曲线图









