当前位置:网站首页>DRConv-pytorch改称输出和输入一样的尺寸
DRConv-pytorch改称输出和输入一样的尺寸
2022-07-27 06:11:00 【zouxiaolv】
import torch.nn.functional as F
import torch.nn as nn
import torch
from torch.autograd import Variable, Function
class asign_index(torch.autograd.Function):
@staticmethod
def forward(ctx, kernel, guide_feature):
ctx.save_for_backward(kernel, guide_feature)
guide_mask = torch.zeros_like(guide_feature).scatter_(1, guide_feature.argmax(dim=1, keepdim=True), 1).unsqueeze(2) # B x 3 x 1 x 25 x 25
print('guide_mask =',guide_mask.shape)
return torch.sum(kernel * guide_mask, dim=1)
@staticmethod
def backward(ctx, grad_output):
kernel, guide_feature = ctx.saved_tensors
guide_mask = torch.zeros_like(guide_feature).scatter_(1, guide_feature.argmax(dim=1, keepdim=True), 1).unsqueeze(2) # B x 3 x 1 x 25 x 25
grad_kernel = grad_output.clone().unsqueeze(1) * guide_mask # B x 3 x 256 x 25 x 25
grad_guide = grad_output.clone().unsqueeze(1) * kernel # B x 3 x 256 x 25 x 25
grad_guide = grad_guide.sum(dim=2) # B x 3 x 25 x 25
softmax = F.softmax(guide_feature, 1) # B x 3 x 25 x 25
grad_guide = softmax * (grad_guide - (softmax * grad_guide).sum(dim=1, keepdim=True)) # B x 3 x 25 x 25
return grad_kernel, grad_guide
def xcorr_slow(x, kernel, kwargs):
"""for loop to calculate cross correlation
"""
batch = x.size()[0]
out = []
for i in range(batch):
px = x[i]
pk = kernel[i]
px = px.view(1, px.size()[0], px.size()[1], px.size()[2])
pk = pk.view(-1, px.size()[1], pk.size()[1], pk.size()[2])
po = F.conv2d(px, pk, stride=1,padding=1,**kwargs)
out.append(po)
out = torch.cat(out, 0)
return out
def xcorr_fast(x, kernel, kwargs):
"""group conv2d to calculate cross correlation
"""
batch = kernel.size()[0]
pk = kernel.view(-1, x.size()[1], kernel.size()[2], kernel.size()[3])
px = x.view(1, -1, x.size()[2], x.size()[3])
po = F.conv2d(px, pk,stride=1,padding=1, **kwargs, groups=batch)
po = po.view(batch, -1, po.size()[2], po.size()[3])
return po
class Corr(Function):
@staticmethod
def symbolic(g, x, kernel, groups):
return g.op("Corr", x, kernel, groups_i=groups)
@staticmethod
def forward(self, x, kernel, groups, kwargs):
"""group conv2d to calculate cross correlation
"""
batch = x.size(0)
channel = x.size(1)
x = x.view(1, -1, x.size(2), x.size(3))
kernel = kernel.view(-1, channel // groups, kernel.size(2), kernel.size(3))
out = F.conv2d(x, kernel,stride=1,padding=1, **kwargs, groups=groups * batch)
out = out.view(batch, -1, out.size(2), out.size(3))
return out
class Correlation(nn.Module):
use_slow = True
def __init__(self, use_slow=None):
super(Correlation, self).__init__()
if use_slow is not None:
self.use_slow = use_slow
else:
self.use_slow = Correlation.use_slow
def extra_repr(self):
if self.use_slow: return "xcorr_slow"
return "xcorr_fast"
def forward(self, x, kernel, **kwargs):
if self.training:
if self.use_slow:
return xcorr_slow(x, kernel, kwargs)
else:
return xcorr_fast(x, kernel, kwargs)
else:
return Corr.apply(x, kernel,1, kwargs)
class DRConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, region_num=8, **kwargs):
super(DRConv2d, self).__init__()
self.region_num = region_num
self.conv_kernel = nn.Sequential(
nn.AdaptiveAvgPool2d((kernel_size, kernel_size)),
nn.Conv2d(in_channels, region_num * region_num, kernel_size=1),
nn.Sigmoid(),
nn.Conv2d(region_num * region_num, region_num * in_channels * out_channels, kernel_size=1, groups=region_num)
)
self.conv_guide = nn.Conv2d(in_channels, region_num, kernel_size=kernel_size,stride=1,padding=1, **kwargs)
self.corr = Correlation(use_slow=False)
self.kwargs = kwargs
self.asign_index = asign_index.apply
def forward(self, input):
kernel = self.conv_kernel(input)
print('kernel',kernel.shape)
kernel = kernel.view(kernel.size(0), -1, kernel.size(2), kernel.size(3)) # B x (r*in*out) x W X H
print('kernel111',kernel.shape)
output = self.corr(input, kernel, **self.kwargs) # B x (r*out) x W x H
output = output.view(output.size(0), self.region_num, -1, output.size(2), output.size(3)) # B x r x out x W x H
print('output',output.shape)
guide_feature = self.conv_guide(input)
print('guide_feature',guide_feature.shape)
output = self.asign_index(output, guide_feature)
print('output',output.shape)
return output
if __name__ == '__main__':
B = 16
in_channels = 16
out_channels = 32
size = (64,32)
conv = DRConv2d(in_channels, out_channels, kernel_size=3, region_num=8).cuda()
conv.train()
input = torch.ones(B, in_channels, size[0], size[1]).cuda()
output = conv(input)
print(input.shape, output.shape)
# flops, params
from thop import profile
from thop import clever_format
class Conv2d(nn.Module):
def __init__(self):
super(Conv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3,stride=1,padding=1)
def forward(self, input):
return self.conv(input)
conv2 = Conv2d().cuda()
conv2.train()
macs2, params2 = profile(conv2, inputs=(input, ))
macs, params = profile(conv, inputs=(input, ))
print(macs2, params2)
print(macs, params)边栏推荐
- “蔚来杯“2022牛客暑期多校训练营1
- Tableau prep is connected to maxcompute and only writes simple SQL. Why is this error reported?
- MySQL index failure and solution practice
- 使用sqlplus显示中文为乱码的解决办法
- 【golang学习笔记2.1】 golang中的数组中的排序和查找
- Jmeter: interface automation test - BeanShell compares database data and return data
- LogCat工具
- jjwt 生成token
- Gbase 8C - SQL reference 6 SQL syntax (11)
- 一个优先级顺序的SQL问题
猜你喜欢

