BatchNorm怎样解决训练和推理时batch size 不同的问题?

BatchNorm是在batch维度上计算每个相同通道上的均值和方差,通常情况下,训练阶段的batchsize较大,而推理时batchsize基本为1。这样的话,就会导致训练和推理阶段得到不同的标准化,均值和方差时靠每一个mini-batch的统计得到的,因为推理时只有一个样本,在只有1个向量的数据组上进行标准化后,成了一个全0向量,导致模型出现BUG。为了解决这个问题,不改变训练时的BatchNorm计算方式,仅仅改变推理时计算均值和方差方法。

做法就是用训练集来估计总体均值μ \mu μ和总体标准差σ \sigma σ。主要有两种方法: 简单平均法移动指数平均

  1. 简单平均法
    把每个mini-batch的均值和方差都保存下来,然后训练完了求均值的均值,方差的均值即可。
  1. 移动指数平均(Exponential Moving Average)
    本文仅以μ \mu μ的计算为例:
    μ t o t a l = d e c a y ∗ μ t o t a l + ( 1 − d e c a y ) ∗ μ \mu_{total}=decay\mu_{total}+(1-decay)\mu μt o t a l ​=d e c a y ∗μt o t a l ​+(1 −d e c a y )∗μ
    其中decay是衰减系数。即总均值μ t o t a l \mu_{total}μt o t a l ​是前一个mini-batch统计的总均值和本次mini-batch的μ \mu μ加权求和。至于衰减率 decay在区间[0,1]之间,decay越接近1,结果μ t o t a l \mu_{total}μt o t a l ​越稳定,越受较远的大范围的样本影响;decay越接近0,结果μ t o t a l \mu_{total}μt o t a l ​越波动,越受较近的小范围的样本影响。

事实上,简单平均可能更好,简单平均本质上是平均权重,但是简单平均需要保存所有BN层在所有mini-batch上的均值向量和方差向量,如果训练数据量很大,会有较可观的存储代价。移动指数平均在实际的框架中更常见(例如tensorflow),可能的好处是EMA不需要存储每一个mini-batch的值,永远只保存着三个值:总统计值、本batch的统计值,decay系数。

在训练阶段同步获得了μ t o t a l \mu_{total}μt o t a l ​和σ t o t a l \sigma_{total}σt o t a l ​后,在推理时即可对样本进行BN操作。

Original: https://blog.csdn.net/weixin_42211626/article/details/122857223
Author: macan_dct
Title: BatchNorm怎样解决训练和推理时batch size 不同的问题?

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/496877/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球