BatchNorm是在batch维度上计算每个相同通道上的均值和方差,通常情况下,训练阶段的batchsize较大,而推理时batchsize基本为1。这样的话,就会导致训练和推理阶段得到不同的标准化,均值和方差时靠每一个mini-batch的统计得到的,因为推理时只有一个样本,在只有1个向量的数据组上进行标准化后,成了一个全0向量,导致模型出现BUG。为了解决这个问题,不改变训练时的BatchNorm计算方式,仅仅改变推理时计算均值和方差方法。
做法就是用训练集来估计总体均值μ \mu μ和总体标准差σ \sigma σ。主要有两种方法: 简单平均法和 移动指数平均
- 简单平均法
把每个mini-batch的均值和方差都保存下来,然后训练完了求均值的均值,方差的均值即可。
- 移动指数平均(Exponential Moving Average)
本文仅以μ \mu μ的计算为例:
μ t o t a l = d e c a y ∗ μ t o t a l + ( 1 − d e c a y ) ∗ μ \mu_{total}=decay\mu_{total}+(1-decay)\mu μt o t a l =d e c a y ∗μt o t a l +(1 −d e c a y )∗μ
其中decay是衰减系数。即总均值μ t o t a l \mu_{total}μt o t a l 是前一个mini-batch统计的总均值和本次mini-batch的μ \mu μ加权求和。至于衰减率 decay在区间[0,1]之间,decay越接近1,结果μ t o t a l \mu_{total}μt o t a l 越稳定,越受较远的大范围的样本影响;decay越接近0,结果μ t o t a l \mu_{total}μt o t a l 越波动,越受较近的小范围的样本影响。
事实上,简单平均可能更好,简单平均本质上是平均权重,但是简单平均需要保存所有BN层在所有mini-batch上的均值向量和方差向量,如果训练数据量很大,会有较可观的存储代价。移动指数平均在实际的框架中更常见(例如tensorflow),可能的好处是EMA不需要存储每一个mini-batch的值,永远只保存着三个值:总统计值、本batch的统计值,decay系数。
在训练阶段同步获得了μ t o t a l \mu_{total}μt o t a l 和σ t o t a l \sigma_{total}σt o t a l 后,在推理时即可对样本进行BN操作。
Original: https://blog.csdn.net/weixin_42211626/article/details/122857223
Author: macan_dct
Title: BatchNorm怎样解决训练和推理时batch size 不同的问题?
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/496877/
转载文章受原作者版权保护。转载请注明原作者出处!