当前位置:网站首页>Batch normalization (BN) based on deep learning
Batch normalization (BN) based on deep learning
2022-08-02 05:29:00 【hello689】
一、BN层的作用,工作原理?
在模型训练过程中,Batch normalization exploits the mean sum and variance of mini-batches,不断调整神经网络的中间输出(Make data distribution consistent),使整个神经网络各层的中间输出值更加稳定.(BN的作用)
在原论文中,The author's explanation is that,减少内部协变量偏移(internal covariate shift),So the final training result is more stable.实际上,The interpretation of this method is imprecise,BN所带来的好处,more due to normalization.
二、内部协变量偏移
先来说一下内部协变量偏移的概念:Indicates that the distribution of data will shift during network propagation.
更具体的来说,Deep neural networks involve the superposition of many layers of neurons,The parameter change of each layer will cause the input data distribution of the previous layer to change,通过层层叠加,The high-level input distribution can change dramatically,This makes the upper layer need to constantly re-adapt to the parameter changes of the lower layer.This phenomenon is known as internal covariate shift..
So some scaling and stretching is required,to alleviate this phenomenon.
Here is an example to explain it(I feel like this example is a bit far-fetched),Suppose we have a deep learning network of roses,This is a binary classification network,1Indicates identification as a rose,0means not a rose.Let's first look at part of the training dataset:
上图所示,All roses are red,This will cause the model to think that as long as it is red, it is a rose.
But the picture is colorful roses,Its characteristic distribution is inconsistent with the above figure.通俗的讲,At the beginning, the model has adapted to the distribution of red is rose.,Suddenly there are colorful roses,Will make the model is a bit difficult to accept,This affects the convergence speed and accuracy of the model.,This is internal variable offset(internal covariate shift).而BNTwo scalar parameters are trained in the process,拉伸gamma和偏移beta.BNIs the input values for a normalized operation,将其放缩到合适的范围,从而加快训练速度.
参考博客:https://www.cnblogs.com/itmorn/p/11241236.html#ct6
三、The effect of normalization
归一化,归纳统一样本的统计分布性.Normalization can make subsequent data processing more convenient,The program converges faster,统一量纲.
如下图所示,The left figure is the unnormalized process of finding the optimal solution,The figure on the right is the process of solving the optimal solution after normalization.
It is evident from the figure that,当使用梯度下降法寻求最优解时,很有可能走“之字型”路线(垂直等高线走),从而导致需要迭代很多次才能收敛;而右图对两个原始特征进行了归一化,其对应的等高线显得很圆,在梯度下降进行求解时能较快的收敛.
归一化类型:
- 线性归一化,用max值和minvalue to normalize,如果max和min不稳定,then the result is not stable.
- 标准差归一化,经过处理的数据符合标准正态分布,BNIt's this one
- 非线性归一化,Through the mathematical function,A map on the original value,如log、指数等.
四、BNLayer stages of work
Batch normalization is generally used in fully connected and convolutional layers.使用略有不同
- 对于卷积层来说,BNCan be used after convolutional layers,Before nonlinear activation function;
- 对于全连接层来说,BNPlaced between the affine transformation and activation function of the fully connected layer;

