当前位置:网站首页>Complete image segmentation efficiently based on MindSpore and realize Dice!
Complete image segmentation efficiently based on MindSpore and realize Dice!
2022-08-05 10:01:00 【Ascension MindSpore】
Dice Introduction and implementation of coefficients
DiceCoefficient principle
DiceIt is the most frequently used metric in medical image competitions,It is an ensemble similarity measure,通常用于计算两个样本的相似度,The value threshold is[0, 1].Often used for image segmentation in medical images,The best result of segmentation is 1,The worst time result is 0.
Dice系数计算公式如下:
当然DiceThere is also another expression,is used in the confusion matrixTP,FP,FN来表达:
The principle of this formula is shown in the figure below:
MindSpore代码实现
先简单介绍一下MindSpore——新一代AI开源计算框架.创新编程范式,AIScientists and engineers more易使用,便于开放式创新;This computational framework satisfies终端、边缘计算、云全场景需求,能更好保护数据隐私;可开源,形成广阔应用生态.
2020年3月28日,华为在开发者大会2020上宣布,全场景AI计算框架MindSpore在码云正式开源.MindSpore着重提升易用性并降低AI开发者的开发门槛,MindSpore原生适应每个场景包括端、边缘和云,并能够在按需协同的基础上,通过实现AI算法即代码,使开发态变得更加友好,显著减少模型开发时间,降低模型开发门槛.
通过MindSpore自身的技术创新及MindSpore与华为昇腾AI处理器的协同优化,实现了Efficient operation,大大提高了计算性能;MindSpore也支持GPU、CPU等其它处理器.
"""Dice"""
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric
class Dice(Metric):
def __init__(self, smooth=1e-5):
super(Dice, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self._dice_coeff_sum = 0
self._samples_num = 0
self.clear()
def clear(self):
# Yes to clear historical data
self._dice_coeff_sum = 0
self._samples_num = 0
def update(self, *inputs):
# 更新输入数据,y_pred和y,The data entry type can beTensor,lisy或numpy,维度必须相等
if len(inputs) != 2:
raise ValueError('Dice need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
# 将数据进行转换,统一转换为numpy
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
self._samples_num += y.shape[0]
if y_pred.shape != y.shape:
raise RuntimeError('y_pred and y should have same the dimension, but the shape of y_pred is{}, '
'the shape of y is {}.'.format(y_pred.shape, y.shape))
# Seek the intersection first,利用dotThe corresponding points are multiplied and added together
intersection = np.dot(y_pred.flatten(), y.flatten())
# 求并集,先将输入shapeAll pulled to one dimension,Then do point multiplication respectively,The two inputs are then added together
unionset = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
# 利用公式进行计算,加smooth是为了防止分母为0,避免当pred和true都为0时,分子被0除的问题,同时减少过拟合
single_dice_coeff = 2 * float(intersection) / float(unionset + self.smooth)
# The coefficients for each batch are accumulated
self._dice_coeff_sum += single_dice_coeff
def eval(self):
# 进行计算
if self._samples_num == 0:
raise RuntimeError('Total samples num must not be 0.')
return self._dice_coeff_sum / float(self._samples_num)
使用方法如下:
import numpy as np
from mindspore import Tensor
from mindspore.nn.metrics Dice
metric = Dice(smooth=1e-5)
metric.clear()
x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
metric.update(x, y)
dice = metric.eval()
print(dice)
0.20467791371802546
每个batch(两组数据)进行计算的时候如下:
import numpy as np
from mindspore import Tensor
from mindspore.nn.metrics Dice
metric = Dice(smooth=1e-5)
metric.clear()
x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
metric.update(x, y)
x1= Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y1 = Tensor(np.array([[1, 0], [1, 1], [1, 0]]))
metric.update(x1, y1)
avg_dice = metric.eval()
print(dice)
Dice Loss 介绍及实现
Dice Loss原理
Dice Loss 原理是在 Dice Calculated on the basis of coefficients,用1去减Dice系数
This is the case where there is only one image per batch in the binary classification,当一个批次有N张图片时,可以将图片压缩为一维向量,如下图:
对应的label也会相应变化,最后一起计算N张图片的Dice系数和Dice Loss.
MindSpore 二分类 DiceLoss 代码实现
class DiceLoss(_Loss):
def __init__(self, smooth=1e-5):
super(DiceLoss, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self.reshape = P.Reshape()
def construct(self, logits, label):
# Dimension check,维度必须相等.(输入必须是tensor)
_check_shape(logits.shape, label.shape)
# 求交集,和dicecoefficients in the same way
intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1)))
# 求并集,和dicecoefficients in the same way
unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \
self.reduce_sum(self.mul(label.view(-1), label.view(-1)))
# 利用公式进行计算
single_dice_coeff = (2 * intersection) / (unionset + self.smooth)
dice_loss = 1 - single_dice_coeff / label.shape[0]
return dice_loss.mean()
@constexpr
def _check_shape(logits_shape, label_shape):
validator.check('logits_shape', logits_shape, 'label_shape', label_shape)
使用方法如下:
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
loss = nn.DiceLoss(smooth=1e-5)
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
output = loss(y_pred, y)
print(output)
[0.7953220862819745]
MindSpore 多分类 MultiClassDiceLoss 代码实现
在MindSporeThere are various loss functions to choose from in semantic segmentation,However, the most commonly used loss function is to use cross entropy.
class MultiClassDiceLoss(_Loss):
def __init__(self, weights=None, ignore_indiex=None, activation=A.Softmax(axis=1)):
super(MultiClassDiceLoss, self).__init__()
# 利用Dice系数
self.binarydiceloss = DiceLoss(smooth=1e-5)
# 权重是一个Tensor,Should be the same dimension as the number of categories:Tensor of shape `[num_classes, dim]`.
self.weights = weights if weights is None else validator.check_value_type("weights", weights, [Tensor])
# The ordinal number of the category to ignore
self.ignore_indiex = ignore_indiex if ignore_indiex is None else \
validator.check_value_type("ignore_indiex", ignore_indiex, [int])
# 使用激活函数
self.activation = A.get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, Cell):
raise TypeError("The activation must be str or Cell, but got {}.".format(activation))
self.activation_flag = self.activation is not None
self.reshape = P.Reshape()
def construct(self, logits, label):
# Dimension check,维度必须相等.(输入必须是tensor)
_check_shape(logits.shape, label.shape)
# 先定义一个loss,初始值为0
total_loss = 0
# 如果使用激活函数
if self.activation_flag:
logits = self.activation(logits)
# Iterates by the first number of the dimension of the label
for i in range(label.shape[1]):
if i != self.ignore_indiex:
dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
if self.weights is not None:
_check_weights(self.weights, label)
dice_loss *= self.weights[i]
total_loss += dice_loss
return total_loss/label.shape[1]
使用方法如下:
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax")
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
output = loss(y_pred, y)
print(output)
[0.7761003]
Dice Loss 存在的问题
训练误差曲线非常混乱,很难看出关于收敛的信息.尽管可以检查在验证集上的误差来避开此问题.
边栏推荐
- Imitation SBUS fixed with serial data conversion
- 2022-08-01 回顾基础二叉树以及操作
- QSS 选择器
- 技术干货 | 基于 MindSpore 实现图像分割之豪斯多夫距离
- Bias lock/light lock/heavy lock lock is healthier. How is locking and unlocking accomplished?
- 还在找网盘资源吗?快点收藏如下几个值得收藏的网盘资源搜索神器吧!
- Qiu Jun, CEO of Eggplant Technology: Focus on users and make products that users really need
- 2022华数杯数学建模A题环形振荡器的优化设计思路思路代码分享
- 无题三
- 无题十二
猜你喜欢
随机推荐
Analysis and practice of antjian webshell dynamic encrypted connection
无题三
如何实现按键的短按、长按检测?
Dry goods!Generative Model Evaluation and Diagnosis
C语言的高级用法
Oracle temporary table space role
浅析WSGI协议
无题二
Egg framework usage (2)
Happens-before rules for threads
eKuiper Newsletter 2022-07|v1.6.0:Flow 编排 + 更好用的 SQL,轻松表达业务逻辑
歌词整理
【AGC】增长服务1-远程配置示例
Redis源码解析:Redis Cluster
PAT乙级-B1020 月饼(25)
Pytorch深度学习快速入门教程 -- 土堆教程笔记(三)
HStreamDB Newsletter 2022-07|分区模型优化、数据集成框架进一步完善
IO stream articles -- based on io stream to realize folder copy (copy subfolders and files in subfolders) full of dry goods
leetcode: 529. Minesweeper Game
皕杰报表的下拉框联动