一、实现过程
使用Pytorch进行预处理时,通常使用torchvision.transforms.Normalize(mean, std)方法进行数据标准化,其中参数mean和std分别表示图像集每个通道的均值和标准差序列。
首先,给出mean和std的定义,数学表示如下:
假设有一组数据集X i , i ∈ { 1 , 2 , ⋯ , n } X_i,\,\,i\in{1,2,\cdots,n}X i ,i ∈{1 ,2 ,⋯,n },则这组数据集的均值为:m e a n = ∑ i = 1 n X i n (1) mean=\frac{\displaystyle\sum_{i=1}^nX_i}{n}\tag{1}m e a n =n i =1 ∑n X i (1 )通常使用X ‾ \overline X X表示数据的均值。
这组数据集的标准差为:s t d = ∑ i = 1 n ( X i − X ‾ ) 2 n = ∑ i = 1 n ( X i 2 − 2 X i X ‾ + X ‾ 2 ) n = ( ∑ i = 1 n X i 2 ) − n X ‾ 2 n = ∑ i = 1 n X i 2 n − X ‾ 2 (2) std=\sqrt{\frac{\displaystyle\sum_{i=1}^n\left(X_i-\overline X\right)^2}{n}}\[2ex]=\sqrt{\frac{\displaystyle\sum_{i=1}^n(X_i^2-2X_i\overline X+\overline X^2)}{n}}\[2ex]=\sqrt{\frac{\left(\displaystyle\sum_{i=1}^nX_i^2\right)-n\overline X^2}{n}}\[2ex]=\sqrt{\frac{\displaystyle\sum_{i=1}^nX_i^2}{n}-\overline X^2}\tag{2}s t d =n i =1 ∑n (X i −X )2 =n i =1 ∑n (X i 2 −2 X i X +X 2 )=n (i =1 ∑n X i 2 )−n X 2 =n i =1 ∑n X i 2 −X 2 (2 )下面给出计算图像数据集每个通道的均值和标准差的函数代码:
import torch
from torchvision import transforms,datasets
from torch.utils.data import DataLoader
batch_size = 64
train_dataset = datasets.CIFAR10(root='G:/datasets/cifar10',train=True,download=False,transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)
def get_mean_std_value(loader):
'''
求数据集的均值和标准差
:param loader:
:return:
'''
data_sum,data_squared_sum,num_batches = 0,0,0
for data,_ in loader:
data_sum += torch.mean(data,dim=[0,2,3])
data_squared_sum += torch.mean(data**2,dim=[0,2,3])
num_batches += 1
mean = data_sum/num_batches
std = (data_squared_sum/num_batches - mean**2)**0.5
return mean,std
mean,std = get_mean_std_value(train_loader)
print('mean = {},std = {}'.format(mean,std))
CIFAR10数据集的均值和标准差为:
mean = tensor([0.4914, 0.4821, 0.4465]),std = tensor([0.2470, 0.2435, 0.2616])
MNIST数据集的均值和标准差为:
mean = tensor([0.1307]),std = tensor([0.3081])
二、参考文献
[1] https://zhuanlan.zhihu.com/p/378810257
Original: https://blog.csdn.net/weixin_43821559/article/details/123459085
Author: 心️升明月
Title: PyTorch实现计算图像数据集的均值和标准差
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/710017/
转载文章受原作者版权保护。转载请注明原作者出处!