当前位置:网站首页>Complete image segmentation efficiently based on MindSpore and realize Dice!
Complete image segmentation efficiently based on MindSpore and realize Dice!
2022-08-05 10:01:00 【Ascension MindSpore】
Dice Introduction and implementation of coefficients
DiceCoefficient principle
DiceIt is the most frequently used metric in medical image competitions,It is an ensemble similarity measure,通常用于计算两个样本的相似度,The value threshold is[0, 1].Often used for image segmentation in medical images,The best result of segmentation is 1,The worst time result is 0.
Dice系数计算公式如下:
当然DiceThere is also another expression,is used in the confusion matrixTP,FP,FN来表达:
The principle of this formula is shown in the figure below:
MindSpore代码实现
先简单介绍一下MindSpore——新一代AI开源计算框架.创新编程范式,AIScientists and engineers more易使用,便于开放式创新;This computational framework satisfies终端、边缘计算、云全场景需求,能更好保护数据隐私;可开源,形成广阔应用生态.
2020年3月28日,华为在开发者大会2020上宣布,全场景AI计算框架MindSpore在码云正式开源.MindSpore着重提升易用性并降低AI开发者的开发门槛,MindSpore原生适应每个场景包括端、边缘和云,并能够在按需协同的基础上,通过实现AI算法即代码,使开发态变得更加友好,显著减少模型开发时间,降低模型开发门槛.
通过MindSpore自身的技术创新及MindSpore与华为昇腾AI处理器的协同优化,实现了Efficient operation,大大提高了计算性能;MindSpore也支持GPU、CPU等其它处理器.
"""Dice"""
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric
class Dice(Metric):
def __init__(self, smooth=1e-5):
super(Dice, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self._dice_coeff_sum = 0
self._samples_num = 0
self.clear()
def clear(self):
# Yes to clear historical data
self._dice_coeff_sum = 0
self._samples_num = 0
def update(self, *inputs):
# 更新输入数据,y_pred和y,The data entry type can beTensor,lisy或numpy,维度必须相等
if len(inputs) != 2:
raise ValueError('Dice need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
# 将数据进行转换,统一转换为numpy
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
self._samples_num += y.shape[0]
if y_pred.shape != y.shape:
raise RuntimeError('y_pred and y should have same the dimension, but the shape of y_pred is{}, '
'the shape of y is {}.'.format(y_pred.shape, y.shape))
# Seek the intersection first,利用dotThe corresponding points are multiplied and added together
intersection = np.dot(y_pred.flatten(), y.flatten())
# 求并集,先将输入shapeAll pulled to one dimension,Then do point multiplication respectively,The two inputs are then added together
unionset = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
# 利用公式进行计算,加smooth是为了防止分母为0,避免当pred和true都为0时,分子被0除的问题,同时减少过拟合
single_dice_coeff = 2 * float(intersection) / float(unionset + self.smooth)
# The coefficients for each batch are accumulated
self._dice_coeff_sum += single_dice_coeff
def eval(self):
# 进行计算
if self._samples_num == 0:
raise RuntimeError('Total samples num must not be 0.')
return self._dice_coeff_sum / float(self._samples_num)
使用方法如下:
import numpy as np
from mindspore import Tensor
from mindspore.nn.metrics Dice
metric = Dice(smooth=1e-5)
metric.clear()
x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
metric.update(x, y)
dice = metric.eval()
print(dice)
0.20467791371802546
每个batch(两组数据)进行计算的时候如下:
import numpy as np
from mindspore import Tensor
from mindspore.nn.metrics Dice
metric = Dice(smooth=1e-5)
metric.clear()
x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
metric.update(x, y)
x1= Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y1 = Tensor(np.array([[1, 0], [1, 1], [1, 0]]))
metric.update(x1, y1)
avg_dice = metric.eval()
print(dice)
Dice Loss 介绍及实现
Dice Loss原理
Dice Loss 原理是在 Dice Calculated on the basis of coefficients,用1去减Dice系数
This is the case where there is only one image per batch in the binary classification,当一个批次有N张图片时,可以将图片压缩为一维向量,如下图:
对应的label也会相应变化,最后一起计算N张图片的Dice系数和Dice Loss.
MindSpore 二分类 DiceLoss 代码实现
class DiceLoss(_Loss):
def __init__(self, smooth=1e-5):
super(DiceLoss, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self.reshape = P.Reshape()
def construct(self, logits, label):
# Dimension check,维度必须相等.(输入必须是tensor)
_check_shape(logits.shape, label.shape)
# 求交集,和dicecoefficients in the same way
intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1)))
# 求并集,和dicecoefficients in the same way
unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \
self.reduce_sum(self.mul(label.view(-1), label.view(-1)))
# 利用公式进行计算
single_dice_coeff = (2 * intersection) / (unionset + self.smooth)
dice_loss = 1 - single_dice_coeff / label.shape[0]
return dice_loss.mean()
@constexpr
def _check_shape(logits_shape, label_shape):
validator.check('logits_shape', logits_shape, 'label_shape', label_shape)
使用方法如下:
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
loss = nn.DiceLoss(smooth=1e-5)
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
output = loss(y_pred, y)
print(output)
[0.7953220862819745]
MindSpore 多分类 MultiClassDiceLoss 代码实现
在MindSporeThere are various loss functions to choose from in semantic segmentation,However, the most commonly used loss function is to use cross entropy.
class MultiClassDiceLoss(_Loss):
def __init__(self, weights=None, ignore_indiex=None, activation=A.Softmax(axis=1)):
super(MultiClassDiceLoss, self).__init__()
# 利用Dice系数
self.binarydiceloss = DiceLoss(smooth=1e-5)
# 权重是一个Tensor,Should be the same dimension as the number of categories:Tensor of shape `[num_classes, dim]`.
self.weights = weights if weights is None else validator.check_value_type("weights", weights, [Tensor])
# The ordinal number of the category to ignore
self.ignore_indiex = ignore_indiex if ignore_indiex is None else \
validator.check_value_type("ignore_indiex", ignore_indiex, [int])
# 使用激活函数
self.activation = A.get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, Cell):
raise TypeError("The activation must be str or Cell, but got {}.".format(activation))
self.activation_flag = self.activation is not None
self.reshape = P.Reshape()
def construct(self, logits, label):
# Dimension check,维度必须相等.(输入必须是tensor)
_check_shape(logits.shape, label.shape)
# 先定义一个loss,初始值为0
total_loss = 0
# 如果使用激活函数
if self.activation_flag:
logits = self.activation(logits)
# Iterates by the first number of the dimension of the label
for i in range(label.shape[1]):
if i != self.ignore_indiex:
dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
if self.weights is not None:
_check_weights(self.weights, label)
dice_loss *= self.weights[i]
total_loss += dice_loss
return total_loss/label.shape[1]
使用方法如下:
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax")
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
output = loss(y_pred, y)
print(output)
[0.7761003]
Dice Loss 存在的问题
训练误差曲线非常混乱,很难看出关于收敛的信息.尽管可以检查在验证集上的误差来避开此问题.
边栏推荐
- 科普大佬说 | 港大黄凯斌老师带你解锁黑客帝国与6G的关系
- Hundred lines of code launch red hearts, why programmers lose their girlfriends!
- [Unity] [UGUI] [Display text on the screen]
- 华为轻量级神经网络架构GhostNet再升级,GPU上大显身手的G-GhostNet(IJCV22)
- 19. Server-side session technology Session
- Pycharm 常用外部工具
- js 图形操作一(兼容pc、移动端实现 draggable属性 拖放效果)
- 阿里顶级架构师多年总结的JVM宝典,哪里不会查哪里!
- DFINITY 基金会创始人谈熊市沉浮,DeFi 项目该何去何从
- PHP operation mangoDb
猜你喜欢
CPU的亲缘性affinity
首次去中心化抢劫?近2亿美元损失:跨链桥Nomad 被攻击事件分析
2022.8.3
Pytorch Deep Learning Quick Start Tutorial -- Mound Tutorial Notes (3)
数据中台建设(十):数据安全管理
MySQL内部函数介绍
阿里顶级架构师多年总结的JVM宝典,哪里不会查哪里!
Qiu Jun, CEO of Eggplant Technology: Focus on users and make products that users really need
three物体围绕一周呈球形排列
告白数字化转型时代:麦聪软件以最简单的方式让企业把数据用起来
随机推荐
韦东山 数码相框 项目学习(六)tslib的移植
还在找网盘资源吗?快点收藏如下几个值得收藏的网盘资源搜索神器吧!
Jenkins使用手册(2) —— 软件配置
Seata source code analysis: initialization process of TM RM client
数据中台建设(十):数据安全管理
Pycharm 常用外部工具
首次去中心化抢劫?近2亿美元损失:跨链桥Nomad 被攻击事件分析
使用工具类把对象中的null值转换为空字符串(集合也可以使用)
Pytorch深度学习快速入门教程 -- 土堆教程笔记(三)
C语言的高级用法
正则表达式replaceAll()方法具有什么功能呢?
Four years of weight loss record
无题五
The technological achievements of Shanghai Konan were selected into the "2021 Shanghai Network Security Industry Innovation Research Achievement Catalog" by the Municipal Commission of Economy and Inf
Handwriting Currying - toString Comprehension
干货!生成模型的评价与诊断
IO stream articles -- based on io stream to realize folder copy (copy subfolders and files in subfolders) full of dry goods
微服务 技术栈
茄子科技CEO仇俊:以用户为中心,做用户真正需要的产品
2022华数杯数学建模A题环形振荡器的优化设计思路思路代码分享