PyTorch实现计算图像数据集的均值和标准差

一、实现过程

使用Pytorch进行预处理时,通常使用torchvision.transforms.Normalize(mean, std)方法进行数据标准化,其中参数mean和std分别表示图像集每个通道的均值和标准差序列。
首先,给出mean和std的定义,数学表示如下:
假设有一组数据集X i , i ∈ { 1 , 2 , ⋯ , n } X_i,\,\,i\in{1,2,\cdots,n}X i ​,i ∈{1 ,2 ,⋯,n },则这组数据集的均值为:m e a n = ∑ i = 1 n X i n (1) mean=\frac{\displaystyle\sum_{i=1}^nX_i}{n}\tag{1}m e a n =n i =1 ∑n ​X i ​​(1 )通常使用X ‾ \overline X X表示数据的均值。
这组数据集的标准差为:s t d = ∑ i = 1 n ( X i − X ‾ ) 2 n = ∑ i = 1 n ( X i 2 − 2 X i X ‾ + X ‾ 2 ) n = ( ∑ i = 1 n X i 2 ) − n X ‾ 2 n = ∑ i = 1 n X i 2 n − X ‾ 2 (2) std=\sqrt{\frac{\displaystyle\sum_{i=1}^n\left(X_i-\overline X\right)^2}{n}}\[2ex]=\sqrt{\frac{\displaystyle\sum_{i=1}^n(X_i^2-2X_i\overline X+\overline X^2)}{n}}\[2ex]=\sqrt{\frac{\left(\displaystyle\sum_{i=1}^nX_i^2\right)-n\overline X^2}{n}}\[2ex]=\sqrt{\frac{\displaystyle\sum_{i=1}^nX_i^2}{n}-\overline X^2}\tag{2}s t d =n i =1 ∑n ​(X i ​−X )2 ​​=n i =1 ∑n ​(X i 2 ​−2 X i ​X +X 2 )​​=n (i =1 ∑n ​X i 2 ​)−n X 2 ​​=n i =1 ∑n ​X i 2 ​​−X 2 ​(2 )下面给出计算图像数据集每个通道的均值和标准差的函数代码:

import torch
from torchvision import transforms,datasets
from torch.utils.data import DataLoader

batch_size = 64

train_dataset = datasets.CIFAR10(root='G:/datasets/cifar10',train=True,download=False,transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)

def get_mean_std_value(loader):
    '''
    求数据集的均值和标准差
    :param loader:
    :return:
    '''
    data_sum,data_squared_sum,num_batches = 0,0,0

    for data,_ in loader:

        data_sum += torch.mean(data,dim=[0,2,3])

        data_squared_sum += torch.mean(data**2,dim=[0,2,3])

        num_batches += 1

    mean = data_sum/num_batches

    std = (data_squared_sum/num_batches - mean**2)**0.5
    return mean,std

mean,std = get_mean_std_value(train_loader)
print('mean = {},std = {}'.format(mean,std))

CIFAR10数据集的均值和标准差为:

mean = tensor([0.4914, 0.4821, 0.4465]),std = tensor([0.2470, 0.2435, 0.2616])

MNIST数据集的均值和标准差为:

mean = tensor([0.1307]),std = tensor([0.3081])

二、参考文献

[1] https://zhuanlan.zhihu.com/p/378810257

Original: https://blog.csdn.net/weixin_43821559/article/details/123459085
Author: 心️升明月
Title: PyTorch实现计算图像数据集的均值和标准差

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/710017/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球