在mac中使用docker来搭建oracle数据库服务器

使用pip命令切换不同的镜像源

Watermelon book learning notes - Chapter 4 decision tree

C# 常用功能整合-3

jjwt 生成token

MySQL2

tigervnc的使用

C4D云渲染平台选哪家合作?

Digital image processing -- Chapter 3 gray scale transformation and spatial filtering

Visual horizontal topic bug1:filenotfounderror: could not find module 'mvcameracontrol dll‘ (or one of it
随机推荐
LogCat工具
Digital image processing - Chapter 6 color image processing
使用反射实现动态修改@Excel的注解属性
如何取得对象的DDL信息
Pytorch model
Oracle数据库问题
MySQL: 提高最大连接数
How does golang assign values to empty structures
请问有人使用oracle xstream 时出现个别capture延迟很大的吗,该如何解决延迟问题呢
Sort increment with typescript
2021 interview question of php+go for Zhongda factory (1)
Vscode creates golang development environment and debug unit test of golang
高级IO提纲
PHP defines the array using commas,
Internal class -- just read this article~
How MySQL executes query statements
Gbase 8C - SQL reference 6 SQL syntax (13)
Jmeter:接口自动化测试-BeanShell对数据库数据和返回数据比较
Watermelon book learning notes - Chapter 1 and 2
Jmeter: interface automation test - BeanShell compares database data and return data