当前位置:网站首页>【论文代码】SML部分代码阅读
【论文代码】SML部分代码阅读
2022-07-06 05:46:00 【starbuling~】
SML中的边界抑制以及高斯平滑
边界平滑抑制类
class BoundarySuppressionWithSmoothing(nn.Module):
""" Apply boundary suppression and dilated smoothing 边界抑制,扩张平滑 """
初始化
def __init__(self, boundary_suppression=True, boundary_width=4, boundary_iteration=4,
dilated_smoothing=True, kernel_size=7, dilation=6):
定义一些参数
super(BoundarySuppressionWithSmoothing, self).__init__()
self.kernel_size = kernel_size # 卷积核大小
self.dilation = dilation # 扩张
self.boundary_suppression = boundary_suppression # 边界抑制
self.boundary_width = boundary_width # 边界宽度
self.boundary_iteration = boundary_iteration # 边界迭代
创建高斯核
sigma = 1.0
size = 7
# function为二维高斯分布的概率密度函数
gaussian_kernel = np.fromfunction(lambda x, y:
(1/(2*math.pi*sigma**2)) * math.e ** ((-1*((x-(size-1)/2)**2+(y-(size-1)/2)**2))/(2*sigma**2)),
(size, size)) # 构造高斯核 (7,7) 3 * sigma + 1
gaussian_kernel /= np.sum(gaussian_kernel) # 除以高斯核中所有元素之和(加权平均,避免图像像素溢出)
gaussian_kernel = torch.Tensor(gaussian_kernel).unsqueeze(0).unsqueeze(0)
self.dilated_smoothing = dilated_smoothing # 扩张平滑
l a m b d a ( x , y ) = 1 2 ∗ π ∗ σ 2 e x p ( − ( x − s i z e − 1 2 ) 2 + ( y − s i z e − 1 2 ) 2 2 ∗ σ 2 ) lambda(x,y) = \frac{1}{2 * \pi * \sigma^2} exp(-\frac{(x - \frac{size-1}{2})^2 + (y - \frac{size-1}{2})^2}{2 * \sigma^2}) lambda(x,y)=2∗π∗σ21exp(−2∗σ2(x−2size−1)2+(y−2size−1)2)
numpy库中的
fromfunction:通过自定义的函数fun,形状shape,数据格式dtype -> 根据数组下标(x,y)生成每个位置的值,构成一个数组
函数参数
np.fromfunction(function, shape, dtype)
function:根据坐标变换成一个具体的值的函数
def function(x,y): 函数内部 (x,y) 分别是以左上角为原点的坐标,x为行坐标,y为列坐标,表示第x行y列。shape(a,b):表示数组array的大小,a行b列。
dtype: 表示数组的数类型
定义两层卷积 (in_channel, out_channel, k, s)
- (1, 1, 3, 1) 权重矩阵是全1矩阵
- (1, 1, 7, 1) 权重矩阵是高斯核
self.first_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, bias=False)
self.first_conv.weight = torch.nn.Parameter(torch.ones_like((self.first_conv.weight)))
self.second_conv = nn.Conv2d(
1, 1, kernel_size=self.kernel_size, stride=1, dilation=self.dilation, bias=False)
self.second_conv.weight = torch.nn.Parameter(gaussian_kernel)
前向传播
def forward(self, x, prediction=None):
if len(x.shape) == 3:
x = x.unsqueeze(1) # 如果是3维,扩充1维
x_size = x.size()
# B x 1 x H x W
assert len(x.shape) == 4
out = x
分支1:需要边界抑制
if self.boundary_suppression:
# obtain the boundary map of width 2 by default 默认获取宽度为2的边界图
# this can be calculated by the difference of dilation and erosion 这可以通过膨胀和腐蚀的差异来计算
boundaries = find_boundaries(prediction.unsqueeze(1)) # 寻找边界
expanded_boundaries = None
if self.boundary_iteration != 0:
assert self.boundary_width % self.boundary_iteration == 0 # 边界宽度要被迭代次数整除
diff = self.boundary_width // self.boundary_iteration # 每次增加的宽度
边界抑制主要过程
for iteration in range(self.boundary_iteration):
if len(out.shape) != 4:
out = out.unsqueeze(1)
prev_out = out
得到边界
# if it is the last iteration or boundary width is zero 最后一次迭代或者边界宽度为0后,停止扩展宽度 if self.boundary_width == 0 or iteration == self.boundary_iteration - 1: expansion_width = 0 # reduce the expansion width for each iteration 否则就在每个迭代不断货站宽度 else: expansion_width = self.boundary_width - diff * iteration - 1 # expand the boundary obtained from the prediction (width of 2) by expansion rate expanded_boundaries = expand_boundaries(boundaries, r=expansion_width) # 根据扩展宽度扩展边界,具体方法在后面的函数详细解释中反转边界 -> 获得非边界掩码
# invert it so that we can obtain non-boundary mask non_boundary_mask = 1. * (expanded_boundaries == 0) # 反转边界,得到非边界掩码。非边界为1,边界为0使得边界区域 to 0
f_size = 1 num_pad = f_size # making boundary regions to 0 x_masked = out * non_boundary_mask # 输入图像 * 非边界掩码 -> 得到非边界区域(1) x_padded = nn.ReplicationPad2d(num_pad)(x_masked) non_boundary_mask_padded = nn.ReplicationPad2d(num_pad)(non_boundary_mask)class torch.nn.ReplicationPad2d(padding)
padding(int ,tuple)填充的大小。如果为 int ,则在所有边界中使用相同的填充。
如果是4 tuple ,则使用(padding_left, padding_right, padding_top, padding_bottom)求和感受野中的值
# sum up the values in the receptive field y = self.first_conv(x_padded) # count non-boundary elements in the receptive field num_calced_elements = self.first_conv(non_boundary_mask_padded) num_calced_elements = num_calced_elements.long()求平均
# take an average by dividing y by count # if there is no non-boundary element in the receptive field, # keep the original value avg_y = torch.where((num_calced_elements == 0), prev_out, y / num_calced_elements) out = avg_y更新边界
# update boundaries only out = torch.where((non_boundary_mask == 0), out, prev_out) del expanded_boundaries, non_boundary_mask
第二步骤:扩张平滑
# second stage; apply dilated smoothing
if self.dilated_smoothing == True:
out = nn.ReplicationPad2d(self.dilation * 3)(out)
out = self.second_conv(out)
return out.squeeze(1)
分支1:不需要边界抑制
else:
if self.dilated_smoothing == True: # 扩张平滑
out = nn.ReplicationPad2d(self.dilation * 3)(out)
out = self.second_conv(out)
else:
out = x
return out.squeeze(1)
find_boundaries
def find_boundaries(label):
""" Calculate boundary mask by getting diff of dilated and eroded prediction maps """
assert len(label.shape) == 4
boundaries = (dilation(label.float(), selem_dilation) != erosion(label.float(), selem)).float()
### save_image(boundaries, f'boundaries_{boundaries.float().mean():.2f}.png', normalize=True)
return boundaries
selem = torch.ones((3, 3)).cuda() # 是一个(3,3)大小的全1的张量,腐蚀卷集核
selem_dilation = torch.FloatTensor(ndi.generate_binary_structure(2, 1)).cuda() # 膨胀卷积核
腐蚀:
膨胀:
膨胀(dilation) & 腐蚀(erosion)
这是两种基本的形态学运算,主要用来寻找图像中的极大区域和极小区域。
- 膨胀类似与 ‘领域扩张’ ,将图像的高亮区域或白色部分进行扩张,其运行结果图比原图的高亮区域更大。
- 腐蚀类似 ‘领域被蚕食’ ,将图像中的高亮区域或白色部分进行缩减细化,其运行结果图比原图的高亮区域更小。
具体过程:定义一个卷积核,对图片进行卷积。膨胀做“或”操作,扩大1的范围;腐蚀做“与”操作,减少1的数量
dilation(image, kernel) # 图像,卷积核
erosion(image, kernel)
对标签图分别做膨胀腐蚀后,不一样的位置,就是边界。用1表示
expand_boundaries
def expand_boundaries(boundaries, r=0):
""" Expand boundary maps with the rate of r """
if r == 0:
return boundaries
expanded_boundaries = dilation(boundaries, d_ks[r]) # 做膨胀操作
### save_image(expanded_boundaries, f'expanded_boundaries_{r}_{boundaries.float().mean():.2f}.png', normalize=True)
return expanded_boundaries
关于d_ks[]:
d_k1 = torch.zeros((1, 1, 2 * 1 + 1, 2 * 1 + 1)).cuda()
d_k2 = torch.zeros((1, 1, 2 * 2 + 1, 2 * 2 + 1)).cuda()
d_k3 = torch.zeros((1, 1, 2 * 3 + 1, 2 * 3 + 1)).cuda()
d_k4 = torch.zeros((1, 1, 2 * 4 + 1, 2 * 4 + 1)).cuda()
d_k5 = torch.zeros((1, 1, 2 * 5 + 1, 2 * 5 + 1)).cuda()
d_k6 = torch.zeros((1, 1, 2 * 6 + 1, 2 * 6 + 1)).cuda()
d_k7 = torch.zeros((1, 1, 2 * 7 + 1, 2 * 7 + 1)).cuda()
d_k8 = torch.zeros((1, 1, 2 * 8 + 1, 2 * 8 + 1)).cuda()
d_k9 = torch.zeros((1, 1, 2 * 9 + 1, 2 * 9 + 1)).cuda()
d_ks = {
1: d_k1, 2: d_k2, 3: d_k3, 4: d_k4,
5: d_k5, 6: d_k6, 7: d_k7, 8: d_k8, 9: d_k9}
for k, v in d_ks.items():
v[:, :, k, k] = 1
for i in range(k):
v = dilation(v, selem_dilation)
d_ks[k] = v.squeeze(0).squeeze(0)
print(f'dilation kernel at {
k}:\n\n{
d_ks[k]}')
这些卷积核的样子大致如下,以此类推
边栏推荐
- Web服务连接器:Servlet
- 养了只小猫咪
- Sequoiadb Lake warehouse integrated distributed database, June 2022 issue
- PDK process library installation -csmc
- C language learning notes (mind map)
- Node 之 nvm 下载、安装、使用,以及node 、nrm 的相关使用
- Station B Liu Erden - linear regression and gradient descent
- Garbage collector with serial, throughput priority and response time priority
- Redis消息队列
- 局域网同一个网段通信过程
猜你喜欢

