pytorch dataloader详解

构建自己的dataloader是模型训练的第一步,本篇文章介绍下pytorch与dataloader以及与其相关的类的用法。

DataLoader类中有一个必填参数为 dataset,因此在构建自己的dataloader前,先要定义好自己的 Dataset类。这里先大致介绍下这两个类的作用:

  • Dataset:真正的”数据集”,它的作用是: 只要告诉它数据在哪里(初始化),就可以像使用iterator一样去拿到数据,继承该类后,需要重载 __len__()以及 __getitem__
  • DataLoader:数据加载器,设置一些参数后,可以按照一定规则加载数据,比如设置batch_size后,每次加载一个batch_siza的数据。它像一个生成器一样工作。

有小伙伴可能会疑惑,自己写一个加载数据的工具似乎也没有多”困难”,为何大费周章要继承pytorch中类,按照它的规则加载数据呢?关于这点可以参考这里:pytorch dataloader,总结一下就是:

  • 当数据量很大的时候,单进程加载数据很慢
  • 一次全加载过来,会占用很大的内存空间(因此dataloader是一个生成器,惰性加载)
  • 在进行训练前,往往需要一些数据预处理或数据增强等操作,pytorch的dataloader已经封装好了,避免了重复造轮子

一、使用方法

两步走:

  1. 定义自己的Dataset类,具体要做的事:
  2. 告诉它去哪儿读数据,并将数据resize为统一的shape(可以思考下为什么呢)
  3. 重写 __len__()以及 __getitem__,其中 __getitem__中要确定自己想要哪些数据,然后将其return出来。
  4. 将自己的Dataset实例传到Dataloder中并设置想要的参数,构建自己的dataloader

下面简单加载一个目录下的图片以及label:

import os
import numpy as np

from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
import cv2

img_dir = '/home/jyz/Downloads/classify_example/val/骏马/'
anno_file = '/home/jyz/Downloads/classify_example/val/label.txt'

class MyDataset(Dataset):
    def __init__(self, img_dir, anno_file, imgsz=(640, 640)):
        self.img_dir = img_dir
        self.anno_file = anno_file
        self.imgsz = imgsz
        self.img_namelst = os.listdir(self.img_dir)

    def __len__(self):
        return len(self.img_namelst)

    def __getitem__(self, idx):
        with open(self.anno_file, 'r') as f:
            label = f.readline().strip()
        img = cv2.imread(os.path.join(img_dir, self.img_namelst[idx]))
        img = cv2.resize(img, self.imgsz)
        return img, label

dataset = MyDataset(img_dir, anno_file)
dataloader = DataLoader(dataset=dataset, batch_size=2)

for img_batch, label_batch in dataloader:
    img_batch = img_batch.numpy()
    print(img_batch.shape)

    if img_batch.shape[0] == 2:
        img = np.hstack((img_batch[0], img_batch[1]))
    else:
        img = np.squeeze(img_batch, axis=0)
    print(img.shape)
    cv2.imshow(label_batch[0], img)
    cv2.waitKey(0)

上面是一次加载两张图片,效果如下:

pytorch dataloader详解

二、结论

  1. 使用pytorch的dataloader,需要先构建自己的Dataset
  2. 构建自己的Dataset,需要重载 __len__()以及 __getitem__
  3. 数据地址:example data,提取码: a1ds

Original: https://blog.csdn.net/qq_34062683/article/details/126528869
Author: 惊瑟
Title: pytorch dataloader详解

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/670510/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球