6. 手写数字图片数据集MNIST

MNIST数据集(http://yann.lecun.com/exdb/mnist/)

手写数字图片数据集,存在60000个训练样本,10000个测试样本。每个样本为一个28X28像素的图片。

6. 手写数字图片数据集MNIST

主要包含四个压缩文件:

  1. train-images-idx3-ubyte.gz训练样本图片的原始数据train-labels-idx1-ubyte.gz训练样本图片对应的标签数据t10k-images-idx3-ubyte.gz测试样本图片的原始数据t10k-labels-idx1-ubyte.gz测试样本图片对应的标签数据 第一步:数据集的下载
  2. MNIST — Torchvision 0.12 documentation 6. 手写数字图片数据集MNISThttps://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST ;
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST

trainData = MNIST(root = "./",
                  train = True,
                  transform=ToTensor(),
                  download = True)
testData = MNIST(root = "./",
                  train = False,
                  transform=ToTensor(),
                  download = True)

如果download为True,在当前目录下出现MNIST文件夹,其中./MNIST/raw内会出现MNIST的四个文件。否则,会从./MNIST/raw自动加载四个文件。

第二步:数据集加载

torch.utils.data — PyTorch 1.11.0 documentation 6. 手写数字图片数据集MNISThttps://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader ;

from torch.utils.data import DataLoader
batch_size = 64
trainData_loader = DataLoader(dataset = trainData,
                              batch_size = batch_size,
                              shuffle = True)
testData_loader = DataLoader(dataset = testData,
                             batch_size = batch_size,
                             shuffle = True)

batch_size = 64 代表每次加载64个样本

第三步:理解样本数据

3.1 数据查看

examples = enumerate(trainData_loader)
idx, (data,labels) = next(examples)
print(data.shape)
print(labels)
torch.Size([64, 1, 28, 28])
tensor([3, 9, 0, 1, 2, 1, 5, 1, 8, 1, 9, 8, 3, 4, 3, 0, 9, 8, 3, 9, 4, 9, 6, 9,
        7, 4, 5, 3, 0, 6, 1, 4, 0, 6, 1, 8, 5, 0, 5, 8, 0, 7, 1, 8, 1, 4, 6, 9,
        4, 6, 7, 4, 2, 5, 4, 7, 1, 2, 6, 1, 9, 0, 0, 7])

data.shape [64,1,28,28] – 64个样本,每个样本有一个通道,每个通道包含28X28的像素;

label – 对应这64个样本的标签;

注:一般灰度图像只有一个通道;如果是彩色图像,是三个通道,对应RGB三原色。

labels – 64个样本图片对应的标签。

3.2 数据显示

import matplotlib.pyplot as plt
data = data.squeeze()   # 删除通道维度 [64,1,28,28]->[64,28,28]

fig = plt.figure(dpi=300)
for i in range(8):
    for j in range(8):
        plt.subplot(8,8, i*8+j+1 )
        plt.imshow(data[i*8+j])
        plt.xticks([])
        plt.yticks([])
plt.show()

6. 手写数字图片数据集MNIST

生成的图片 与 3.1步骤中显示 labels标签一一对应

附录:完整代码

from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

trainData = MNIST(root = "./",
                  train = True,
                  transform=ToTensor(),
                  download = True)
testData = MNIST(root = "./",
                  train = False,
                  transform=ToTensor(),
                  download = True)

batch_size = 64
trainData_loader = DataLoader(dataset = trainData,
                              batch_size = batch_size,
                              shuffle = True)

testData_loader = DataLoader(dataset = testData,
                             batch_size = batch_size,
                             shuffle = True)

examples = enumerate(trainData_loader)
idx, (data,labels) = next(examples)

fig = plt.figure()
for i in range(8):
    for j in range(8):
        plt.subplot(8,8, i*8+j+1 )
        plt.imshow(data.squeeze()[i*8+j])
        plt.xticks([])
        plt.yticks([])
plt.show()

Original: https://blog.csdn.net/Austin6035/article/details/124542318
Author: Austin6035
Title: 6. 手写数字图片数据集MNIST

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/618442/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球