当前位置:网站首页>Convolution和Batch normalization的融合
Convolution和Batch normalization的融合
2022-07-02 23:25:00 【點PY】
理論推算
當前CNN卷積層的基本組成單元標配:Conv + BN +ReLU 三個子模塊。但其實在網絡的推理階段,可以將BN層的運算融合到Conv層中,减少運算量,加速推理。本質上是修改了卷積核的參數,在不增加Conv層計算量的同時,略去了BN層的計算量。公式推導如下。
conv層的參數
BN層的參數
假設輸入為x,則x->Conv->BN的輸出便是:
做個簡單的公式變形:

代碼實現
在實際使用時,首先要定比特conv和bn的比特置,根據實際情况進行替換或者删除BN層。在本次實施例中,以開源分割模型庫https://github.com/qubvel/segmentation_models.pytorch為案例進行融合實驗,對BN層進行了替換。
class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
if use_batchnorm == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
)
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
if use_batchnorm == "inplace":
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
relu = nn.Identity()
elif use_batchnorm and use_batchnorm != "inplace":
bn = nn.BatchNorm2d(out_channels)
else:
bn = nn.Identity()
super(Conv2dReLU, self).__init__(conv, bn, relu)
from turtle import forward
from torch.fx.experimental.optimization import fuse
import torch
import torch.nn as nn
import time
import segmentation_models_pytorch.base.modules as md
from utils.torchUtils import fuse_conv_and_bn
def fuseModel(model): # fuse model Conv2d() + BatchNorm2d() layers
for m in model.modules():
if isinstance(m, (md.Conv2dReLU)) and isinstance(m[1], (nn.BatchNorm2d)):
m[0] = fuse_conv_and_bn(m[0], m[1]) # update conv
m[1] = nn.Identity()
count += 1
return model
边栏推荐
- Tronapi wave field interface - source code without encryption - can be opened twice - interface document attached - packaging based on thinkphp5 - detailed guidance of the author - July 1, 2022 08:43:
- Remote connection of raspberry pie by VNC viewer
- RuntimeError: no valid convolution algorithms available in CuDNN
- 跨境电商如何通过打好数据底座,实现低成本稳步增长
- Go basic constant definition and use
- Application of containerization technology in embedded field
- Static file display problem
- [live broadcast appointment] database obcp certification comprehensive upgrade open class
- RuntimeError: no valid convolution algorithms available in CuDNN
- Go basic anonymous variable
猜你喜欢

Introduction to the latest plan of horizon in April 2022

2022年最新最全软件测试面试题大全

Catalogue of digital image processing experiments

【STL源码剖析】仿函数(待补充)

面试过了,起薪16k

Use of recyclerview with viewbinding

实现BottomNavigationView和Navigation联动

Strictly abide by the construction period and ensure the quality, this AI data annotation company has done it!

What can I do after buying a domain name?

RuntimeError: no valid convolution algorithms available in CuDNN
随机推荐
Brief introduction to common sense of Zhongtai
解决:exceptiole ‘xxxxx.QRTZ_LOCKS‘ doesn‘t exist以及mysql的my.cnf文件追加lower_case_table_names后启动报错
Li Kou brush questions (2022-6-28)
Redis expiration policy +conf record
海思 VI接入视频流程
[redis notes] compressed list (ziplist)
Ping domain name error unknown host, NSLOOKUP / system d-resolve can be resolved normally, how to Ping the public network address?
[live broadcast appointment] database obcp certification comprehensive upgrade open class
Where is the win11 microphone test? Win11 method of testing microphone
Troubleshooting the cause of the crash when STM32 serial port dam receives 253 bytes
Potplayer set minimized shortcut keys
RuntimeError: no valid convolution algorithms available in CuDNN
潘多拉 IOT 开发板学习(HAL 库)—— 实验3 按键输入实验(学习笔记)
Third party payment function test point [Hangzhou multi tester _ Wang Sir] [Hangzhou multi tester]
非路由组件之头部组件和底部组件书写
SQL advanced syntax
Numerical solution of partial differential equations with MATLAB
富滇银行完成数字化升级|OceanBase数据库助力布局分布式架构中台
ServletContext learning diary 1
Remote connection of raspberry pie by VNC viewer