当前位置:网站首页>Mixconv code
Mixconv code
2022-07-01 22:42:00 【zouxiaolv】
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def split_layer(total_channels, num_groups):
# print('total_channels', 'num_groups',total_channels, num_groups)
a =[]
for _ in range(num_groups):
a.append(int(np.ceil(total_channels / num_groups)))
# print('a=',a)
split = [int(np.ceil(total_channels / num_groups)) for _ in range(num_groups)]
# print('split= ',split)
# print('splitsum= ', total_channels - sum(split))
split[num_groups - 1] += total_channels - sum(split)
# print('split111= ', split)
return split
class DepthwiseConv2D(nn.Module):
def __init__(self, in_channels, kernal_size, stride, bias=False):
super(DepthwiseConv2D, self).__init__()
padding = (kernal_size - 1) // 2
self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=kernal_size, padding=padding, stride=stride, groups=in_channels, bias=bias)
def forward(self, x):
out = self.depthwise_conv(x)
return out
class GroupConv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, n_chunks=1, bias=False):
super(GroupConv2D, self).__init__()
self.n_chunks = n_chunks
self.split_in_channels = split_layer(in_channels, n_chunks)
# print('self.split_in_channels=',self.split_in_channels)
split_out_channels = split_layer(out_channels, n_chunks)
if n_chunks == 1:
self.group_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias)
else:
self.group_layers = nn.ModuleList()
for idx in range(n_chunks):
self.group_layers.append(nn.Conv2d(self.split_in_channels[idx], split_out_channels[idx], kernel_size=kernel_size, bias=bias))
def forward(self, x):
if self.n_chunks == 1:
return self.group_conv(x)
else:
split = torch.split(x, self.split_in_channels, dim=1)
out = torch.cat([layer(s) for layer, s in zip(self.group_layers, split)], dim=1)
return out
class MDConv(nn.Module):
def __init__(self, out_channels, n_chunks, stride=1, bias=False):
super(MDConv, self).__init__()
self.n_chunks = n_chunks
self.split_out_channels = split_layer(out_channels, n_chunks)
print('self.split_out_channels=',self.split_out_channels)
self.layers = nn.ModuleList()
for idx in range(self.n_chunks):
kernel_size = 2 * idx + 3
self.layers.append(DepthwiseConv2D(self.split_out_channels[idx], kernal_size=kernel_size, stride=stride, bias=bias))
def forward(self, x):
split = torch.split(x, self.split_out_channels, dim=1)
out = torch.cat([layer(s) for layer, s in zip(self.layers, split)], dim=1)
return out
temp = torch.randn((16, 3, 32, 32))
# group = GroupConv2D(3, 16, n_chunks=2)
# print(group(temp).size())
group = MDConv(3, n_chunks=2)
print(group(temp).size())Explanation is : First group the input , After subletting , Each group uses different convolution kernel sizes , Complete the convolution process with deep separable convolution
shortcoming : similar inception, The more groups , The faster the speed decreases . Only focus on parameter reduction
边栏推荐
- Selection of all-optical technology in the park - Part 2
- Lc669. Prune binary search tree
- EasyExcel 复杂数据导出
- Mask wearing detection method based on yolov5
- flink sql 命令行 连接 yarn
- 性能测试计划怎么编写
- Mysql——》MyISAM存储引擎的索引
- CSDN购买的课程从哪里可以进入
- 恶意软件反向关闭EDR的原理、测试和反制思考
- The fixed assets management subsystem reports are divided into what categories and which accounts are included
猜你喜欢

Ida dynamic debugging apk

Easyexcel complex data export

Intelligent computing architecture design of Internet

keras训练的H5模型转tflite

GenICam GenTL 标准 ver1.5(4)第五章 采集引擎

The fixed assets management subsystem reports are divided into what categories and which accounts are included

Selection of all-optical technology in the park - Part 2

C#/VB.NET 给PDF文档添加文本/图像水印

SAP 智能机器人流程自动化(iRPA)解决方案分享

91.(cesium篇)cesium火箭发射模拟
随机推荐
flink sql 命令行 连接 yarn
Flume interview questions
20220701
Pytorch sharpening chapter | argmax and argmin functions
【MySQL】explain的基本使用以及各列的作用
灵动微 MM32 多路ADC-DMA配置
flink sql-client 使用 对照并熟悉官方文档
Ida dynamic debugging apk
Selection of all-optical technology in the park - Part 2
EasyExcel 复杂数据导出
详解JMM
从零开始学 MySQL —数据库和数据表操作
Yyds dry goods inventory # solve the real problem of famous enterprises: egg twisting machine
#yyds干货盘点# 解决名企真题:扭蛋机
搜狗微信APP逆向(二)so层
3DE resources have nothing or nothing wrong
人体姿态估计的热图变成坐标点的两种方案
plantuml介绍与使用
效率提升 - 鼓捣个性化容器开发环境
Object memory layout