当前位置:网站首页>论文阅读:Duplex Contextual Relation Network for Polyp Segmentation
论文阅读:Duplex Contextual Relation Network for Polyp Segmentation
2022-06-28 19:09:00 【sigmoidAndRELU】
结肠镜图像分割论文阅读
论文总体架构
论文名称:用于息肉分割的双重上下文关系网络(ISBI2022)
作者单位:北京邮电大学
作者名称:尹子衿等
代码地址: https://github.com/PRIS-CV/DCRNet/blob/master/lib/DCRNet.py
摘要
结肠镜检查中的息肉自动分割在结直肠癌(CRC)的早期诊断中起着关键作用。然而,息肉图像的多样性极大增加了准确分割的难度。现有的研究主要集中在学习单个图像中的上下文信息,但未能利用跨图像的息肉的同步视觉模式。本文从整个数据集的整体角度来探索上下文相关性,并提出了一个双工上下文关系网络(DCRNet)来捕获图像内和交叉图像之间的上下文关系。基于上述两种相似性,每个输入区域的特征可以通过嵌入上下文区域来增强每个输入区域的特征。为了存储训练过程中先前图像嵌入的特征区域,设计了情景记忆并作为队列操作。我们在EndoScene、Kvasir-SEG和最近发布的大规模PICCOLO数据集上评估了所提出的方法。实验结果表明,我们提出的DCRNet在广泛使用的评价指标方面优于最先进的方法。
贡献:
1、提出来嵌入上下文区域;
2、设计了情景记忆并作为队列操作;
3、提出了DCRNet;
4、模型在多个结肠癌数据集上的表现良好。
引言
结肠癌的诊断和治疗中,对于息肉的区域分析是非常关键的步骤,切除息肉是预防和治疗早期结肠直肠癌的直接手段。结肠镜图像能够清晰地展示出整个患者结肠部分的信息,但是对于息肉的定位分割依然存在着以下困难:1、息肉多饰多样;2、息肉和结肠粘膜之间的边界过于模糊。如图所示:
从图像中我们能够观察到,有的比较明显,像 a b,肿起来的部分就是,而d就很夸张,c很不明显,不仔细看根本看不着。
相关工作
在现有的工作中,这里简介:
1、多尺度提取特征的网络:ACSNet(MICCAI 2020),结合上下文信息和局部细节来应对息肉特征多样性的问题。
PraNet使用多尺度的特征聚合的方法,根据局部特征提取轮廓图并通过上采样依次细化分割图。
2、利用辅助信息来约束分割结果:SFANet(MICCAI 2019),利用区域边界约束,来选择特征聚合,提高分割精度。
重点: 这些工作,额,好像都是在单个图像上找特征分割,这样的话是不是涉及到一个隐性的病灶相似度,然后选取对应的分割参数??如果是这样的话,一个模型所做到的工作就是在对于明显的病灶的分割的基础上,对于不同类型的息肉图像进行相应的隐形分类,简单图像简单分,复杂图像及不明显的图像就特殊方法,很有道理!
所以本文就要提到一个机制,叫做情景记忆!
理论证明:(Content-based medical image retrieval of ct images of
liver lesions using manifold learning)已经证明了从其他图像中检索在放射学病变治疗过程中的意义。
相关成果:在度量学习中已经有用到。
所以,本文采用这种思想,从整个数据集的整体角度来探讨交叉图像和图像内的特征关联。
工作总结:
1、图像内上下文关系模块
2、图像外上下文关系模块
这两个模块也是即插即用的。
模型结构
先上图片

