浅谈Batch Normalization

最近面试被问到了一些BN相关的知识,有些忘记,重新拾起一下。从字面上看,Batch Normalization就是要对每一批数据进行归一化操作,确实如此,对于训练中某个batch的数据,注意这个是数据也可以是输入也可以是网络中间某层的输出,对数据进行如下操作:

动机

知道BN的具体操作,那为啥需要对数据进行BN操作呢?动机又是啥?对于传统的浅层学习模型,如单层的逻辑回归,SVM以及BP等模型,每次更新参数均从稳定的训练数据上拟合。深度学习因其多层结构,浅层输出作为下一层输入,除了存在梯度消失问题外,在学习过程中,每层网络参数不断更新 ,导致下一层的输入的分布不断的变化,因而无法跟浅层模型一样,每次都在稳定的数据上学习参数。

减均值除方差,数据就被移到中心区域。对大多数激活函数而言,这个区域的梯度都是最大的或者是有梯度的,这可以看做是一种对抗梯度小时的手段。对于一层是如此,如果对于每一层数据都是这么操作的话,那么数据的分布就总在随着输入变化敏感的区域,相当于不用考虑数据分布变来变去,这样训练效率就高很多。
不过,这里问题还没有结束,因为减均值除方差未必是最好的分布。比如数据本身就是很不对称,或者激活函数未必是对方差为1的数据有最好的结果。而且如果仅仅是把特征都normalize到N(0,1),那么因为特征只在激活函数上线性区域上激活,会降低特征的表达能力。
针对上述问题,在算法结束的之前,作者对normalization之后的数据设置了两个参数γ和β,做了上述第四步的线性变换,这两个参数是需要学习的参数。其实BN的本质就是利用优化变一下方差大小和均值的位置。

caffe 实现

实际应用中,$μβ$和$σ^2β$通常是在训练集上计算,测试的时候直接使用训练时计算得到的值. 此外,Batch Normalization Layer的backward pass实际并没有被调用。
BN在proto文件中默认参数配置如下:

1
2
3
4
5
6
7
8
9
10
11
message BatchNormParameter {
// 如果为真,则使用保存的均值和方差,否则采用滑动平均计算新的均值和方差。
// 该参数缺省的时候,如果是测试阶段则等价为真,如果是训练阶段则等价为假。
optional bool use_global_stats = 1;

// 滑动平均的衰减系数,默认为0.999
optional float moving_average_fraction = 2 [default = .999];

// 分母附加值,防止除以方差时出现除0操作,默认为1e-5
optional float eps = 3 [default = 1e-5];
}

因为一些历史原因, Caffe的normalize step 和scale and shift step至今不在同一个layer中实现,对于BN的使用可以参考ResNet。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

bottom: "res2a_branch2b"
top: "res2a_branch2b"
name: "bn2a_branch2b"
type: "BatchNorm"
batch_norm_param {
use_global_stats: false //训练阶段和测试阶段不同,
}
include: { phase: TRAIN }

}
layer {
bottom: "res2a_branch2b"
top: "res2a_branch2b"
name: "bn2a_branch2b"
type: "BatchNorm"
batch_norm_param {
use_global_stats: true
}
include: { phase: TEST }

}

此外建议在BatchNorm中不显示配置batch_norm_param,而是有代码运行时自动判断是否use_global_stats。

1
2
3
4
5
BatchNormParameter param = this->layer_param_.batch_norm_param();
moving_average_fraction_ = param.moving_average_fraction();
use_global_stats_ = this->phase_ == TEST;
if (param.has_use_global_stats())
use_global_stats_ = param.use_global_stats();

对于use_global_stats参数值,测试阶段则等价为真,如果是训练阶段则等价为假。此外在caffe中使用BN的时候需要注意一定要配置scale层一起使用。

使用Python实现BN

BN的前向操作只需要实现下面这些公式就行

对于反向传播的话,我们需要对x,r,β进行求导获得其梯度,然后进行反向传播。下面是原文给出的求导公式,原文给出的公式省略了一些中间步骤,下面我自己也推导了一下,如下图所示。


具体实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
def bn_forward(x,gamma,beta,eps=1e-5):#eps: 除以方差时为了防止方差太小而导致数值计算不稳定
#计算均值
sample_mean = np.mean(x, axis=0)
#计算方差
sample_var = np.var(x, axis=0)
#归一化
out_ = (x - sample_mean) / np.sqrt(sample_var + eps)
#缩放平移
out=gamma*out_+beta
#缓存中间变量用于反向传播计算
cache=(x,out_,gamma,beta,sample_mean,sample_var,eps)

return out,cache


def bn_backward(dout,cache):

N,D = dout.shape

x,out_,gamma,beta,sample_mean,sample_var,eps=cache

dbeta=np.sum(dout,axis=0)

dgamma=np.sum(out_*dout,axis=0)

dxhat=gamma*dout

dvar=np.sum(dxhat*out_,axis=0)*(-1/2.)*(1./sample_var+eps)

dmean=np.sum(dxhat,axis=0)/np.sqrt(sample_var + eps) + dvar*(-2)*sample_mean

dx=dxhat/np.sqrt(sample_var + eps) + dvar*2*(x-sample_mean)/N + dmean/N

return dx,dgamma,dbeta

谈谈batch_size大小影响

对于一些视觉任务比如分类、目标检测等任务引入BN之后,大的batch往往会带来large margin的收益。最明显的例子就是Face++提出的MegDet,以256个batch size训练检测网络,一举将coco2017目标检测mAP刷到了52.5%,稳居榜首。先解释一下batch size的一些作用:

  • 首先我们谈谈大batch的作用。以sgd举例,事实上sgd大家都知道选一批样本去更新梯度,学这个梯度下降的方向。我们选用大batch,样本更有代表性,因为选择的样本越多,下降肯定越趋于整个数据集收敛应该朝向的方向,所以大batch会更稳定,基本不会在局部震荡, 因为每次参数更新所用到的数据越多,越能代表整体损失函数的梯度,因此梯度精确度更高。另一点是快,而且可以更好利用矩阵计算库里大矩阵乘法的效率,还可以提高显存利用率。
  • 那小batch自然没有上述的好处,取一个极端例子,如果我们把batch size设为1,这样称为在线学习,每次修正方向以各自样本的梯度方向修正,难以达到收敛。可以理解为因为每次梯度更新只用一个样本,所以后果就是很可能在某一个局部极值点附近震荡……同理可辐射到其他小batch size上,但是同样也有好处,可以在达到一定程度后用来精细的去磨一下你的模型。因为小批量在学习过程中加入了噪声,会有一些正则化效果,但是需要小的学习率保持稳定性,否则还是会跳进局部极值难以自拔。
  • 对于大batch size,达到相同精度需要更多epoch,因为每轮训练中的迭代次数更少。

所以可以总结下,适当增大batch size好处在于:

  • 内存利用率提高了。
  • 对于相同数据量的处理速度更快。
  • 在一定范围内,一般来说batch size 越大,其确定的下降方向越准,引起训练震荡越小。

但不能盲目增大,毕竟硬件限制摆在那,batch size上去,首先要有足够地显存和内存支撑。不是每个人都可以动不动上256的batch size来训练网络。