当前位置:网站首页>mixconv代码
mixconv代码
2022-07-01 21:47:00 【zouxiaolv】
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def split_layer(total_channels, num_groups):
# print('total_channels', 'num_groups',total_channels, num_groups)
a =[]
for _ in range(num_groups):
a.append(int(np.ceil(total_channels / num_groups)))
# print('a=',a)
split = [int(np.ceil(total_channels / num_groups)) for _ in range(num_groups)]
# print('split= ',split)
# print('splitsum= ', total_channels - sum(split))
split[num_groups - 1] += total_channels - sum(split)
# print('split111= ', split)
return split
class DepthwiseConv2D(nn.Module):
def __init__(self, in_channels, kernal_size, stride, bias=False):
super(DepthwiseConv2D, self).__init__()
padding = (kernal_size - 1) // 2
self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=kernal_size, padding=padding, stride=stride, groups=in_channels, bias=bias)
def forward(self, x):
out = self.depthwise_conv(x)
return out
class GroupConv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, n_chunks=1, bias=False):
super(GroupConv2D, self).__init__()
self.n_chunks = n_chunks
self.split_in_channels = split_layer(in_channels, n_chunks)
# print('self.split_in_channels=',self.split_in_channels)
split_out_channels = split_layer(out_channels, n_chunks)
if n_chunks == 1:
self.group_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias)
else:
self.group_layers = nn.ModuleList()
for idx in range(n_chunks):
self.group_layers.append(nn.Conv2d(self.split_in_channels[idx], split_out_channels[idx], kernel_size=kernel_size, bias=bias))
def forward(self, x):
if self.n_chunks == 1:
return self.group_conv(x)
else:
split = torch.split(x, self.split_in_channels, dim=1)
out = torch.cat([layer(s) for layer, s in zip(self.group_layers, split)], dim=1)
return out
class MDConv(nn.Module):
def __init__(self, out_channels, n_chunks, stride=1, bias=False):
super(MDConv, self).__init__()
self.n_chunks = n_chunks
self.split_out_channels = split_layer(out_channels, n_chunks)
print('self.split_out_channels=',self.split_out_channels)
self.layers = nn.ModuleList()
for idx in range(self.n_chunks):
kernel_size = 2 * idx + 3
self.layers.append(DepthwiseConv2D(self.split_out_channels[idx], kernal_size=kernel_size, stride=stride, bias=bias))
def forward(self, x):
split = torch.split(x, self.split_out_channels, dim=1)
out = torch.cat([layer(s) for layer, s in zip(self.layers, split)], dim=1)
return out
temp = torch.randn((16, 3, 32, 32))
# group = GroupConv2D(3, 16, n_chunks=2)
# print(group(temp).size())
group = MDConv(3, n_chunks=2)
print(group(temp).size())
讲解就是 :先将输入进行分组,分完租后,每一个组采用不同的卷积核大小,用深度可分离卷积完成卷积过程
缺点:类似inception,分组越多,速度降低越快。只关注参数量降低
边栏推荐
- Redis配置与优化
- 为什么数字化转型战略必须包括持续测试?
- LIS (longest ascending subsequence) problem that can be understood [easy to understand]
- Count the number of each character in the character
- 性能测试计划怎么编写
- MySQL的视图练习题
- 互联网的智算架构设计
- PHP reflective XSS, reflective XSS test and repair
- Indicator trap: seven KPI mistakes that it leaders are prone to make
- CIO's discussion and Analysis on the definition of high-performance it team
猜你喜欢
Slope compensation
详解JMM
企业架构与项目管理的关联和区别
Spark interview questions
Easyexcel complex data export
性能测试计划怎么编写
CIO's discussion and Analysis on the definition of high-performance it team
The second anniversary of the three winged bird: the wings are getting richer and the take-off is just around the corner
Mysql——》Innodb存储引擎的索引
3DE 资源没东西或不对
随机推荐
MySQL之MHA高可用配置及故障切换
切面条 C语言
Four methods of JS array splicing [easy to understand]
100年仅6款产品获批,疫苗竞争背后的“佐剂”江湖
效率提升 - 鼓捣个性化容器开发环境
牛客月赛-分组求对数和
Count the number of each character in the character
Slope compensation
【juc学习之路第8天】Condition
Redis配置与优化
In the past 100 years, only 6 products have been approved, which is the "adjuvant" behind the vaccine competition
[ecological partner] Kunpeng system engineer training
快乐数[环类问题之快慢指针]
Flume interview questions
Fully annotated SSM framework construction
【MySQL】数据库优化方法
MySQL learning notes - SQL optimization of optimization
微信开放平台扫码登录[通俗易懂]
[commercial terminal simulation solution] Shanghai daoning brings you Georgia introduction, trial and tutorial
性能测试计划怎么编写