The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower

移植InfoNES到STM32

无代码六月大事件|2022无代码探索者大会即将召开;AI增强型无代码工具推出...

应用安全系列之三十七:日志注入

What preparations should be made for website server migration?
【SQL server速成之路】——身份验证及建立和管理用户账户

PDK工艺库安装-CSMC

Leetcode 701 insertion operation in binary search tree -- recursive method and iterative method
[email protected] raspberry pie"/>[email protected] raspberry pie

Self built DNS server, the client opens the web page slowly, the solution
随机推荐
【SQL server速成之路】——身份驗證及建立和管理用戶賬戶
Sequoiadb Lake warehouse integrated distributed database, June 2022 issue
59. Spiral matrix
养了只小猫咪
[experience] install Visio on win11
入侵检测领域数据集总结
Summary of data sets in intrusion detection field
Codeless June event 2022 codeless Explorer conference will be held soon; AI enhanced codeless tool launched
Zoom through the mouse wheel
27io stream, byte output stream, OutputStream writes data to file
[JVM] [Chapter 17] [garbage collector]
YYGH-11-定时统计
自建DNS服务器,客户端打开网页慢,解决办法
查询生产订单中某个(些)工作中心对应的标准文本码
Anti shake and throttling are easy to understand
ArcGIS application foundation 4 thematic map making
华为路由器忘记密码怎么恢复
Report on market depth analysis and future trend prediction of China's arsenic trioxide industry from 2022 to 2028
授予渔,从0开始搭建一个自己想要的网页
29io stream, byte output stream continue write line feed