nn.BatchNorm2d——批量标准化操作解读

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

功能:对输入的四维数组进行批量标准化处理,具体计算公式如下:
y = x − m e a n [ x ] V a r [ x ] + e p s ∗ g a m m a + b e t a y=\frac{x-mean[x]}{\sqrt{Var[x]+eps}}*gamma+beta y =Va r [x ]+e p s ​x −m e an [x ]​∗g amma +b e t a

对于 所有的batch中的 同一个channel的数据元素进行标准化处理,即如果有C个通道,无论有多少个batch,都会在通道维度上进行标准化处理,一共进行C次。

训练阶段的均值和方差计算方法相同,将所有batch相同通道的值取出来,一块计算均值和方差,即计算当前观测值的均值和方差。

测试阶段的均值和方差有两种计算方法:
①估计所有图片的均值和方差,即做全局计算,具体计算方法如下:
模型分别储存 各个通道(通道数需要预先定义)的均值和方差数据(初始为0和1),在每次 训练过程中, 每标准化一组数据,都利用计算得到的局部观测值的均值和方差对储存的数据做 更新测试阶段利用模型存储的两个数据做标准化处理,更新公式如下:
X n e w = ( 1 − m o m e n t u m ) × X o l d + m o m e n t u m × X t 其中, X n e w 是模型的新参数, X o l d 是模型原来的参数, X t 是当前观测值的参数 X_{new}=(1-momentum)\times X_{old} + momentum\times X_t\ 其中,X_{new}是模型的新参数,X_{old}是模型原来的参数,X_t是当前观测值的参数X n e w ​=(1 −m o m e n t u m )×X o l d ​+m o m e n t u m ×X t ​其中,X n e w ​是模型的新参数,X o l d ​是模型原来的参数,X t ​是当前观测值的参数
②采用和训练阶段相同的计算方法,即只计算当前输入数据的均值和方差

输入:

  • num_features:输入图像的通道数量。
  • eps:稳定系数,防止分母出现0。
  • momentum:模型均值和方差更新时的参数,见上述公式。
  • affine:代表gamma,beta是否可学。如果设为 True,代表两个参数是通过学习得到的;如果设为 False,代表两个参数是固定值,默认情况下,gamma是1,beta是0。
  • track_running_stats:代表训练阶段是否更新模型存储的均值和方差,即测试阶段的均值与方差的计算方法采用第一种方法还是第二种方法。如果设为 True,则代表训练阶段每次迭代都会更新模型存储的均值和方差(计算全局数据),测试过程中利用存储的均值和方差对各个通道进行标准化处理;如果设为 False,则模型不会存储均值和方差,训练过程中也不会更新均值和方差的数据,测试过程中只计算当前输入图像的均值和方差数据(局部数据)。具体区别见代码案例。

注意:

  • 训练阶段的标准化过程中,均值和方差来源途径只有一种方式,即利用当前输入的数据进行计算。
  • 测试阶段的标准化过程中,均值和方差来源途径有两种方式,一是来源于全局的数据,即模型本身存储一组均值和方差数据,在训练过程中,不断更新它们,使其具有 描述全局数据的统计特性;二是来源于当前的输入数据,即和训练阶段计算方法一样,但这样会在测试过程中带来 统计特性偏移的弊端,一般 track_running_stats设置为 True,即采用第一种来源途径。
  • 换句话说,就是训练阶段和测试阶段所承载的任务不同,训练阶段主要是通过已知的数据去优化模型,而测试阶段主要是利用已知的模型去预测未知的数据。

用途:

  • 训练过程中遇到收敛速度很慢的问题时,可以通过引入BN层来加快网络模型的收敛速度
  • 遇到梯度消失或者梯度爆炸的问题时,可以考虑引入BN层来解决
  • 一般情况下,还可以通过引入BN层来加快网络的训练速度

一般用法

import torch
from torch import nn

