当前位置:网站首页>mixconv代码
mixconv代码
2022-07-01 21:47:00 【zouxiaolv】
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def split_layer(total_channels, num_groups):
# print('total_channels', 'num_groups',total_channels, num_groups)
a =[]
for _ in range(num_groups):
a.append(int(np.ceil(total_channels / num_groups)))
# print('a=',a)
split = [int(np.ceil(total_channels / num_groups)) for _ in range(num_groups)]
# print('split= ',split)
# print('splitsum= ', total_channels - sum(split))
split[num_groups - 1] += total_channels - sum(split)
# print('split111= ', split)
return split
class DepthwiseConv2D(nn.Module):
def __init__(self, in_channels, kernal_size, stride, bias=False):
super(DepthwiseConv2D, self).__init__()
padding = (kernal_size - 1) // 2
self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=kernal_size, padding=padding, stride=stride, groups=in_channels, bias=bias)
def forward(self, x):
out = self.depthwise_conv(x)
return out
class GroupConv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, n_chunks=1, bias=False):
super(GroupConv2D, self).__init__()
self.n_chunks = n_chunks
self.split_in_channels = split_layer(in_channels, n_chunks)
# print('self.split_in_channels=',self.split_in_channels)
split_out_channels = split_layer(out_channels, n_chunks)
if n_chunks == 1:
self.group_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias)
else:
self.group_layers = nn.ModuleList()
for idx in range(n_chunks):
self.group_layers.append(nn.Conv2d(self.split_in_channels[idx], split_out_channels[idx], kernel_size=kernel_size, bias=bias))
def forward(self, x):
if self.n_chunks == 1:
return self.group_conv(x)
else:
split = torch.split(x, self.split_in_channels, dim=1)
out = torch.cat([layer(s) for layer, s in zip(self.group_layers, split)], dim=1)
return out
class MDConv(nn.Module):
def __init__(self, out_channels, n_chunks, stride=1, bias=False):
super(MDConv, self).__init__()
self.n_chunks = n_chunks
self.split_out_channels = split_layer(out_channels, n_chunks)
print('self.split_out_channels=',self.split_out_channels)
self.layers = nn.ModuleList()
for idx in range(self.n_chunks):
kernel_size = 2 * idx + 3
self.layers.append(DepthwiseConv2D(self.split_out_channels[idx], kernal_size=kernel_size, stride=stride, bias=bias))
def forward(self, x):
split = torch.split(x, self.split_out_channels, dim=1)
out = torch.cat([layer(s) for layer, s in zip(self.layers, split)], dim=1)
return out
temp = torch.randn((16, 3, 32, 32))
# group = GroupConv2D(3, 16, n_chunks=2)
# print(group(temp).size())
group = MDConv(3, n_chunks=2)
print(group(temp).size())讲解就是 :先将输入进行分组,分完租后,每一个组采用不同的卷积核大小,用深度可分离卷积完成卷积过程
缺点:类似inception,分组越多,速度降低越快。只关注参数量降低
边栏推荐
- MySQL之MHA高可用配置及故障切换
- FFMpeg学习笔记
- Clean up system cache and free memory under Linux
- Communication between browser tab pages
- Basic knowledge of ngnix
- Sonic cloud real machine learning summary 6 - 1.4.1 server and agent deployment
- 2020-ViT ICLR
- Count the number of each character in the character
- 企业架构与项目管理的关联和区别
- 从零开始学 MySQL —数据库和数据表操作
猜你喜欢
随机推荐
MySQL的视图练习题
Redis configuration and optimization
Make a three digit number of all daffodils "recommended collection"
Is PMP certificate really useful?
多种智能指针
leetcode - 287. 寻找重复数
【c语言】malloc函数详解[通俗易懂]
Object memory layout
【MySQL】explain的基本使用以及各列的作用
Qtreeview+qabstractitemmodel custom model: the third of a series of tutorials [easy to understand]
牛客月赛-分组求对数和
Recent public ancestor (LCA) online practices
13th Blue Bridge Cup group B national tournament
【JetCache】JetCache的使用方法与步骤
Basic knowledge of ngnix
LC669. 修剪二叉搜索树
Clean up system cache and free memory under Linux
Delete AWS bound credit card account
Clean up system cache and free memory under Linux
【juc学习之路第9天】屏障衍生工具









![快乐数[环类问题之快慢指针]](/img/37/5c94b9b062a54067a50918f94e61ea.png)