最近刚学图神经网络,数据集导入折腾了很久,终于开窍了一点。
目前常用的数据导入方法主要有两种:
(1)torchvision自带的导入方式:
这种导入方式使用了torchvision自带的库,打开函数进去看它的说明是这样的:
直接翻译过来意思就是 图片要放在相应类别的文件夹下,文件夹名字就是图片所属的类别。
导入代码如下:
from torchvision import datasets
'''transform可自行定义'''
train_transforms = transforms.Compose(
[transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
transforms.RandomRotation(degrees=15),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
train_dataset=datasets.ImageFolder(train_dir,transform=train_transforms)
2.自定义数据导入方式
现实使用过程中经常会遇到图片跟标签是分开放置的情况,如下面两张图所示,图片和label分别放置的,那么torchvision自带的库就不能用了,需要自定义数据读取方式。
首先用os库遍历文件,提取图片的名字和对应的label,保存在CSV文件中(当然完整的程序不保存也可以,这里是为了方便后面用),遍历的方式参考这篇博客。
开始自定义导入数据的类,这部分的格式都是统一的,最开始先写上这几个必须的函数,再往里面填东西:
from torch.utils.data import Dataset
class LoadData(Dataset):
def __init__(self,image_path,transform=None):
def __getitem__(self,index):
def __len__(self):
确定模板以后直接往里面填东西就可以了:
from torch.utils.data import Dataset
import pandas as pd
from PIL import Image
class LoadData(Dataset):
def __init__(self,image_path,transform=None):
self.imgs_info=pd.read_csv(image_path)
def __getitem__(self,index):
img_path,label=self.imgs_info['img_path'],self.imgs_info['weather']
img=Image.open(img_path)
img=img.convert('RGB')
if transform is not None:
img=transform(img)
returnimg,label
def __len__(self):
return len(self.imgs_info)
主函数中调用:
from torchvision import transforms
train_csv_path=r'./dataset/train.csv'
train_transforms=transforms.Compose(
[transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
transforms.RandomRotation(degrees=15),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
train_dataset=LoadData(train_csv_path,transform=train_transforms)
train_loader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=10,shuffle=True)
Original: https://blog.csdn.net/weixin_43760440/article/details/123120000
Author: 蒽,开心(∩_∩)
Title: pytorch导入自定义数据集
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/650661/
转载文章受原作者版权保护。转载请注明原作者出处!