当前位置:网站首页>【论文代码】SML部分代码阅读
【论文代码】SML部分代码阅读
2022-07-06 05:46:00 【starbuling~】
SML中的边界抑制以及高斯平滑
边界平滑抑制类
class BoundarySuppressionWithSmoothing(nn.Module):
""" Apply boundary suppression and dilated smoothing 边界抑制,扩张平滑 """
初始化
def __init__(self, boundary_suppression=True, boundary_width=4, boundary_iteration=4,
dilated_smoothing=True, kernel_size=7, dilation=6):
定义一些参数
super(BoundarySuppressionWithSmoothing, self).__init__()
self.kernel_size = kernel_size # 卷积核大小
self.dilation = dilation # 扩张
self.boundary_suppression = boundary_suppression # 边界抑制
self.boundary_width = boundary_width # 边界宽度
self.boundary_iteration = boundary_iteration # 边界迭代
创建高斯核
sigma = 1.0
size = 7
# function为二维高斯分布的概率密度函数
gaussian_kernel = np.fromfunction(lambda x, y:
(1/(2*math.pi*sigma**2)) * math.e ** ((-1*((x-(size-1)/2)**2+(y-(size-1)/2)**2))/(2*sigma**2)),
(size, size)) # 构造高斯核 (7,7) 3 * sigma + 1
gaussian_kernel /= np.sum(gaussian_kernel) # 除以高斯核中所有元素之和(加权平均,避免图像像素溢出)
gaussian_kernel = torch.Tensor(gaussian_kernel).unsqueeze(0).unsqueeze(0)
self.dilated_smoothing = dilated_smoothing # 扩张平滑
l a m b d a ( x , y ) = 1 2 ∗ π ∗ σ 2 e x p ( − ( x − s i z e − 1 2 ) 2 + ( y − s i z e − 1 2 ) 2 2 ∗ σ 2 ) lambda(x,y) = \frac{1}{2 * \pi * \sigma^2} exp(-\frac{(x - \frac{size-1}{2})^2 + (y - \frac{size-1}{2})^2}{2 * \sigma^2}) lambda(x,y)=2∗π∗σ21exp(−2∗σ2(x−2size−1)2+(y−2size−1)2)
numpy库中的
fromfunction
:通过自定义的函数fun,形状shape,数据格式dtype -> 根据数组下标(x,y)生成每个位置的值,构成一个数组
函数参数
np.fromfunction(function, shape, dtype)
function:根据坐标变换成一个具体的值的函数
def function(x,y): 函数内部 (x,y) 分别是以左上角为原点的坐标,x为行坐标,y为列坐标,表示第x行y列。
shape(a,b):表示数组array的大小,a行b列。
dtype: 表示数组的数类型
定义两层卷积 (in_channel, out_channel, k, s)
- (1, 1, 3, 1) 权重矩阵是全1矩阵
- (1, 1, 7, 1) 权重矩阵是高斯核
self.first_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, bias=False)
self.first_conv.weight = torch.nn.Parameter(torch.ones_like((self.first_conv.weight)))
self.second_conv = nn.Conv2d(
1, 1, kernel_size=self.kernel_size, stride=1, dilation=self.dilation, bias=False)
self.second_conv.weight = torch.nn.Parameter(gaussian_kernel)
前向传播
def forward(self, x, prediction=None):
if len(x.shape) == 3:
x = x.unsqueeze(1) # 如果是3维,扩充1维
x_size = x.size()
# B x 1 x H x W
assert len(x.shape) == 4
out = x
分支1:需要边界抑制
if self.boundary_suppression:
# obtain the boundary map of width 2 by default 默认获取宽度为2的边界图
# this can be calculated by the difference of dilation and erosion 这可以通过膨胀和腐蚀的差异来计算
boundaries = find_boundaries(prediction.unsqueeze(1)) # 寻找边界
expanded_boundaries = None
if self.boundary_iteration != 0:
assert self.boundary_width % self.boundary_iteration == 0 # 边界宽度要被迭代次数整除
diff = self.boundary_width // self.boundary_iteration # 每次增加的宽度
边界抑制主要过程
for iteration in range(self.boundary_iteration):
if len(out.shape) != 4:
out = out.unsqueeze(1)
prev_out = out
得到边界
# if it is the last iteration or boundary width is zero 最后一次迭代或者边界宽度为0后,停止扩展宽度 if self.boundary_width == 0 or iteration == self.boundary_iteration - 1: expansion_width = 0 # reduce the expansion width for each iteration 否则就在每个迭代不断货站宽度 else: expansion_width = self.boundary_width - diff * iteration - 1 # expand the boundary obtained from the prediction (width of 2) by expansion rate expanded_boundaries = expand_boundaries(boundaries, r=expansion_width) # 根据扩展宽度扩展边界,具体方法在后面的函数详细解释中
反转边界 -> 获得非边界掩码
# invert it so that we can obtain non-boundary mask non_boundary_mask = 1. * (expanded_boundaries == 0) # 反转边界,得到非边界掩码。非边界为1,边界为0
使得边界区域 to 0
f_size = 1 num_pad = f_size # making boundary regions to 0 x_masked = out * non_boundary_mask # 输入图像 * 非边界掩码 -> 得到非边界区域(1) x_padded = nn.ReplicationPad2d(num_pad)(x_masked) non_boundary_mask_padded = nn.ReplicationPad2d(num_pad)(non_boundary_mask)
class torch.nn.ReplicationPad2d(padding)
padding(int ,tuple)填充的大小。如果为 int ,则在所有边界中使用相同的填充。
如果是4 tuple ,则使用(padding_left, padding_right, padding_top, padding_bottom)求和感受野中的值
# sum up the values in the receptive field y = self.first_conv(x_padded) # count non-boundary elements in the receptive field num_calced_elements = self.first_conv(non_boundary_mask_padded) num_calced_elements = num_calced_elements.long()
求平均
# take an average by dividing y by count # if there is no non-boundary element in the receptive field, # keep the original value avg_y = torch.where((num_calced_elements == 0), prev_out, y / num_calced_elements) out = avg_y
更新边界
# update boundaries only out = torch.where((non_boundary_mask == 0), out, prev_out) del expanded_boundaries, non_boundary_mask
第二步骤:扩张平滑
# second stage; apply dilated smoothing
if self.dilated_smoothing == True:
out = nn.ReplicationPad2d(self.dilation * 3)(out)
out = self.second_conv(out)
return out.squeeze(1)
分支1:不需要边界抑制
else:
if self.dilated_smoothing == True: # 扩张平滑
out = nn.ReplicationPad2d(self.dilation * 3)(out)
out = self.second_conv(out)
else:
out = x
return out.squeeze(1)
find_boundaries
def find_boundaries(label):
""" Calculate boundary mask by getting diff of dilated and eroded prediction maps """
assert len(label.shape) == 4
boundaries = (dilation(label.float(), selem_dilation) != erosion(label.float(), selem)).float()
### save_image(boundaries, f'boundaries_{boundaries.float().mean():.2f}.png', normalize=True)
return boundaries
selem = torch.ones((3, 3)).cuda() # 是一个(3,3)大小的全1的张量,腐蚀卷集核
selem_dilation = torch.FloatTensor(ndi.generate_binary_structure(2, 1)).cuda() # 膨胀卷积核
腐蚀:
膨胀:
膨胀(dilation) & 腐蚀(erosion)
这是两种基本的形态学运算,主要用来寻找图像中的极大区域和极小区域。
- 膨胀类似与 ‘领域扩张’ ,将图像的高亮区域或白色部分进行扩张,其运行结果图比原图的高亮区域更大。
- 腐蚀类似 ‘领域被蚕食’ ,将图像中的高亮区域或白色部分进行缩减细化,其运行结果图比原图的高亮区域更小。
具体过程:定义一个卷积核,对图片进行卷积。膨胀做“或”操作,扩大1的范围;腐蚀做“与”操作,减少1的数量
dilation(image, kernel) # 图像,卷积核
erosion(image, kernel)
对标签图分别做膨胀腐蚀后,不一样的位置,就是边界。用1表示
expand_boundaries
def expand_boundaries(boundaries, r=0):
""" Expand boundary maps with the rate of r """
if r == 0:
return boundaries
expanded_boundaries = dilation(boundaries, d_ks[r]) # 做膨胀操作
### save_image(expanded_boundaries, f'expanded_boundaries_{r}_{boundaries.float().mean():.2f}.png', normalize=True)
return expanded_boundaries
关于d_ks[]
:
d_k1 = torch.zeros((1, 1, 2 * 1 + 1, 2 * 1 + 1)).cuda()
d_k2 = torch.zeros((1, 1, 2 * 2 + 1, 2 * 2 + 1)).cuda()
d_k3 = torch.zeros((1, 1, 2 * 3 + 1, 2 * 3 + 1)).cuda()
d_k4 = torch.zeros((1, 1, 2 * 4 + 1, 2 * 4 + 1)).cuda()
d_k5 = torch.zeros((1, 1, 2 * 5 + 1, 2 * 5 + 1)).cuda()
d_k6 = torch.zeros((1, 1, 2 * 6 + 1, 2 * 6 + 1)).cuda()
d_k7 = torch.zeros((1, 1, 2 * 7 + 1, 2 * 7 + 1)).cuda()
d_k8 = torch.zeros((1, 1, 2 * 8 + 1, 2 * 8 + 1)).cuda()
d_k9 = torch.zeros((1, 1, 2 * 9 + 1, 2 * 9 + 1)).cuda()
d_ks = {
1: d_k1, 2: d_k2, 3: d_k3, 4: d_k4,
5: d_k5, 6: d_k6, 7: d_k7, 8: d_k8, 9: d_k9}
for k, v in d_ks.items():
v[:, :, k, k] = 1
for i in range(k):
v = dilation(v, selem_dilation)
d_ks[k] = v.squeeze(0).squeeze(0)
print(f'dilation kernel at {
k}:\n\n{
d_ks[k]}')
这些卷积核的样子大致如下,以此类推
边栏推荐
- Note the various data set acquisition methods of jvxetable
- 27io stream, byte output stream, OutputStream writes data to file
- B站刘二大人-数据集及数据加载 Lecture 8
- [machine learning notes] univariate linear regression principle, formula and code implementation
- 网站进行服务器迁移前应做好哪些准备?
- wib3.0 跨越,在跨越(ง •̀_•́)ง
- 大型网站如何选择比较好的云主机服务商?
- 养了只小猫咪
- Analysis of grammar elements in turtle Library
- Construction of yolox based on paste framework
猜你喜欢
Jushan database appears again in the gold fair to jointly build a new era of digital economy
A master in the field of software architecture -- Reading Notes of the beauty of Architecture
Installation de la Bibliothèque de processus PDK - csmc
B站刘二大人-Softmx分类器及MNIST实现-Lecture 9
How can large websites choose better virtual machine service providers?
Promise summary
59. Spiral matrix
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
First knowledge database
初识数据库
随机推荐
B站刘二大人-多元逻辑回归 Lecture 7
无代码六月大事件|2022无代码探索者大会即将召开;AI增强型无代码工具推出...
B站刘二大人-Softmx分类器及MNIST实现-Lecture 9
P2802 go home
Some easy-to-use tools make your essay style more elegant
What impact will frequent job hopping have on your career?
大型网站如何选择比较好的云主机服务商?
[machine learning notes] univariate linear regression principle, formula and code implementation
H3C V7版本交换机配置IRF
进程和线程
28io stream, byte output stream writes multiple bytes
Jvxetable implant j-popup with slot
High quality coding tool clion
Station B Liu Erden softmx classifier and MNIST implementation -structure 9
网络协议模型
59. Spiral matrix
What is independent IP and how about independent IP host?
Web服务连接器:Servlet
[imgui] unity MenuItem shortcut key
[detailed explanation of Huawei machine test] check whether there is a digital combination that meets the conditions