img=torch.rand(2,2,2,3)
bn=nn.BatchNorm2d(2)
img_2=bn(img)
print(img)
print(img_2)

tensor([[[[0.5330, 0.7753, 0.6192],
          [0.9190, 0.1657, 0.5841]],

         [[0.7766, 0.7864, 0.2004],
          [0.9379, 0.3253, 0.1964]]],

        [[[0.7448, 0.9222, 0.1860],
          [0.3829, 0.8812, 0.2508]],

         [[0.0130, 0.0405, 0.2205],
          [0.8997, 0.5143, 0.9414]]]])

tensor([[[[-0.1764,  0.7257,  0.1446],
          [ 1.2605, -1.5434,  0.0140]],

         [[ 0.8332,  0.8615, -0.8287],
          [ 1.2987, -0.4685, -0.8403]]],

        [[[ 0.6121,  1.2726, -1.4678],
          [-0.7350,  1.1199, -1.2269]],

         [[-1.3693, -1.2899, -0.7707],
          [ 1.1883,  0.0769,  1.3088]]]], grad_fn=<NativeBatchNormBackward>)

标准化过程是以通道为维度计算的,即所有batch下,相同通道(channel)下的数据合并到一块,做标准化处理。若有C个通道,无论batch是多少,都会有C次标准化。

import torch
from torch import nn
img=torch.rand(2,2,2,3)

a=torch.cat((img[0,0,:,:],img[1,0,:,:]),dim=0)

b=a.numpy()
import numpy as np
mean=np.mean(b)
std=np.std(b)+1e-5

img_2=(b-mean)/std
bn=nn.BatchNorm2d(2)

img_3=bn(img)
print(img_2)
print(img_3)

手动标准化得到的数据,前两行代表第一个batch下第一个通道标准化后的数据,与利用BatchNorm2d的前两行数据相等;后两行代表第二个batch下第一个通道标准化后的数据,与利用BatchNorm2d的前五六行数据相等。


[[-0.8814389  -1.3535967   0.05035681]
 [-0.5180839  -1.396645    1.8198812 ]
 [ 0.9151892   0.9469903  -0.7903797 ]
 [ 0.35690263  1.135288   -0.28446582]]

tensor([[[[-0.8814, -1.3535,  0.0504],
          [-0.5181, -1.3966,  1.8198]],

         [[-1.5779, -0.5996, -1.0233],
          [-0.3919, -0.6692,  0.6693]]],

        [[[ 0.9151,  0.9469, -0.7903],
          [ 0.3569,  1.1352, -0.2844]],

         [[ 0.5829, -0.7664,  1.1329],
          [ 1.4469, -0.4100,  1.6063]]]], grad_fn=<NativeBatchNormBackward>)

训练过程

import torch
from torch import nn
img=torch.rand(2,2,2,3)
bn_t=nn.BatchNorm2d(2,track_running_stats=True)
bn_f=nn.BatchNorm2d(2,track_running_stats=False)

print('bn_t,mean:',bn_t.running_mean,'var:',bn_t.running_var)
print('bn_f,mean:',bn_f.running_mean,'var:',bn_f.running_var)

bn_t.train()
bn_f.train()
img_t=bn_t(img)
img_f=bn_f(img)
print(img_t)
print(img_f)
print('一次迭代更新bn_t,mean:',bn_t.running_mean,'var:',bn_t.running_var)
img_t=bn_t(img)
print('两次迭代更新bn_t,mean:',bn_t.running_mean,'var:',bn_t.running_var)

bn_t,mean: tensor([0., 0.]) var: tensor([1., 1.])

bn_f,mean: None var: None

tensor([[[[-1.0599,  0.9532, -0.2647],
          [ 0.8146,  0.2971, -1.7099]],

         [[ 1.0554,  0.9239,  1.9331],
          [ 0.0334, -1.3058, -0.0804]]],

        [[[ 1.0146,  0.7528, -0.1986],
          [ 1.3564, -1.6232, -0.3325]],

         [[-1.6591, -0.7690, -0.3045],
          [ 0.7691,  0.1344, -0.7306]]]], grad_fn=<NativeBatchNormBackward>)
