当前位置:网站首页>U-Net: Convolutional Networks for Biomedical Images Segmentation
U-Net: Convolutional Networks for Biomedical Images Segmentation
2022-07-05 17:51:00 【00000cj】
paper: U-Net: Convolutional Networks for Biomedical Image Segmentation
创新点
- 提出了U型encoder-decoder的网络结构,通过skip-connection操作更好的融合浅层的位置信息和深层的语义信息。U-Net借鉴FCN采用全卷积的结构,相比于FCN一个重要的改变是在上采样部分也有大量的特征通道,这允许网络将上下文信息传播到更高分辨率的层。
- 医疗图像分割的任务,训练数据非常少,作者通过应用弹性形变做了大量的数据增强。
- 提出使用加权损失。

一些需要注意的实现细节
- 原论文实现中没有使用padding,因此输出feature map的分辨率逐渐减小,在下面介绍的mmsegmentation的实现中采用了padding,因此当stride=1时输出特征图的分辨率不变。
- FCN中skip-connection融合浅层信息与深层信息是通过add的方式,而U-Net中是通过concatenate的方式.
实现细节解析
以MMSegmentation中unet的实现为例,假设batch_size=4,输入shape为(4, 3, 480, 480)。
Backbone
- encode阶段共5个stage,每个stage中有一个ConvBlock,ConvBlock由2个Conv-BN-Relu组成。除了第1个stage,后4个stage在ConvBlock前都有1个2x2-s2的maxpool。每个stage的第1个conv的输出通道x2。因此encode阶段每个stage的输出shape分别为(4, 64, 480, 480)、(4, 128, 240, 240)、(4, 256, 120, 120)、(4, 512, 60, 60)、(4, 1024, 30, 30)。
- decode阶段共4个stage,和encode后4个降采样的stage对应。每个stage分为upsample、concatenate、conv三个步骤。upsample由一个scale_factor=2的bilinear插值和1个Conv-BN-Relu组成,其中的conv是1x1-s1通道数减半的卷积。第二步concatenate将upsample的输出与encode阶段分辨率大小相同的输出沿通道方向拼接到一起。第三步是一个ConvBlock,和encode阶段一样,这里的ConvBlock也由两个Conv-BN-Relu组成,因为upsample后通道数减半,但和encode对应输出拼接后通道数又还原回去了,这里的ConvBlock中的第一个conv再将输出通道数减半。因此decode阶段每个stage的输出shape分别为(4, 1024, 30, 30)、(4, 512, 60, 60)、(4, 256, 120, 120)、(4, 128 , 240, 240)、(4, 64, 480, 480)。注意decode共4个stage,因此实际的输出是后4个,第一个输出就是encode最后一个stage的输出。
FCN Head
- backbone中decode阶段的最后一个stage的输出(4, 64, 480, 480)作为head的输入。首先经过一个3x3-s1的conv-bn-relu,通道数不变。然后经过ratio=0.1的dropout。最后经过一个1x1的conv得到模型最终的输出,输出通道数为类别数(包含背景)。
Loss
- loss采用cross-entropy loss
Auxiliary Head
- backbone中decode阶段的倒数第二个stage的输出(4, 128, 240, 240)作为auxiliary head的输入。经过一个3x3-s1的conv-bn-relu,输出通道数减半为64。经过ratio=0.1的dropout。最后经过一个1x1的conv得到模型最终的输出,输出通道数为类别数(包含背景)。
- 辅助分支的Loss也是cross-entropy loss,注意这个分支的最终输出分辨率为原始gt的一半,因此在计算loss时需要先通过双线性插值上采样。
模型的完整结构
EncoderDecoder(
(backbone): UNet(
(encoder): ModuleList(
(0): Sequential(
(0): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(1): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(2): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(3): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(4): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
)
(decoder): ModuleList(
(0): UpConvBlock(
(conv_block): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
(upsample): InterpConv(
(interp_upsample): Sequential(
(0): Upsample()
(1): ConvModule(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(1): UpConvBlock(
(conv_block): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
(upsample): InterpConv(
(interp_upsample): Sequential(
(0): Upsample()
(1): ConvModule(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(2): UpConvBlock(
(conv_block): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
(upsample): InterpConv(
(interp_upsample): Sequential(
(0): Upsample()
(1): ConvModule(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(3): UpConvBlock(
(conv_block): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
(upsample): InterpConv(
(interp_upsample): Sequential(
(0): Upsample()
(1): ConvModule(
(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
)
)
init_cfg=[{'type': 'Kaiming', 'layer': 'Conv2d'}, {'type': 'Constant', 'val': 1, 'layer': ['_BatchNorm', 'GroupNorm']}]
(decode_head): FCNHead(
input_transform=None, ignore_index=255, align_corners=False
(loss_decode): CrossEntropyLoss(avg_non_ignore=False)
(conv_seg): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
(dropout): Dropout2d(p=0.1, inplace=False)
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
init_cfg={'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}
(auxiliary_head): FCNHead(
input_transform=None, ignore_index=255, align_corners=False
(loss_decode): CrossEntropyLoss(avg_non_ignore=False)
(conv_seg): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
(dropout): Dropout2d(p=0.1, inplace=False)
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
init_cfg={'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}
)
边栏推荐
猜你喜欢

小林coding的内存管理章节

Zabbix

"Xiaodeng in operation and maintenance" is a single sign on solution for cloud applications

ISPRS2022/雲檢測:Cloud detection with boundary nets基於邊界網的雲檢測

Leetcode daily practice: rotating arrays

求解为啥all(())是True, 而any(())是FALSE?

To solve the stubborn problem of Lake + warehouse hybrid architecture, xinghuan Technology launched an independent and controllable cloud native Lake warehouse integrated platform

Elk log analysis system

Tencent music launched its new product "quyimai", which provides music commercial copyright authorization

南京大学:新时代数字化人才培养方案探讨
随机推荐
Privacy computing helps secure data circulation and sharing
开户复杂吗?网上开户安全么?
mybash
「运维有小邓」用于云应用程序的单点登录解决方案
寻找第k小元素 前k小元素 select_k
从类生成XML架构
[BeanShell] there are many ways to write data locally
使用Jmeter虚拟化table失败
ELK日志分析系统
matlab内建函数怎么不同颜色,matlab分段函数不同颜色绘图
Use QT designer interface class to create two interfaces, and switch from interface 1 to interface 2 by pressing the key
How awesome is the architecture of "12306"?
EasyCVR接入设备开启音频后,视频无法正常播放是什么原因?
论文阅读_中文NLP_LTP
星环科技数据安全管理平台 Defensor重磅发布
Teamcenter 消息注册前操作或后操作
外盘黄金哪个平台正规安全,怎么辨别?
吴恩达团队2022机器学习课程,来啦
rsync
Size_ T is unsigned