当前位置:网站首页>【论文代码】SML部分代码阅读
【论文代码】SML部分代码阅读
2022-07-06 05:46:00 【starbuling~】
SML中的边界抑制以及高斯平滑
边界平滑抑制类
class BoundarySuppressionWithSmoothing(nn.Module):
""" Apply boundary suppression and dilated smoothing 边界抑制,扩张平滑 """
初始化
def __init__(self, boundary_suppression=True, boundary_width=4, boundary_iteration=4,
dilated_smoothing=True, kernel_size=7, dilation=6):
定义一些参数
super(BoundarySuppressionWithSmoothing, self).__init__()
self.kernel_size = kernel_size # 卷积核大小
self.dilation = dilation # 扩张
self.boundary_suppression = boundary_suppression # 边界抑制
self.boundary_width = boundary_width # 边界宽度
self.boundary_iteration = boundary_iteration # 边界迭代
创建高斯核
sigma = 1.0
size = 7
# function为二维高斯分布的概率密度函数
gaussian_kernel = np.fromfunction(lambda x, y:
(1/(2*math.pi*sigma**2)) * math.e ** ((-1*((x-(size-1)/2)**2+(y-(size-1)/2)**2))/(2*sigma**2)),
(size, size)) # 构造高斯核 (7,7) 3 * sigma + 1
gaussian_kernel /= np.sum(gaussian_kernel) # 除以高斯核中所有元素之和(加权平均,避免图像像素溢出)
gaussian_kernel = torch.Tensor(gaussian_kernel).unsqueeze(0).unsqueeze(0)
self.dilated_smoothing = dilated_smoothing # 扩张平滑
l a m b d a ( x , y ) = 1 2 ∗ π ∗ σ 2 e x p ( − ( x − s i z e − 1 2 ) 2 + ( y − s i z e − 1 2 ) 2 2 ∗ σ 2 ) lambda(x,y) = \frac{1}{2 * \pi * \sigma^2} exp(-\frac{(x - \frac{size-1}{2})^2 + (y - \frac{size-1}{2})^2}{2 * \sigma^2}) lambda(x,y)=2∗π∗σ21exp(−2∗σ2(x−2size−1)2+(y−2size−1)2)
numpy库中的
fromfunction
:通过自定义的函数fun,形状shape,数据格式dtype -> 根据数组下标(x,y)生成每个位置的值,构成一个数组
函数参数
np.fromfunction(function, shape, dtype)
function:根据坐标变换成一个具体的值的函数
def function(x,y): 函数内部 (x,y) 分别是以左上角为原点的坐标,x为行坐标,y为列坐标,表示第x行y列。
shape(a,b):表示数组array的大小,a行b列。
dtype: 表示数组的数类型
定义两层卷积 (in_channel, out_channel, k, s)
- (1, 1, 3, 1) 权重矩阵是全1矩阵
- (1, 1, 7, 1) 权重矩阵是高斯核
self.first_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, bias=False)
self.first_conv.weight = torch.nn.Parameter(torch.ones_like((self.first_conv.weight)))
self.second_conv = nn.Conv2d(
1, 1, kernel_size=self.kernel_size, stride=1, dilation=self.dilation, bias=False)
self.second_conv.weight = torch.nn.Parameter(gaussian_kernel)
前向传播
def forward(self, x, prediction=None):
if len(x.shape) == 3:
x = x.unsqueeze(1) # 如果是3维,扩充1维
x_size = x.size()
# B x 1 x H x W
assert len(x.shape) == 4
out = x
分支1:需要边界抑制
if self.boundary_suppression:
# obtain the boundary map of width 2 by default 默认获取宽度为2的边界图
# this can be calculated by the difference of dilation and erosion 这可以通过膨胀和腐蚀的差异来计算
boundaries = find_boundaries(prediction.unsqueeze(1)) # 寻找边界
expanded_boundaries = None
if self.boundary_iteration != 0:
assert self.boundary_width % self.boundary_iteration == 0 # 边界宽度要被迭代次数整除
diff = self.boundary_width // self.boundary_iteration # 每次增加的宽度
边界抑制主要过程
for iteration in range(self.boundary_iteration):
if len(out.shape) != 4:
out = out.unsqueeze(1)
prev_out = out
得到边界
# if it is the last iteration or boundary width is zero 最后一次迭代或者边界宽度为0后,停止扩展宽度 if self.boundary_width == 0 or iteration == self.boundary_iteration - 1: expansion_width = 0 # reduce the expansion width for each iteration 否则就在每个迭代不断货站宽度 else: expansion_width = self.boundary_width - diff * iteration - 1 # expand the boundary obtained from the prediction (width of 2) by expansion rate expanded_boundaries = expand_boundaries(boundaries, r=expansion_width) # 根据扩展宽度扩展边界,具体方法在后面的函数详细解释中
反转边界 -> 获得非边界掩码
# invert it so that we can obtain non-boundary mask non_boundary_mask = 1. * (expanded_boundaries == 0) # 反转边界,得到非边界掩码。非边界为1,边界为0
使得边界区域 to 0
f_size = 1 num_pad = f_size # making boundary regions to 0 x_masked = out * non_boundary_mask # 输入图像 * 非边界掩码 -> 得到非边界区域(1) x_padded = nn.ReplicationPad2d(num_pad)(x_masked) non_boundary_mask_padded = nn.ReplicationPad2d(num_pad)(non_boundary_mask)
class torch.nn.ReplicationPad2d(padding)
padding(int ,tuple)填充的大小。如果为 int ,则在所有边界中使用相同的填充。
如果是4 tuple ,则使用(padding_left, padding_right, padding_top, padding_bottom)求和感受野中的值
# sum up the values in the receptive field y = self.first_conv(x_padded) # count non-boundary elements in the receptive field num_calced_elements = self.first_conv(non_boundary_mask_padded) num_calced_elements = num_calced_elements.long()
求平均
# take an average by dividing y by count # if there is no non-boundary element in the receptive field, # keep the original value avg_y = torch.where((num_calced_elements == 0), prev_out, y / num_calced_elements) out = avg_y
更新边界
# update boundaries only out = torch.where((non_boundary_mask == 0), out, prev_out) del expanded_boundaries, non_boundary_mask
第二步骤:扩张平滑
# second stage; apply dilated smoothing
if self.dilated_smoothing == True:
out = nn.ReplicationPad2d(self.dilation * 3)(out)
out = self.second_conv(out)
return out.squeeze(1)
分支1:不需要边界抑制
else:
if self.dilated_smoothing == True: # 扩张平滑
out = nn.ReplicationPad2d(self.dilation * 3)(out)
out = self.second_conv(out)
else:
out = x
return out.squeeze(1)
find_boundaries
def find_boundaries(label):
""" Calculate boundary mask by getting diff of dilated and eroded prediction maps """
assert len(label.shape) == 4
boundaries = (dilation(label.float(), selem_dilation) != erosion(label.float(), selem)).float()
### save_image(boundaries, f'boundaries_{boundaries.float().mean():.2f}.png', normalize=True)
return boundaries
selem = torch.ones((3, 3)).cuda() # 是一个(3,3)大小的全1的张量,腐蚀卷集核
selem_dilation = torch.FloatTensor(ndi.generate_binary_structure(2, 1)).cuda() # 膨胀卷积核
腐蚀:
膨胀:
膨胀(dilation) & 腐蚀(erosion)
这是两种基本的形态学运算,主要用来寻找图像中的极大区域和极小区域。
- 膨胀类似与 ‘领域扩张’ ,将图像的高亮区域或白色部分进行扩张,其运行结果图比原图的高亮区域更大。
- 腐蚀类似 ‘领域被蚕食’ ,将图像中的高亮区域或白色部分进行缩减细化,其运行结果图比原图的高亮区域更小。
具体过程:定义一个卷积核,对图片进行卷积。膨胀做“或”操作,扩大1的范围;腐蚀做“与”操作,减少1的数量
dilation(image, kernel) # 图像,卷积核
erosion(image, kernel)
对标签图分别做膨胀腐蚀后,不一样的位置,就是边界。用1表示
expand_boundaries
def expand_boundaries(boundaries, r=0):
""" Expand boundary maps with the rate of r """
if r == 0:
return boundaries
expanded_boundaries = dilation(boundaries, d_ks[r]) # 做膨胀操作
### save_image(expanded_boundaries, f'expanded_boundaries_{r}_{boundaries.float().mean():.2f}.png', normalize=True)
return expanded_boundaries
关于d_ks[]
:
d_k1 = torch.zeros((1, 1, 2 * 1 + 1, 2 * 1 + 1)).cuda()
d_k2 = torch.zeros((1, 1, 2 * 2 + 1, 2 * 2 + 1)).cuda()
d_k3 = torch.zeros((1, 1, 2 * 3 + 1, 2 * 3 + 1)).cuda()
d_k4 = torch.zeros((1, 1, 2 * 4 + 1, 2 * 4 + 1)).cuda()
d_k5 = torch.zeros((1, 1, 2 * 5 + 1, 2 * 5 + 1)).cuda()
d_k6 = torch.zeros((1, 1, 2 * 6 + 1, 2 * 6 + 1)).cuda()
d_k7 = torch.zeros((1, 1, 2 * 7 + 1, 2 * 7 + 1)).cuda()
d_k8 = torch.zeros((1, 1, 2 * 8 + 1, 2 * 8 + 1)).cuda()
d_k9 = torch.zeros((1, 1, 2 * 9 + 1, 2 * 9 + 1)).cuda()
d_ks = {
1: d_k1, 2: d_k2, 3: d_k3, 4: d_k4,
5: d_k5, 6: d_k6, 7: d_k7, 8: d_k8, 9: d_k9}
for k, v in d_ks.items():
v[:, :, k, k] = 1
for i in range(k):
v = dilation(v, selem_dilation)
d_ks[k] = v.squeeze(0).squeeze(0)
print(f'dilation kernel at {
k}:\n\n{
d_ks[k]}')
这些卷积核的样子大致如下,以此类推
边栏推荐
- Remember an error in MySQL: the user specified as a definer ('mysql.infoschema '@' localhost ') does not exist
- js Array 列表 实战使用总结
- [experience] when ultralso makes a startup disk, there is an error: the disk / image capacity is too small
- [email protected] raspberry pie
- 03. Login of development blog project
- [JVM] [Chapter 17] [garbage collector]
- PDK工艺库安装-CSMC
- P2802 回家
- 大型网站如何选择比较好的云主机服务商?
- Game push: image / table /cv/nlp, multi-threaded start!
猜你喜欢
实践分享:如何安全快速地从 Centos迁移到openEuler
[SQL Server Express Way] - authentification et création et gestion de comptes utilisateurs
C language learning notes (mind map)
Station B, Master Liu Er - dataset and data loading
Clear floating mode
无代码六月大事件|2022无代码探索者大会即将召开;AI增强型无代码工具推出...
嵌入式面试题(四、常见算法)
Redis message queue
Redis消息队列
进程和线程
随机推荐
ArcGIS application foundation 4 thematic map making
华为路由器忘记密码怎么恢复
巨杉数据库再次亮相金交会,共建数字经济新时代
自建DNS服务器,客户端打开网页慢,解决办法
What preparations should be made for website server migration?
Game push image / table /cv/nlp, multi-threaded start
Summary of data sets in intrusion detection field
[detailed explanation of Huawei machine test] check whether there is a digital combination that meets the conditions
Web服务连接器:Servlet
First knowledge database
Li Chuang EDA learning notes 12: common PCB board layout constraint principles
Sequoiadb Lake warehouse integrated distributed database, June 2022 issue
Embedded interview questions (IV. common algorithms)
How to get list length
PDK工艺库安装-CSMC
Mysql database master-slave cluster construction
Luogu [Beginner Level 4] array p1427 number game of small fish
01. Project introduction of blog development project
Station B Liu Erden softmx classifier and MNIST implementation -structure 9
[email protected] raspberry pie