1.下载mnist
使用torchvision.datasets,其中含有一些常见的MNIST等数据集,使用方式:
train_data=torchvision.datasets.MNIST(
root='MNIST',
train=True,
transform=torchvision.transforms.ToTensor(),
download=True
)
test_data=torchvision.datasets.MNIST(
root='MNIST',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True
)
root:表示下载位置,下载后,会在该位置中新建一个MNIST文件夹,底下还有一个raw文件夹
train:True下载就会是训练集,False下载就会是测试集
transform:表示转换方式
download:表示是否下载
下载完后会生成四个压缩包,分别代表着train的img和label以及test的img和label
变量train_data和test_data的类型分别为’torchvision.datasets.mnist.MNIST’,如果想用到pytorch中的进行训练,就必须将变量改为torch
2.torch.utils.data.DataLoader( )
用from torch.utils.data import DataLoader进行导入,
train_load=DataLoader(dataset=train_data,batch_size=100,shuffle=True)
test_load=DataLoader(dataset=test_data,batch_size=100,shuffle=True)
随机加载批量大小为l00数据给train_load和test_load,每个变量都由两部分组成,用迭代器将两部分分开
train_x,train_y=next(iter(train_load))
其中train_x为属性值,type(train_x)=torch.Size([100, 1, 28, 28])#100个,channel为1,长宽为28*28,type(train_y)=torch.size([100])
3.opencv显示图片
import cv2
img=torchvision.utils.make_grid(train_x,nrow=10)#将train_x赋给一个宽为10的网格
#因为cv2显示的图片格式是(size,size,channel),但是img格式为(channel,size,size)
img = img.numpy().transpose(1,2,0)
cv2.imshow('img', img)
cv2.waitKey()
Original: https://blog.csdn.net/weixin_45412737/article/details/120561883
Author: 啥也不会的阿兴
Title: pytorch下载加载mnist数据集
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/708854/
转载文章受原作者版权保护。转载请注明原作者出处!