当前位置:网站首页>基于MindSpore高效完成图像分割,实现Dice!
基于MindSpore高效完成图像分割,实现Dice!
2022-08-05 09:55:00 【昇思MindSpore】
Dice 系数的介绍及实现
Dice系数原理
Dice是医学图像比赛中使用频率最高的度量指标,它是一种集合相似度度量指标,通常用于计算两个样本的相似度,值阈为[0, 1]。在医学图像中经常用于图像分割,分割的最好结果是1,最差时候结果为0.
Dice系数计算公式如下:
当然Dice也有另一个表达方式,是利用混淆矩阵中的TP,FP,FN来表达:
该公式原理如下图:
MindSpore代码实现
先简单介绍一下MindSpore——新一代AI开源计算框架。创新编程范式,AI科学家和工程师更易使用,便于开放式创新;该计算框架可满足终端、边缘计算、云全场景需求,能更好保护数据隐私;可开源,形成广阔应用生态。
2020年3月28日,华为在开发者大会2020上宣布,全场景AI计算框架MindSpore在码云正式开源。MindSpore着重提升易用性并降低AI开发者的开发门槛,MindSpore原生适应每个场景包括端、边缘和云,并能够在按需协同的基础上,通过实现AI算法即代码,使开发态变得更加友好,显著减少模型开发时间,降低模型开发门槛。
通过MindSpore自身的技术创新及MindSpore与华为昇腾AI处理器的协同优化,实现了运行态的高效,大大提高了计算性能;MindSpore也支持GPU、CPU等其它处理器。
"""Dice"""
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric
class Dice(Metric):
def __init__(self, smooth=1e-5):
super(Dice, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self._dice_coeff_sum = 0
self._samples_num = 0
self.clear()
def clear(self):
# 是来清除历史数据
self._dice_coeff_sum = 0
self._samples_num = 0
def update(self, *inputs):
# 更新输入数据,y_pred和y,数据输入类型可以是Tensor,lisy或numpy,维度必须相等
if len(inputs) != 2:
raise ValueError('Dice need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
# 将数据进行转换,统一转换为numpy
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
self._samples_num += y.shape[0]
if y_pred.shape != y.shape:
raise RuntimeError('y_pred and y should have same the dimension, but the shape of y_pred is{}, '
'the shape of y is {}.'.format(y_pred.shape, y.shape))
# 先求交集,利用dot对应点相乘再相加
intersection = np.dot(y_pred.flatten(), y.flatten())
# 求并集,先将输入shape都拉到一维,然后分别进行点乘,再将两个输入进行相加
unionset = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
# 利用公式进行计算,加smooth是为了防止分母为0,避免当pred和true都为0时,分子被0除的问题,同时减少过拟合
single_dice_coeff = 2 * float(intersection) / float(unionset + self.smooth)
# 对每一批次的系数进行累加
self._dice_coeff_sum += single_dice_coeff
def eval(self):
# 进行计算
if self._samples_num == 0:
raise RuntimeError('Total samples num must not be 0.')
return self._dice_coeff_sum / float(self._samples_num)
使用方法如下:
import numpy as np
from mindspore import Tensor
from mindspore.nn.metrics Dice
metric = Dice(smooth=1e-5)
metric.clear()
x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
metric.update(x, y)
dice = metric.eval()
print(dice)
0.20467791371802546
每个batch(两组数据)进行计算的时候如下:
import numpy as np
from mindspore import Tensor
from mindspore.nn.metrics Dice
metric = Dice(smooth=1e-5)
metric.clear()
x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
metric.update(x, y)
x1= Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y1 = Tensor(np.array([[1, 0], [1, 1], [1, 0]]))
metric.update(x1, y1)
avg_dice = metric.eval()
print(dice)
Dice Loss 介绍及实现
Dice Loss原理
Dice Loss 原理是在 Dice 系数的基础上进行计算,用1去减Dice系数
这种是在二分类一个批次只有一张图的情况,当一个批次有N张图片时,可以将图片压缩为一维向量,如下图:
对应的label也会相应变化,最后一起计算N张图片的Dice系数和Dice Loss。
MindSpore 二分类 DiceLoss 代码实现
class DiceLoss(_Loss):
def __init__(self, smooth=1e-5):
super(DiceLoss, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self.reshape = P.Reshape()
def construct(self, logits, label):
# 进行维度校验,维度必须相等。(输入必须是tensor)
_check_shape(logits.shape, label.shape)
# 求交集,和dice系数一样的方式
intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1)))
# 求并集,和dice系数一样的方式
unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \
self.reduce_sum(self.mul(label.view(-1), label.view(-1)))
# 利用公式进行计算
single_dice_coeff = (2 * intersection) / (unionset + self.smooth)
dice_loss = 1 - single_dice_coeff / label.shape[0]
return dice_loss.mean()
@constexpr
def _check_shape(logits_shape, label_shape):
validator.check('logits_shape', logits_shape, 'label_shape', label_shape)
使用方法如下:
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
loss = nn.DiceLoss(smooth=1e-5)
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
output = loss(y_pred, y)
print(output)
[0.7953220862819745]
MindSpore 多分类 MultiClassDiceLoss 代码实现
在MindSpore中支持在语义分割中有多种损失函数可以选择,不过最常用的还是用交叉熵来做损失函数。
class MultiClassDiceLoss(_Loss):
def __init__(self, weights=None, ignore_indiex=None, activation=A.Softmax(axis=1)):
super(MultiClassDiceLoss, self).__init__()
# 利用Dice系数
self.binarydiceloss = DiceLoss(smooth=1e-5)
# 权重是一个Tensor,应该和分类数的维度一样:Tensor of shape `[num_classes, dim]`。
self.weights = weights if weights is None else validator.check_value_type("weights", weights, [Tensor])
# 要忽略的类别序号
self.ignore_indiex = ignore_indiex if ignore_indiex is None else \
validator.check_value_type("ignore_indiex", ignore_indiex, [int])
# 使用激活函数
self.activation = A.get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, Cell):
raise TypeError("The activation must be str or Cell, but got {}.".format(activation))
self.activation_flag = self.activation is not None
self.reshape = P.Reshape()
def construct(self, logits, label):
# 进行维度校验,维度必须相等。(输入必须是tensor)
_check_shape(logits.shape, label.shape)
# 先定义一个loss,初始值为0
total_loss = 0
# 如果使用激活函数
if self.activation_flag:
logits = self.activation(logits)
# 按照标签的维度的第一个数进行遍历
for i in range(label.shape[1]):
if i != self.ignore_indiex:
dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
if self.weights is not None:
_check_weights(self.weights, label)
dice_loss *= self.weights[i]
total_loss += dice_loss
return total_loss/label.shape[1]
使用方法如下:
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax")
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
output = loss(y_pred, y)
print(output)
[0.7761003]
Dice Loss 存在的问题
训练误差曲线非常混乱,很难看出关于收敛的信息。尽管可以检查在验证集上的误差来避开此问题。
边栏推荐
- 蚁剑webshell动态加密连接分析与实践
- 茄子科技CEO仇俊:以用户为中心,做用户真正需要的产品
- IO流篇 -- 基于io流实现文件夹拷贝(拷贝子文件夹及子文件夹内文件)满满的干货
- 开源一夏|OpenHarmony如何查询设备类型(eTS)
- dotnet OpenXML 解析 PPT 图表 面积图入门
- 2022.8.3
- tensorflow.keras cannot introduce layers
- 营销建议 | 您有一份八月营销月历待查收! 建议收藏 !
- PAT乙级-B1020 月饼(25)
- Qiu Jun, CEO of Eggplant Technology: Focus on users and make products that users really need
猜你喜欢
自定义过滤器和拦截器实现ThreadLocal线程封闭
2.4G无线收发模块的应用
19.服务器端会话技术Session
Keil升级到AC6后,到底有哪些变化?
Assembly language (8) x86 inline assembly
CPU的亲缘性affinity
Which big guy has the 11G GI and ojvm patches in April or January 2020, please help?
The technological achievements of Shanghai Konan were selected into the "2021 Shanghai Network Security Industry Innovation Research Achievement Catalog" by the Municipal Commission of Economy and Inf
Qiu Jun, CEO of Eggplant Technology: Focus on users and make products that users really need
轩辕实验室丨欧盟EVITA项目预研 第一章(四)
随机推荐
公众号如何运维?公众号运维专业团队
IO流篇 -- 基于io流实现文件夹拷贝(拷贝子文件夹及子文件夹内文件)满满的干货
Pycharm 常用外部工具
无题九
Concurrent CAS
Why are RELTABLESPACE values 0 for many tables displayed in sys_class?
Does flink cdc support synchronization from oracle dg library?
leetcode refers to Offer 10- II. Frog jumping steps
Seata source code analysis: initialization process of TM RM client
eKuiper Newsletter 2022-07|v1.6.0:Flow 编排 + 更好用的 SQL,轻松表达业务逻辑
What is CRM Decision Analysis Management?
科普大佬说 | 港大黄凯斌老师带你解锁黑客帝国与6G的关系
Oracle temporary table space role
无题十一
leetcode: 529. 扫雷游戏
Which big guy has the 11G GI and ojvm patches in April or January 2020, please help?
After Keil upgrades to AC6, what changes?
Overall design and implementation of Kubernetes-based microservice project
2022.8.3
深度学习21天——卷积神经网络(CNN):服装图像分类(第3天)