BatchNormalization
1、简介
BatchNormalization原文: BN原论文)
Batch Normalization是google团队在2015年提出的, 该方法能够加速网络的收敛并提高准确率。
本文分为以下几个部分:
- BN的原理
- 使用pytorch验证本文观点
- BN使用注意事项
2、Batch Normalization原理
在图像预处理中通常会对图像进行标准化处理, 这样能够加速网络的收敛, 对于Conv1来说, 输入就是满足某一分布的特征矩阵, 但是对Conv2而言的feature map就不一定满足某一分布规律了(注意这里所说的满足某一分布规律并不是指某一个feature map的数据要满足分布规律, 理论上指整个训练样本集所对应的feature map的数据要满足分布规律)。而我们的Batch Normalization的目的就是使我们的feature map满足均值为0,方差为1的分布规律。
下面是从原论文中截取的原话
对于一个拥有d维的输入x, 我们将对它的每一个维度进行标准化处理, 假设我们输入的x是RGB三通道的彩色图像,这里的d就是输入图像的channels即d=3, , 其中代表的就是R通道对应的特征矩阵, 以此类推。标准化处理也是分别对我们的R通道, G通道, B通道进行处理。
原文中提供了更加相似的计算公式。
上面提到让feature map满足某一分布规律, 理论上指整个训练样本集所对应的feature map的数据要满足分布规律,也就是说要算出整个训练集的feature map然后再进行标准化处理。对于一个大型的数据集明显是不可能的,所以论文中说的是batch normalization, 也就是我们计算一个batch数据的feature map然后再进行标准化(batch越大越接近整个数据集的分布, 效果越好) .
根据上面的公式可以知道代表计算的feature map每个维度(channel)的均值,注意是一个向量,而不是一个值, 向量的每个元素代表一个维度(channel)的均值。代表计算的feature map每个维度(channels)的方差, 是一个向量, 而不是一个值, 向量的每一个元素代表一个维度(channel)的方差, 然后根据和计算标准化处理得到的值。下图给出了一个计算和的示例。
所以可以得出
上面的示例展示了一个batch为2(两张图片)的Batch Normalization的计算过程, 假设feature1, feature2分别是由image1、image2经过一系列卷积池化后的得到的特征矩阵, feature的channel为2, 那么代表该batch的所有feature的channel1的数据, 同理代表该batch的所有feature的channel2的数据。然后分别计算和的均值和方差, 得到和两个向量。
然后再根据标准差计算公式分别计算每个channel的值(公式中的是一个很小的常量, 防止分母为零的情况)。
batch normalization之后,每个元素的计算公式为:
网络训练过程中, 通过一个batch一个batch的数据进行训练, 但是在预测过程中通常是输入一张图片进行预测,因此预测是batch size=1, 如果再通过上述方法计算均值和方差就没有意义了。
所以在训练过程中要去不断地计算每个batch的均值和方差, 并使用移动平均(moving average)的方法记录统计的均值和方差。在训练完成后我们可以近似认为所统计的均值和方法就等于整个训练集的均值和方差。然后在验证以及预测过程中, 就使用统计得到的均值和方差进行标准化处理。
其实还可以发现论文中还有和两个参数, 用来调整数值分布的方差大小, 用来调整数据均值的位置。这两个参数是在反向传播过程中学习得到的, 默认为1, 默认为0。
2、使用pytorch进行试验
上面提到,在训练过程中, 均值和方差是通过计算当前批次数据得到的,而在验证和预测过程中使用的均值和方差是一个统计量,记为和。
\mu_{statistic + 1} = (1 - momentum) \mu_{statistic} + momentum \mu_{now}
\sigma^2_{statistic + 1} = (1 - momentum) \sigma^2_{statistic} + momentum\sigma^2_{statistic}
\sigma^2_{now} = \frac{1}{m}\sum_{i=1}^{m}(x_i - \mu_{now})^2
\sigma^2_{now} = \frac{1}{m-1}\sum_{i=1}^{m}(x_i - \mu_{now})^2
1 | # !/usr/bin/env python |
结果明显一样,只是精度不同。
4、使用BN时需要注意的问题
训练时training参数设置为true, 验证时设置为false。在pytorch中可通过创建模型的model.train()和model.eval()方法控制。
batch size尽可能设置大一点, 设置很小后表现很糟糕, 设置越大,求解出来的均值和方差越接近整个训练集的均值和方差。
建议将bn层凡在conv和激活层之间,且卷积层不要使用偏置bias, 因为设置了偏置,最后的结果也一样。