pytorch导入自定义数据集

最近刚学图神经网络,数据集导入折腾了很久,终于开窍了一点。
目前常用的数据导入方法主要有两种:

(1)torchvision自带的导入方式:
这种导入方式使用了torchvision自带的库,打开函数进去看它的说明是这样的:

pytorch导入自定义数据集
直接翻译过来意思就是 图片要放在相应类别的文件夹下,文件夹名字就是图片所属的类别。

导入代码如下:

from torchvision import datasets
'''transform可自行定义'''
train_transforms = transforms.Compose(
        [transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
train_dataset=datasets.ImageFolder(train_dir,transform=train_transforms)

2.自定义数据导入方式
现实使用过程中经常会遇到图片跟标签是分开放置的情况,如下面两张图所示,图片和label分别放置的,那么torchvision自带的库就不能用了,需要自定义数据读取方式。

pytorch导入自定义数据集
pytorch导入自定义数据集
首先用os库遍历文件,提取图片的名字和对应的label,保存在CSV文件中(当然完整的程序不保存也可以,这里是为了方便后面用),遍历的方式参考这篇博客。

开始自定义导入数据的类,这部分的格式都是统一的,最开始先写上这几个必须的函数,再往里面填东西:

from torch.utils.data import Dataset
class LoadData(Dataset):
    def __init__(self,image_path,transform=None):

    def __getitem__(self,index):

    def __len__(self):

确定模板以后直接往里面填东西就可以了:

from torch.utils.data import Dataset
import pandas as pd
from PIL import Image
class LoadData(Dataset):
    def __init__(self,image_path,transform=None):
        self.imgs_info=pd.read_csv(image_path)
    def __getitem__(self,index):
        img_path,label=self.imgs_info['img_path'],self.imgs_info['weather']
        img=Image.open(img_path)
        img=img.convert('RGB')
        if transform is not None:
            img=transform(img)
        returnimg,label
    def __len__(self):
        return len(self.imgs_info)

主函数中调用:

from torchvision import transforms
train_csv_path=r'./dataset/train.csv'
train_transforms=transforms.Compose(
        [transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
train_dataset=LoadData(train_csv_path,transform=train_transforms)
train_loader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=10,shuffle=True)

Original: https://blog.csdn.net/weixin_43760440/article/details/123120000
Author: 蒽,开心(∩_∩)
Title: pytorch导入自定义数据集

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/650661/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球