当前位置:网站首页>Mixconv code
Mixconv code
2022-07-01 22:42:00 【zouxiaolv】
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def split_layer(total_channels, num_groups):
# print('total_channels', 'num_groups',total_channels, num_groups)
a =[]
for _ in range(num_groups):
a.append(int(np.ceil(total_channels / num_groups)))
# print('a=',a)
split = [int(np.ceil(total_channels / num_groups)) for _ in range(num_groups)]
# print('split= ',split)
# print('splitsum= ', total_channels - sum(split))
split[num_groups - 1] += total_channels - sum(split)
# print('split111= ', split)
return split
class DepthwiseConv2D(nn.Module):
def __init__(self, in_channels, kernal_size, stride, bias=False):
super(DepthwiseConv2D, self).__init__()
padding = (kernal_size - 1) // 2
self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=kernal_size, padding=padding, stride=stride, groups=in_channels, bias=bias)
def forward(self, x):
out = self.depthwise_conv(x)
return out
class GroupConv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, n_chunks=1, bias=False):
super(GroupConv2D, self).__init__()
self.n_chunks = n_chunks
self.split_in_channels = split_layer(in_channels, n_chunks)
# print('self.split_in_channels=',self.split_in_channels)
split_out_channels = split_layer(out_channels, n_chunks)
if n_chunks == 1:
self.group_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias)
else:
self.group_layers = nn.ModuleList()
for idx in range(n_chunks):
self.group_layers.append(nn.Conv2d(self.split_in_channels[idx], split_out_channels[idx], kernel_size=kernel_size, bias=bias))
def forward(self, x):
if self.n_chunks == 1:
return self.group_conv(x)
else:
split = torch.split(x, self.split_in_channels, dim=1)
out = torch.cat([layer(s) for layer, s in zip(self.group_layers, split)], dim=1)
return out
class MDConv(nn.Module):
def __init__(self, out_channels, n_chunks, stride=1, bias=False):
super(MDConv, self).__init__()
self.n_chunks = n_chunks
self.split_out_channels = split_layer(out_channels, n_chunks)
print('self.split_out_channels=',self.split_out_channels)
self.layers = nn.ModuleList()
for idx in range(self.n_chunks):
kernel_size = 2 * idx + 3
self.layers.append(DepthwiseConv2D(self.split_out_channels[idx], kernal_size=kernel_size, stride=stride, bias=bias))
def forward(self, x):
split = torch.split(x, self.split_out_channels, dim=1)
out = torch.cat([layer(s) for layer, s in zip(self.layers, split)], dim=1)
return out
temp = torch.randn((16, 3, 32, 32))
# group = GroupConv2D(3, 16, n_chunks=2)
# print(group(temp).size())
group = MDConv(3, n_chunks=2)
print(group(temp).size())Explanation is : First group the input , After subletting , Each group uses different convolution kernel sizes , Complete the convolution process with deep separable convolution
shortcoming : similar inception, The more groups , The faster the speed decreases . Only focus on parameter reduction
边栏推荐
- MySQL stored procedure
- C#/VB. Net to add text / image watermarks to PDF documents
- 园区全光技术选型-中篇
- nn.Parameter】Pytorch特征融合自适应权重设置(可学习权重使用)
- C#/VB.NET 给PDF文档添加文本/图像水印
- 固定资产管理子系统报表分为什么大类,包括哪些科目
- [jetcache] how to use jetcache
- 好友新书发布,祝贺(送福利)
- Intelligent computing architecture design of Internet
- Why must digital transformation strategies include continuous testing?
猜你喜欢
随机推荐
Indicator trap: seven KPI mistakes that it leaders are prone to make
园区全光技术选型-中篇
91.(cesium篇)cesium火箭发射模拟
Fully annotated SSM framework construction
[jetcache] how to use jetcache
小红书Scheme跳转到指定页面
QT 使用FFmpeg4将argb的Qimage转换成YUV422P
There is no signal in HDMI in computer games caused by memory, so it crashes
MySQL5.7 设置密码策略(等保三级密码改造)
LC669. 修剪二叉搜索树
MySQL的视图练习题
Resttemplate remote call tool class
【c语言】malloc函数详解[通俗易懂]
Basic knowledge of ngnix
447-哔哩哔哩面经1
[ecological partner] Kunpeng system engineer training
QT uses ffmpeg4 to convert the qimage of ARGB to yuv422p
Smart micro mm32 multi-channel adc-dma configuration
高攀不起的希尔排序,直接插入排序
完全注解的ssm框架搭建







