MNIST数据集(http://yann.lecun.com/exdb/mnist/)
手写数字图片数据集,存在60000个训练样本,10000个测试样本。每个样本为一个28X28像素的图片。
主要包含四个压缩文件:
- train-images-idx3-ubyte.gz训练样本图片的原始数据train-labels-idx1-ubyte.gz训练样本图片对应的标签数据t10k-images-idx3-ubyte.gz测试样本图片的原始数据t10k-labels-idx1-ubyte.gz测试样本图片对应的标签数据 第一步:数据集的下载
- MNIST — Torchvision 0.12 documentation https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST ;
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
trainData = MNIST(root = "./",
train = True,
transform=ToTensor(),
download = True)
testData = MNIST(root = "./",
train = False,
transform=ToTensor(),
download = True)
如果download为True,在当前目录下出现MNIST文件夹,其中./MNIST/raw内会出现MNIST的四个文件。否则,会从./MNIST/raw自动加载四个文件。
第二步:数据集加载
from torch.utils.data import DataLoader
batch_size = 64
trainData_loader = DataLoader(dataset = trainData,
batch_size = batch_size,
shuffle = True)
testData_loader = DataLoader(dataset = testData,
batch_size = batch_size,
shuffle = True)
batch_size = 64 代表每次加载64个样本
第三步:理解样本数据
3.1 数据查看
examples = enumerate(trainData_loader)
idx, (data,labels) = next(examples)
print(data.shape)
print(labels)
torch.Size([64, 1, 28, 28])
tensor([3, 9, 0, 1, 2, 1, 5, 1, 8, 1, 9, 8, 3, 4, 3, 0, 9, 8, 3, 9, 4, 9, 6, 9,
7, 4, 5, 3, 0, 6, 1, 4, 0, 6, 1, 8, 5, 0, 5, 8, 0, 7, 1, 8, 1, 4, 6, 9,
4, 6, 7, 4, 2, 5, 4, 7, 1, 2, 6, 1, 9, 0, 0, 7])
data.shape [64,1,28,28] – 64个样本,每个样本有一个通道,每个通道包含28X28的像素;
label – 对应这64个样本的标签;
注:一般灰度图像只有一个通道;如果是彩色图像,是三个通道,对应RGB三原色。
labels – 64个样本图片对应的标签。
3.2 数据显示
import matplotlib.pyplot as plt
data = data.squeeze() # 删除通道维度 [64,1,28,28]->[64,28,28]
fig = plt.figure(dpi=300)
for i in range(8):
for j in range(8):
plt.subplot(8,8, i*8+j+1 )
plt.imshow(data[i*8+j])
plt.xticks([])
plt.yticks([])
plt.show()
生成的图片 与 3.1步骤中显示 labels标签一一对应
附录:完整代码
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
trainData = MNIST(root = "./",
train = True,
transform=ToTensor(),
download = True)
testData = MNIST(root = "./",
train = False,
transform=ToTensor(),
download = True)
batch_size = 64
trainData_loader = DataLoader(dataset = trainData,
batch_size = batch_size,
shuffle = True)
testData_loader = DataLoader(dataset = testData,
batch_size = batch_size,
shuffle = True)
examples = enumerate(trainData_loader)
idx, (data,labels) = next(examples)
fig = plt.figure()
for i in range(8):
for j in range(8):
plt.subplot(8,8, i*8+j+1 )
plt.imshow(data.squeeze()[i*8+j])
plt.xticks([])
plt.yticks([])
plt.show()
Original: https://blog.csdn.net/Austin6035/article/details/124542318
Author: Austin6035
Title: 6. 手写数字图片数据集MNIST
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/618442/
转载文章受原作者版权保护。转载请注明原作者出处!