当前位置:网站首页>Convolution和Batch normalization的融合
Convolution和Batch normalization的融合
2022-07-02 22:14:00 【点PY】
理论推算
当前CNN卷积层的基本组成单元标配:Conv + BN +ReLU 三个子模块。但其实在网络的推理阶段,可以将BN层的运算融合到Conv层中,减少运算量,加速推理。本质上是修改了卷积核的参数,在不增加Conv层计算量的同时,略去了BN层的计算量。公式推导如下。
conv层的参数
BN层的参数
假设输入为x,则x->Conv->BN的输出便是:
做个简单的公式变形:
代码实现
在实际使用时,首先要定位conv和bn的位置,根据实际情况进行替换或者删除BN层。在本次实施例中,以开源分割模型库https://github.com/qubvel/segmentation_models.pytorch为案例进行融合实验,对BN层进行了替换。
class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
if use_batchnorm == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
)
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
if use_batchnorm == "inplace":
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
relu = nn.Identity()
elif use_batchnorm and use_batchnorm != "inplace":
bn = nn.BatchNorm2d(out_channels)
else:
bn = nn.Identity()
super(Conv2dReLU, self).__init__(conv, bn, relu)
from turtle import forward
from torch.fx.experimental.optimization import fuse
import torch
import torch.nn as nn
import time
import segmentation_models_pytorch.base.modules as md
from utils.torchUtils import fuse_conv_and_bn
def fuseModel(model): # fuse model Conv2d() + BatchNorm2d() layers
for m in model.modules():
if isinstance(m, (md.Conv2dReLU)) and isinstance(m[1], (nn.BatchNorm2d)):
m[0] = fuse_conv_and_bn(m[0], m[1]) # update conv
m[1] = nn.Identity()
count += 1
return model
边栏推荐
- 用matlab调用vs2015来编译vs工程
- 基于Pyqt5工具栏按钮可实现界面切换-2
- C#中Linq用法汇集
- 抖音实战~点赞数量弹框
- 20220524_ Database process_ Statement retention
- 海思调用接口之Makefile配置
- 2022年最新最全软件测试面试题大全
- How difficult is it to be high? AI rolls into the mathematics circle, and the accuracy rate of advanced mathematics examination is 81%!
- SQL advanced syntax
- Tronapi wave field interface - source code without encryption - can be opened twice - interface document attached - packaging based on thinkphp5 - detailed guidance of the author - July 1, 2022 08:43:
猜你喜欢
抖音实战~点赞数量弹框
Start from the bottom structure to learn the customization and testing of FPGA --- Xilinx ROM IP
The first batch of Tencent cloud completed the first cloud native security maturity assessment in China
[npuctf2020]ezlogin XPath injection
Win11启用粘滞键关闭不了怎么办?粘滞键取消了但不管用怎么解决
STM32之ADC
Third party payment function test point [Hangzhou multi tester _ Wang Sir] [Hangzhou multi tester]
内网渗透 | 手把手教你如何进行内网渗透
Print out mode of go
基于Pyqt5工具栏按钮可实现界面切换-1
随机推荐
Typical case of data annotation: how does jinglianwen technology help enterprises build data solutions
海思调用接口之Makefile配置
2016. maximum difference between incremental elements
RuntimeError: no valid convolution algorithms available in CuDNN
Pandora IOT development board learning (HAL Library) - Experiment 4 serial port communication experiment (learning notes)
Explain promise usage in detail
Interface switching based on pyqt5 toolbar button -1
C MVC creates a view to get rid of the influence of layout
Yolox enhanced feature extraction network panet analysis
LINQ usage collection in C #
Numerical solution of partial differential equations with MATLAB
Win11启用粘滞键关闭不了怎么办?粘滞键取消了但不管用怎么解决
ADC of stm32
STM32串口DAM接收253字节就死机原因排查
【直播预约】数据库OBCP认证全面升级公开课
What experience is there only one test in the company? Listen to what they say
“一个优秀程序员可抵五个普通程序员!”
Solution to boost library link error
Win11麦克风测试在哪里?Win11测试麦克风的方法
Golang common settings - modify background