当前位置:网站首页>Complete image segmentation efficiently based on MindSpore and realize Dice!
Complete image segmentation efficiently based on MindSpore and realize Dice!
2022-08-05 10:01:00 【Ascension MindSpore】
Dice Introduction and implementation of coefficients
DiceCoefficient principle
DiceIt is the most frequently used metric in medical image competitions,It is an ensemble similarity measure,通常用于计算两个样本的相似度,The value threshold is[0, 1].Often used for image segmentation in medical images,The best result of segmentation is 1,The worst time result is 0.
Dice系数计算公式如下:
当然DiceThere is also another expression,is used in the confusion matrixTP,FP,FN来表达:
The principle of this formula is shown in the figure below:
MindSpore代码实现
先简单介绍一下MindSpore——新一代AI开源计算框架.创新编程范式,AIScientists and engineers more易使用,便于开放式创新;This computational framework satisfies终端、边缘计算、云全场景需求,能更好保护数据隐私;可开源,形成广阔应用生态.
2020年3月28日,华为在开发者大会2020上宣布,全场景AI计算框架MindSpore在码云正式开源.MindSpore着重提升易用性并降低AI开发者的开发门槛,MindSpore原生适应每个场景包括端、边缘和云,并能够在按需协同的基础上,通过实现AI算法即代码,使开发态变得更加友好,显著减少模型开发时间,降低模型开发门槛.
通过MindSpore自身的技术创新及MindSpore与华为昇腾AI处理器的协同优化,实现了Efficient operation,大大提高了计算性能;MindSpore也支持GPU、CPU等其它处理器.
"""Dice"""
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric
class Dice(Metric):
def __init__(self, smooth=1e-5):
super(Dice, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self._dice_coeff_sum = 0
self._samples_num = 0
self.clear()
def clear(self):
# Yes to clear historical data
self._dice_coeff_sum = 0
self._samples_num = 0
def update(self, *inputs):
# 更新输入数据,y_pred和y,The data entry type can beTensor,lisy或numpy,维度必须相等
if len(inputs) != 2:
raise ValueError('Dice need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
# 将数据进行转换,统一转换为numpy
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
self._samples_num += y.shape[0]
if y_pred.shape != y.shape:
raise RuntimeError('y_pred and y should have same the dimension, but the shape of y_pred is{}, '
'the shape of y is {}.'.format(y_pred.shape, y.shape))
# Seek the intersection first,利用dotThe corresponding points are multiplied and added together
intersection = np.dot(y_pred.flatten(), y.flatten())
# 求并集,先将输入shapeAll pulled to one dimension,Then do point multiplication respectively,The two inputs are then added together
unionset = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
# 利用公式进行计算,加smooth是为了防止分母为0,避免当pred和true都为0时,分子被0除的问题,同时减少过拟合
single_dice_coeff = 2 * float(intersection) / float(unionset + self.smooth)
# The coefficients for each batch are accumulated
self._dice_coeff_sum += single_dice_coeff
def eval(self):
# 进行计算
if self._samples_num == 0:
raise RuntimeError('Total samples num must not be 0.')
return self._dice_coeff_sum / float(self._samples_num)
使用方法如下:
import numpy as np
from mindspore import Tensor
from mindspore.nn.metrics Dice
metric = Dice(smooth=1e-5)
metric.clear()
x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
metric.update(x, y)
dice = metric.eval()
print(dice)
0.20467791371802546
每个batch(两组数据)进行计算的时候如下:
import numpy as np
from mindspore import Tensor
from mindspore.nn.metrics Dice
metric = Dice(smooth=1e-5)
metric.clear()
x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
metric.update(x, y)
x1= Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y1 = Tensor(np.array([[1, 0], [1, 1], [1, 0]]))
metric.update(x1, y1)
avg_dice = metric.eval()
print(dice)
Dice Loss 介绍及实现
Dice Loss原理
Dice Loss 原理是在 Dice Calculated on the basis of coefficients,用1去减Dice系数
This is the case where there is only one image per batch in the binary classification,当一个批次有N张图片时,可以将图片压缩为一维向量,如下图:
对应的label也会相应变化,最后一起计算N张图片的Dice系数和Dice Loss.
MindSpore 二分类 DiceLoss 代码实现
class DiceLoss(_Loss):
def __init__(self, smooth=1e-5):
super(DiceLoss, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self.reshape = P.Reshape()
def construct(self, logits, label):
# Dimension check,维度必须相等.(输入必须是tensor)
_check_shape(logits.shape, label.shape)
# 求交集,和dicecoefficients in the same way
intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1)))
# 求并集,和dicecoefficients in the same way
unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \
self.reduce_sum(self.mul(label.view(-1), label.view(-1)))
# 利用公式进行计算
single_dice_coeff = (2 * intersection) / (unionset + self.smooth)
dice_loss = 1 - single_dice_coeff / label.shape[0]
return dice_loss.mean()
@constexpr
def _check_shape(logits_shape, label_shape):
validator.check('logits_shape', logits_shape, 'label_shape', label_shape)
使用方法如下:
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
loss = nn.DiceLoss(smooth=1e-5)
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
output = loss(y_pred, y)
print(output)
[0.7953220862819745]
MindSpore 多分类 MultiClassDiceLoss 代码实现
在MindSporeThere are various loss functions to choose from in semantic segmentation,However, the most commonly used loss function is to use cross entropy.
class MultiClassDiceLoss(_Loss):
def __init__(self, weights=None, ignore_indiex=None, activation=A.Softmax(axis=1)):
super(MultiClassDiceLoss, self).__init__()
# 利用Dice系数
self.binarydiceloss = DiceLoss(smooth=1e-5)
# 权重是一个Tensor,Should be the same dimension as the number of categories:Tensor of shape `[num_classes, dim]`.
self.weights = weights if weights is None else validator.check_value_type("weights", weights, [Tensor])
# The ordinal number of the category to ignore
self.ignore_indiex = ignore_indiex if ignore_indiex is None else \
validator.check_value_type("ignore_indiex", ignore_indiex, [int])
# 使用激活函数
self.activation = A.get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, Cell):
raise TypeError("The activation must be str or Cell, but got {}.".format(activation))
self.activation_flag = self.activation is not None
self.reshape = P.Reshape()
def construct(self, logits, label):
# Dimension check,维度必须相等.(输入必须是tensor)
_check_shape(logits.shape, label.shape)
# 先定义一个loss,初始值为0
total_loss = 0
# 如果使用激活函数
if self.activation_flag:
logits = self.activation(logits)
# Iterates by the first number of the dimension of the label
for i in range(label.shape[1]):
if i != self.ignore_indiex:
dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
if self.weights is not None:
_check_weights(self.weights, label)
dice_loss *= self.weights[i]
total_loss += dice_loss
return total_loss/label.shape[1]
使用方法如下:
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax")
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
output = loss(y_pred, y)
print(output)
[0.7761003]
Dice Loss 存在的问题
训练误差曲线非常混乱,很难看出关于收敛的信息.尽管可以检查在验证集上的误差来避开此问题.
边栏推荐
- 2022.8.3
- Custom filters and interceptors implement ThreadLocal thread closure
- Happens-before rules for threads
- IO stream articles -- based on io stream to realize folder copy (copy subfolders and files in subfolders) full of dry goods
- 百年北欧奢华家电品牌ASKO智能三温区酒柜臻献七夕,共品珍馐爱意
- First Decentralized Heist?Loss of nearly 200 million US dollars: analysis of the attack on the cross-chain bridge Nomad
- CCVR eases heterogeneous federated learning based on classifier calibration
- 2022-08-01 回顾基础二叉树以及操作
- Egg framework usage (2)
- 手把手教你纯c实现异常捕获try-catch组件
猜你喜欢
CCVR eases heterogeneous federated learning based on classifier calibration
leetcode: 529. 扫雷游戏
2022.8.3
MySQL内部函数介绍
上海控安技术成果入选市经信委《2021年上海市网络安全产业创新攻关成果目录》
js 图形操作一(兼容pc、移动端实现 draggable属性 拖放效果)
How to realize the short press and long press detection of the button?
科普大佬说 | 港大黄凯斌老师带你解锁黑客帝国与6G的关系
2022华数杯数学建模A题环形振荡器的优化设计思路思路代码分享
电竞、便捷、高效、安全,盘点OriginOS功能的关键词
随机推荐
无题七
How ali cloud storage database automatically to speed up the loading speed of www.cxsdkt.cn how to set up the case?
egg框架使用(一)
Wei Dongshan Digital Photo Frame Project Learning (6) Transplantation of tslib
21 Days of Deep Learning - Convolutional Neural Networks (CNN): Weather Recognition (Day 5)
5.部署web项目到云服务器
无题三
ffmpeg drawtext add text watermark
Oracle临时表空间作用
Marketing Suggestions | You have an August marketing calendar to check! Suggest a collection!
【温度预警程序de开发】事件驱动模型实例运用
欧盟 | 地平线 2020 ENSEMBLE:D2.13 SOTIF Safety Concept(上)
Egg framework usage (1)
无题十
使用工具类把对象中的null值转换为空字符串(集合也可以使用)
华为轻量级神经网络架构GhostNet再升级,GPU上大显身手的G-GhostNet(IJCV22)
2022华数杯数学建模思路分析交流
express hot-reload
hcip BGP 增强实验
无题四