首先看到网络框架图,它由三部分组成,编码器、解码器、底部信息处理模块。
编码解码器本文用到的是基于ResNet34的UNet,这里不再赘述。直接看重头戏!
内部上下文关系
class PAM_Module(Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_dim):
super(PAM_Module, self).__init__()
self.chanel_in = in_dim
self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = Parameter(torch.zeros(1))
self.softmax = Softmax(dim=-1)
def forward(self, x):
""" inputs : x : input feature maps( B X C X H X W) returns : out : attention value + input feature attention: B X (HxW) X (HxW) """
m_batchsize, C, height, width = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, height, width)
out = self.gamma*out + x
return out
这一段代码,作者在里面写的备注还是非常详细的,这个东东的作用就是建立当前图像中所有像素点之间的关系,然后将这种关系与输入相乘,从而得到加权的效果!当然,残差结构一直是保留项目,嗯,就是这样的。
外部上下文关系(这个平生还是第一次见,值得重点观察)
class DCRNet(ResNet34Unet):
def __init__(self,
bank_size=20,
num_classes=1,
num_channels=3,
is_deconv=False,
decoder_kernel_size=3,
pretrained=True,
feat_channels=512
):
super().__init__(num_classes=1,
num_channels=3,
is_deconv=False,
decoder_kernel_size=3,
pretrained=True)
self.bank_size = bank_size
self.register_buffer("bank_ptr", torch.zeros(1, dtype=torch.long)) # memory bank pointer
self.register_buffer("bank", torch.zeros(self.bank_size, feat_channels, num_classes)) # memory bank
self.bank_full = False
# =====Attentive Cross Image Interaction==== #
self.feat_channels = feat_channels
self.L = nn.Conv2d(feat_channels, num_classes, 1)
self.X = conv2d(feat_channels, 512, 3)
self.phi = conv1d(512, 256)
self.psi = conv1d(512, 256)
self.delta = conv1d(512, 256)
self.rho = conv1d(256, 512)
self.g = conv2d(512 + 512, 512, 1)
# =========Dual Attention========== #
self.sa_head = PAM_Module(feat_channels)
#=========Attention Fusion=========#
self.fusion = nn.Conv2d(feat_channels, feat_channels, 1)
#==Initiate the pointer of bank buffer==#
def init(self):
self.bank_ptr[0] = 0
self.bank_full = False
@torch.no_grad() #这句很重要!!!!
def update_bank(self, x):
ptr = int(self.bank_ptr)
batch_size = x.shape[0]
vacancy = self.bank_size - ptr
if batch_size >= vacancy:
self.bank_full = True
pos = min(batch_size, vacancy)
self.bank[ptr:ptr+pos] = x[0:pos].clone()
# update pointer
ptr = (ptr + pos) % self.bank_size
self.bank_ptr[0] = ptr
def down(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
return e4, e3, e2, e1
def up(self, feat, e3, e2, e1, x):
center = self.center(feat)
d4 = self.decoder4(torch.cat([center, e3], 1))
d3 = self.decoder3(torch.cat([d4, e2], 1))
d2 = self.decoder2(torch.cat([d3, e1], 1))
d1 = self.decoder1(torch.cat([d2, x], 1))
f1 = self.finalconv1(d1)
f2 = self.finalconv2(d2)
f3 = self.finalconv3(d3)
f4 = self.finalconv4(d4)
f4 = F.interpolate(f4, scale_factor=8, mode='bilinear', align_corners=True)
f3 = F.interpolate(f3, scale_factor=4, mode='bilinear', align_corners=True)
f2 = F.interpolate(f2, scale_factor=2, mode='bilinear', align_corners=True)
return f4, f3, f2, f1
def region_representation(self, input):
X = self.X(input)
L = self.L(input)
aux_out = L
batch, n_class, height, width = L.shape
l_flat = L.view(batch, n_class, -1)
# M = B * N * HW
M = torch.softmax(l_flat, -1)
channel = X.shape[1]
# X_flat = B * C * HW
X_flat = X.view(batch, channel, -1)
# f_k = B * C * N
f_k = (M @ X_flat.transpose(1, 2)).transpose(1, 2)
return aux_out, f_k, X_flat, X
def attentive_interaction(self, bank, X_flat, X):
batch, n_class, height, width = X.shape
# query = S * C
query = self.phi(bank).squeeze(dim=2)
# key: = B * C * HW
key = self.psi(X_flat)
# logit = HW * S * B (cross image relation)
logit = torch.matmul(query, key).transpose(0,2)
# attn = HW * S * B
attn = torch.softmax(logit, 2) ##softmax维度要正确
# delta = S * C
delta = self.delta(bank).squeeze(dim=2)
# attn_sum = B * C * HW
attn_sum = torch.matmul(attn.transpose(1,2), delta).transpose(1,2)
# x_obj = B * C * H * W
X_obj = self.rho(attn_sum).view(batch, -1, height, width)
concat = torch.cat([X, X_obj], 1)
out = self.g(concat)
return out
def forward(self, x, flag='train'):
batch_size = x.shape[0]
#=== Stem ===#
x = self.firstconv(x)
x = self.firstbn(x)
x = self.firstrelu(x)
x_ = self.firstmaxpool(x)
#=== Encoder ===#
e4, e3, e2, e1 = self.down(x_)
#=== Attentive Cross Image Interaction ===#
aux_out, patch, feats_flat, feats = self.region_representation(e4)
if flag == 'train':
self.update_bank(patch)
ptr = int(self.bank_ptr)
if self.bank_full == True:
feature_aug = self.attentive_interaction(self.bank, feats_flat, feats)
else:
feature_aug = self.attentive_interaction(self.bank[0:ptr], feats_flat, feats)
elif flag == 'test':
feature_aug = self.attentive_interaction(patch, feats_flat, feats)
#=== Dual Attention ===#
sa_feat = self.sa_head(e4)
#=== Fusion ===#
feats = sa_feat + feature_aug
#=== Decoder ===#
f4, f3, f2, f1 = self.up(feats, e3, e2, e1, x)
aux_out = F.interpolate(aux_out, scale_factor=32, mode='bilinear', align_corners=True)
return aux_out, f4, f3, f2, f1
实验分析
实验部分主要包含以下几个方面:
| 数据集名称 | 图像数量 | train | valid | test |
|---|---|---|---|---|
| EndoScene | 912 | 548 | 182 | 182 |
| Kvasir-SEG | 1000 | 600 | 200 | 200 |
| PICCOLO | 3433 | 2203 | 897 | 333 |
| 设备 | 学习率 | epoches | batchsize | memory size |
|---|---|---|---|---|
| NVIDIA RTX 2080Ti | 1e-4 | 150 | 4 | 20(Kvasir) / 40(E & P) |


从可视化和表格数据上,我们能够看出本文模型的有效性!

对于这两个经典模型,有着不错的提高,说明了本模型的设计和内外上下文推理体系的合理性。
讨论
本文最大的亮点应该是外部memory 的设定,对于整个模型的体系架构,我们应当学习到这种内部隐性的分类思想和理念,所谓的外部上下文关系模块的机理也是如此!
厚着脸皮,要个点赞收藏,谢谢支持!!!
边栏推荐
- 1 goal, 3 fields, 6 factors and 9 links of digital transformation
- 基于趋势和季节性的时间序列预测
- The amazing nanopc-t4 (rk3399) is used as the initial configuration and related applications of the workstation
- Understanding of closures
- sql面试题:求连续最大登录天数
- 牛津大學教授Michael Wooldridge:AI社區近40年如何看待神經網絡
- Render function parsing
- In which industries did the fire virtual human start to make efforts?
- 1 goal, 3 fields, 6 factors and 9 links of digital transformation
- OpenHarmony—内核对象事件之源码详解
猜你喜欢

sql面试题:求连续最大登录天数

openGauss内核:SQL解析过程分析
![[C #] explain the difference between value type and reference type](/img/23/5bcbfc5f9cc6e8f4d647acf9219b08.png)
[C #] explain the difference between value type and reference type

In which industries did the fire virtual human start to make efforts?

How to resolve kernel errors? Solution to kernel error of win11 system

Question brushing analysis tool

使用Karmada实现Helm应用的跨集群部署

春风动力携手华为打造智慧园区标杆,未来工厂创新迈上新台阶

rancher增加/删除node节点

Idea merge other branches into dev branch
随机推荐
POI excel conversion tool
让企业数字化砸锅和IT主管背锅的软件供应链安全风险指南
Summary of the use of qobjectcleanuphandler in QT
月环比sql实现
Live app system source code, automatically playing when encountering video dynamically
180.1. Log in continuously for n days (database)
How to change the status bar at the bottom of win11 to black? How to change the status bar at the bottom of win11 to black
Shell脚本批量修改文件目录权限
Find out the users who log in for 7 consecutive days and 30 consecutive days
多测师肖sirapp中riginal error: Could not extract PIDs from ps output. PIDS: [], Procs: [“bad pid
[unity3d] camera follow
Month on month SQL implementation
内核错误怎么解决?Win11系统内核错误解决方法
math_ Proving common equivalent infinitesimal & Case & substitution
深度学习需要多强的数学基础?
SQL calculates daily new users and retention rate indicators
Native implementation Net5.0+ custom log
How to remove dataframe field column names
应用实践 | 10 亿数据秒级关联,货拉拉基于 Apache Doris 的 OLAP 体系演进(附 PPT 下载)
Professor Michael Wooldridge of Oxford University: how the AI community views neural networks in the past 40 years