当前位置:网站首页>mixconv代码
mixconv代码
2022-07-01 21:47:00 【zouxiaolv】
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def split_layer(total_channels, num_groups):
# print('total_channels', 'num_groups',total_channels, num_groups)
a =[]
for _ in range(num_groups):
a.append(int(np.ceil(total_channels / num_groups)))
# print('a=',a)
split = [int(np.ceil(total_channels / num_groups)) for _ in range(num_groups)]
# print('split= ',split)
# print('splitsum= ', total_channels - sum(split))
split[num_groups - 1] += total_channels - sum(split)
# print('split111= ', split)
return split
class DepthwiseConv2D(nn.Module):
def __init__(self, in_channels, kernal_size, stride, bias=False):
super(DepthwiseConv2D, self).__init__()
padding = (kernal_size - 1) // 2
self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=kernal_size, padding=padding, stride=stride, groups=in_channels, bias=bias)
def forward(self, x):
out = self.depthwise_conv(x)
return out
class GroupConv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, n_chunks=1, bias=False):
super(GroupConv2D, self).__init__()
self.n_chunks = n_chunks
self.split_in_channels = split_layer(in_channels, n_chunks)
# print('self.split_in_channels=',self.split_in_channels)
split_out_channels = split_layer(out_channels, n_chunks)
if n_chunks == 1:
self.group_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias)
else:
self.group_layers = nn.ModuleList()
for idx in range(n_chunks):
self.group_layers.append(nn.Conv2d(self.split_in_channels[idx], split_out_channels[idx], kernel_size=kernel_size, bias=bias))
def forward(self, x):
if self.n_chunks == 1:
return self.group_conv(x)
else:
split = torch.split(x, self.split_in_channels, dim=1)
out = torch.cat([layer(s) for layer, s in zip(self.group_layers, split)], dim=1)
return out
class MDConv(nn.Module):
def __init__(self, out_channels, n_chunks, stride=1, bias=False):
super(MDConv, self).__init__()
self.n_chunks = n_chunks
self.split_out_channels = split_layer(out_channels, n_chunks)
print('self.split_out_channels=',self.split_out_channels)
self.layers = nn.ModuleList()
for idx in range(self.n_chunks):
kernel_size = 2 * idx + 3
self.layers.append(DepthwiseConv2D(self.split_out_channels[idx], kernal_size=kernel_size, stride=stride, bias=bias))
def forward(self, x):
split = torch.split(x, self.split_out_channels, dim=1)
out = torch.cat([layer(s) for layer, s in zip(self.layers, split)], dim=1)
return out
temp = torch.randn((16, 3, 32, 32))
# group = GroupConv2D(3, 16, n_chunks=2)
# print(group(temp).size())
group = MDConv(3, n_chunks=2)
print(group(temp).size())讲解就是 :先将输入进行分组,分完租后,每一个组采用不同的卷积核大小,用深度可分离卷积完成卷积过程
缺点:类似inception,分组越多,速度降低越快。只关注参数量降低
边栏推荐
- 详解Kubernetes网络模型
- LIS (longest ascending subsequence) problem that can be understood [easy to understand]
- Communication between browser tab pages
- 分享一个一年经历两次裁员的程序员的一些感触
- Redis配置与优化
- PHP reflective XSS, reflective XSS test and repair
- 删除AWS绑定的信用卡账户
- Airserver mobile phone third-party screen projection computer software
- 基准环路增益与相位裕度的测量
- Can you get a raise? Analysis on gold content of PMP certificate
猜你喜欢
随机推荐
多种智能指针
FFMpeg学习笔记
The leader of the cloud native theme group of beacon Committee has a long way to go!
Copy ‘XXXX‘ to effectively final temp variable
小红书Scheme跳转到指定页面
为什么数字化转型战略必须包括持续测试?
LC669. 修剪二叉搜索树
[STM32] stm32cubemx tutorial II - basic use (new projects light up LED lights)
C#/VB. Net to add text / image watermarks to PDF documents
Basic knowledge of ngnix
RestTemplate 远程调用工具类
100年仅6款产品获批,疫苗竞争背后的“佐剂”江湖
完全注解的ssm框架搭建
In the past 100 years, only 6 products have been approved, which is the "adjuvant" behind the vaccine competition
【MySQL】索引的分类
flink sql 命令行 连接 yarn
Can you get a raise? Analysis on gold content of PMP certificate
【juc学习之路第9天】屏障衍生工具
详解Kubernetes网络模型
【MySQL】索引的创建、查看和删除




![[jetcache] how to use jetcache](/img/fa/5b3abe53bb7e9db6af2dbb1cb76a31.png)