tensor([[[[-1.0599,  0.9532, -0.2647],
          [ 0.8146,  0.2971, -1.7099]],

         [[ 1.0554,  0.9239,  1.9331],
          [ 0.0334, -1.3058, -0.0804]]],

        [[[ 1.0146,  0.7528, -0.1986],
          [ 1.3564, -1.6232, -0.3325]],

         [[-1.6591, -0.7690, -0.3045],
          [ 0.7691,  0.1344, -0.7306]]]], grad_fn=<NativeBatchNormBackward>)

一次迭代更新bn_t,mean: tensor([0.0562, 0.0586]) var: tensor([0.9092, 0.9043])
两次迭代更新bn_t,mean: tensor([0.1068, 0.1114]) var: tensor([0.8275, 0.8183])

测试过程

import torch
from torch import nn

bn_t=nn.BatchNorm2d(2,track_running_stats=True)
bn_f=nn.BatchNorm2d(2,track_running_stats=False)

print('bn_t,mean:',bn_t.running_mean,'var:',bn_t.running_var)

bn_t.eval()
bn_f.eval()
img_t=bn_t(img)
img_f=bn_f(img)
print(img)
print(img_t)
print(img_f)
bn_t.train()
img_t=bn_t(img)
bn_t.eval()
img_t=bn_t(img)
print('更新后bn_t,mean:',bn_t.running_mean,'var:',bn_t.running_var)
print(img_t)

bn_t,mean: tensor([0., 0.]) var: tensor([1., 1.])

tensor([[[[0.2542, 0.8395, 0.4854],
          [0.7992, 0.6488, 0.0652]],

         [[0.7970, 0.7707, 0.9722],
          [0.5929, 0.3256, 0.5702]]],

        [[[0.8574, 0.7813, 0.5046],
          [0.9568, 0.0904, 0.4657]],

         [[0.2550, 0.4327, 0.5255],
          [0.7398, 0.6131, 0.4404]]]])

tensor([[[[0.2542, 0.8395, 0.4854],
          [0.7992, 0.6488, 0.0652]],

         [[0.7970, 0.7707, 0.9722],
          [0.5929, 0.3256, 0.5702]]],

        [[[0.8574, 0.7813, 0.5046],
          [0.9568, 0.0904, 0.4657]],

         [[0.2550, 0.4327, 0.5255],
          [0.7398, 0.6131, 0.4404]]]], grad_fn=<NativeBatchNormBackward>)

tensor([[[[-1.0599,  0.9532, -0.2647],
          [ 0.8146,  0.2971, -1.7099]],

         [[ 1.0554,  0.9239,  1.9331],
          [ 0.0334, -1.3058, -0.0804]]],

        [[[ 1.0146,  0.7528, -0.1986],
          [ 1.3564, -1.6232, -0.3325]],

         [[-1.6591, -0.7690, -0.3045],
          [ 0.7691,  0.1344, -0.7306]]]], grad_fn=<NativeBatchNormBackward>)

更新后bn_t,mean: tensor([0.0562, 0.0586]) var: tensor([0.9092, 0.9043])

tensor([[[[0.2076, 0.8215, 0.4501],
          [0.7792, 0.6214, 0.0094]],

         [[0.7764, 0.7488, 0.9607],
          [0.5619, 0.2807, 0.5380]]],

        [[[0.8402, 0.7604, 0.4702],
          [0.9444, 0.0358, 0.4294]],

         [[0.2065, 0.3934, 0.4909],
          [0.7163, 0.5831, 0.4015]]]], grad_fn=<NativeBatchNormBackward>)

Original: https://blog.csdn.net/qq_50001789/article/details/120507768
Author: 视觉萌新、
Title: nn.BatchNorm2d——批量标准化操作解读

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/623421/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球