For a more visualized operation process, please refer tohttps://www.cnblogs.com/itmorn/p/11241236.html#ct6
BN的算法流程
输入, X = x 1 , x 2 , . . . , x m X = x_1,x_2,...,x_m X=x1,x2,...,xm学习参数 γ , β \gamma, \beta γ,β
- 计算上一层输出数据的均值;
- 计算上一层输出数据的标准差;
- 归一化处理,为了避免分母为0,add a small value;
- reconstruction distribution(Zoom in and stretch), y i = γ ∗ 第 3 步 得 到 的 值 + β y_i = \gamma * 第3The value of step + \beta yi=γ∗第3步得到的值+β.
Why do scaling and offsets need to be done after normalization??
The distributed normal distribution can be obtained by subtracting the mean and dividing the variance.Take roses as an example,If all roses,So after minus the mean for0,are distributed in0处,By scaling and offsetting,change the distribution,Make model training more stable.
train和evalWhat is the difference in the calculation method?
The following figure shows the error message of a network during training.在训练模式下,batchsize设置为4,使用4块显卡,分布式训练.show appearedbatch_size为1的情况,从而导致代码抛出异常(Guessing here may be because of data parallelism,One for each cardbatch).Because the source code has set,在训练模式下,BN的计算batch是不能为1的,BNis the variance and mean difference of the batch data .
在训练过程中,we cannot know⽤the entire dataset to estimate the mean and⽅差,So it can only be based on the average sum of each mini-batch⽅Poor continuously training the model.while in forecast mode,The mean and the required batch normalization can be computed exactly over the entire dataset⽅差.
五、代码
The following is Mushen's hands-on learning in deep learningBN实现代码.
import torch
from torch import nn
from d2l import torch as d2l
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
if not torch.is_grad_enabled():
# 如果是在预测模式下,直接使⽤传⼊的移动平均所得的均值和⽅差
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# 使⽤全连接层的情况,计算特征维上的均值和⽅差
mean = X.mean(dim=0)
var = ((X - mean) ** 2).mean(dim=0)
else:
# 使⽤⼆维卷积层的情况,计算通道维上(axis=1)的均值和⽅差.
# 这⾥我们需要保持X的形状以便后⾯可以做⼴播运算
mean = X.mean(dim=(0, 2, 3), keepdim=True)
var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
# 训练模式下,⽤当前的均值和⽅差做标准化
X_hat = (X - mean) / torch.sqrt(var + eps)
# 更新移动平均的均值和⽅差
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
Y = gamma * X_hat + beta # 缩放和移位
return Y, moving_mean.data, moving_var.data
边栏推荐
- 可视水印的实现——1使用加法实现(add,+)
- 吴恩达机器学习系列课程笔记——第六章:逻辑回归(Logistic Regression)
- 深度学习基础之批量归一化(BN)
- 空卡安装设置树莓派4B并安装opencv+QT
- Andrew Ng's Machine Learning Series Course Notes - Chapter 18: Application Example: Image Text Recognition (Application Example: Photo OCR)
- 多主复制下处理写冲突(1)-同步与异步冲突检测及避免冲突
- 节流阀和本地存储
- ffmpeg视频播放、格式转化、缩放等命令
- 树莓派上QT连接海康相机
- Pycharm platform import scikit-learn
猜你喜欢

详解CAN总线:什么是CAN总线?

生物识别学习资源推荐

携手推进国产化发展,未来智安与麒麟软件完成兼容互认证

Research Notes (8) Deep Learning and Its Application in WiFi Human Perception (Part 1)

Kubernetes中Pod对象学习笔记

普氏分析法-MATLAB工具箱函数

2022华为软件精英挑战赛(初赛)-总结

科研笔记(七) 基于路径规划和WiFi指纹定位的多目的地室内导航

Reinforcement Learning (Chapter 16 of the Watermelon Book) Mind Map

基于sysbench工具的压力测试---MyCat2.0+MySql架构
随机推荐
字典基本操作方法
吴恩达机器学习系列课程笔记——第十四章:降维(Dimensionality Reduction)
Andrew Ng's Machine Learning Series Course Notes - Chapter 18: Application Example: Image Text Recognition (Application Example: Photo OCR)
VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tupl
MapFi论文架构整理
拦截器Sercurity权限管理和加密方式的登录认证使用
Nexus 5 phone uses Nexmon tool to get CSI information
吴恩达机器学习系列课程笔记——第八章:神经网络:表述(Neural Networks: Representation)
科研笔记(六) 基于环境感知的室内路径规划方法
Deep Blue Academy - Fourteen Lectures of Visual SLAM - Chapter 4 Homework
深蓝学院-视觉SLAM十四讲-第四章作业
MySQL读写分离mysql-proxy部署
树莓派上FFMPEG/VLC播放海康网络摄像仪视频
MySQL5.7的安装编译及报错的解决方法
其他语法和模块的导出导入
数学建模学习(76):多目标线性规划模型(理想法、线性加权法、最大最小法),模型敏感性分析
腾讯云+keepalived搭建云服务器主备实践
剩余参数、数组对象的方法和字符串扩展的方法
STM32/TMS320F2812+W5500硬软件调试总结
日本痴汉打赏女主播1.5亿,结果。。。