当前位置:网站首页>U-Net: Convolutional Networks for Biomedical Images Segmentation
U-Net: Convolutional Networks for Biomedical Images Segmentation
2022-07-05 18:23:00 【00000cj】
paper: U-Net: Convolutional Networks for Biomedical Image Segmentation
Innovation points
- Put forward U type encoder-decoder Network structure , adopt skip-connection Operate to better integrate shallow location information and deep semantic information .U-Net reference FCN Structure with full convolution , Compared with FCN An important change is that there are also a large number of feature channels in the upper sampling part , This allows the network to propagate context information to higher resolution layers .
- The task of medical image segmentation , Very little training data , The author has done a lot of data enhancement by using elastic deformation .
- Propose to use weighted loss .
Some implementation details that need attention
- The original paper does not use padding, So output feature map The resolution of gradually decreases , What is introduced below mmsegmentation In the implementation of padding, So when stride=1 The resolution of the output characteristic image remains unchanged .
- FCN in skip-connection The fusion of shallow information and deep information is through add The way , and U-Net China is through concatenate The way .
Implementation details analysis
With MMSegmentation in unet For example , hypothesis batch_size=4, Input shape by (4, 3, 480, 480).
Backbone
- encode Stage total 5 individual stage, Every stage There is one of them. ConvBlock,ConvBlock from 2 individual Conv-BN-Relu form . Except for 1 individual stage, after 4 individual stage stay ConvBlock There's always been 1 individual 2x2-s2 Of maxpool. Every stage Of the 1 individual conv The output channel of x2. therefore encode Each stage stage Output shape Respectively (4, 64, 480, 480)、(4, 128, 240, 240)、(4, 256, 120, 120)、(4, 512, 60, 60)、(4, 1024, 30, 30).
- decode Stage total 4 individual stage, and encode after 4 Down sampled stage Corresponding . Every stage It is divided into upsample、concatenate、conv Three steps .upsample By a scale_factor=2 Of bilinear Interpolation and 1 individual Conv-BN-Relu form , Among them conv yes 1x1-s1 Convolution of halving the number of channels . The second step concatenate take upsample The output of encode Outputs with the same stage resolution are spliced together along the channel direction . The third step is a ConvBlock, and encode The stage is the same , there ConvBlock There are also two Conv-BN-Relu form , because upsample Halve the number of rear channels , But and encode After the corresponding output is spliced, the number of channels is restored , there ConvBlock The first of conv Then halve the number of output channels . therefore decode Each stage stage Output shape Respectively (4, 1024, 30, 30)、(4, 512, 60, 60)、(4, 256, 120, 120)、(4, 128 , 240, 240)、(4, 64, 480, 480). Be careful decode common 4 individual stage, So the actual output is post 4 individual , The first output is encode the last one stage Output .
FCN Head
- backbone in decode The last stage stage Output (4, 64, 480, 480) As head The input of . First pass through a 3x3-s1 Of conv-bn-relu, The number of channels remains unchanged . And then pass by ratio=0.1 Of dropout. At the end of the day 1x1 Of conv Get the final output of the model , The number of output channels is the number of categories ( Include background ).
Loss
- loss use cross-entropy loss
Auxiliary Head
- backbone in decode The penultimate stage stage Output (4, 128, 240, 240) As auxiliary head The input of . Through a 3x3-s1 Of conv-bn-relu, The number of output channels is halved to 64. after ratio=0.1 Of dropout. At the end of the day 1x1 Of conv Get the final output of the model , The number of output channels is the number of categories ( Include background ).
- Auxiliary branch Loss It's also cross-entropy loss, Note that the final output resolution of this branch is raw gt Half of , So it's calculating loss You need to sample up through bilinear interpolation first .
The complete structure of the model
EncoderDecoder(
(backbone): UNet(
(encoder): ModuleList(
(0): Sequential(
(0): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(1): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(2): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(3): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(4): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
)
(decoder): ModuleList(
(0): UpConvBlock(
(conv_block): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
(upsample): InterpConv(
(interp_upsample): Sequential(
(0): Upsample()
(1): ConvModule(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(1): UpConvBlock(
(conv_block): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
(upsample): InterpConv(
(interp_upsample): Sequential(
(0): Upsample()
(1): ConvModule(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(2): UpConvBlock(
(conv_block): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
(upsample): InterpConv(
(interp_upsample): Sequential(
(0): Upsample()
(1): ConvModule(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(3): UpConvBlock(
(conv_block): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
(upsample): InterpConv(
(interp_upsample): Sequential(
(0): Upsample()
(1): ConvModule(
(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
)
)
init_cfg=[{'type': 'Kaiming', 'layer': 'Conv2d'}, {'type': 'Constant', 'val': 1, 'layer': ['_BatchNorm', 'GroupNorm']}]
(decode_head): FCNHead(
input_transform=None, ignore_index=255, align_corners=False
(loss_decode): CrossEntropyLoss(avg_non_ignore=False)
(conv_seg): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
(dropout): Dropout2d(p=0.1, inplace=False)
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
init_cfg={'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}
(auxiliary_head): FCNHead(
input_transform=None, ignore_index=255, align_corners=False
(loss_decode): CrossEntropyLoss(avg_non_ignore=False)
(conv_seg): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
(dropout): Dropout2d(p=0.1, inplace=False)
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
init_cfg={'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}
)
边栏推荐
- How to obtain the coordinates of the aircraft passing through both ends of the radar
- 让更多港澳青年了解南沙特色文创产品!“南沙麒麟”正式亮相
- Check namespaces and classes
- Introduction to Resampling
- 图像分类,看我就够啦!
- node_exporter内存使用率不显示
- Multithreading (I) processes and threads
- Privacy computing helps secure data circulation and sharing
- Leetcode notes: Weekly contest 300
- To solve the stubborn problem of Lake + warehouse hybrid architecture, xinghuan Technology launched an independent and controllable cloud native Lake warehouse integrated platform
猜你喜欢
华夏基金:基金行业数字化转型实践成果分享
Find the first k small element select_ k
Let more young people from Hong Kong and Macao know about Nansha's characteristic cultural and creative products! "Nansha kylin" officially appeared
LeetCode 6109. 知道秘密的人数
Sophon Base 3.1 推出MLOps功能,为企业AI能力运营插上翅膀
Sophon CE社区版上线,免费Get轻量易用、高效智能的数据分析工具
Copy the linked list with random pointer in the "Li Kou brush question plan"
吴恩达团队2022机器学习课程,来啦
pytorch yolov5 训练自定义数据
Star Ring Technology launched transwarp Navier, a data element circulation platform, to help enterprises achieve secure data circulation and collaboration under privacy protection
随机推荐
记录Pytorch中的eval()和no_grad()
New words new words new words new words [2]
访问数据库使用redis作为mysql的缓存(redis和mysql结合)
使用JMeter录制脚本并调试
Star ring technology data security management platform defender heavy release
Sophon base 3.1 launched mlops function to provide wings for the operation of enterprise AI capabilities
About Estimation with Cross-Validation
金太阳开户安全吗?万一免5开户能办理吗?
快速生成ipa包
About Estimation with Cross-Validation
Sophon autocv: help AI industrial production and realize visual intelligent perception
Fix vulnerability - mysql, ES
Introduction to VC programming on "suggestions collection"
吴恩达团队2022机器学习课程,来啦
ISPRS2022/雲檢測:Cloud detection with boundary nets基於邊界網的雲檢測
What is the reason why the video cannot be played normally after the easycvr access device turns on the audio?
图片数据不够?我做了一个免费的图像增强软件
FCN: Fully Convolutional Networks for Semantic Segmentation
vulnhub之darkhole_2
第十一届中国云计算标准和应用大会 | 云计算国家标准及白皮书系列发布 华云数据全面参与编制