用开源代码跑自己的数据集:修改dataloader

论文需要跑网络对比实验。那么如何用 Github 上的代码(或者其他开源代码) 跑我们需要它跑的数据集呢

下文将简要介绍与 PyTorch 框架的 dataloader 的相关知识。
首先引用 PyTorch 中文教程中关于 Dataset 抽象类的介绍和 Dataloader 的介绍 :

  • 我们在做深度学习训练时,首先要做的是做一个 数据集类,它可能需要完成 自动打乱数据数据处理批量提供 batchsize 数据等功能。 PyTorch 在 torch.utils.data 中提供了 Dataset抽象类,用于构建一个数据集类,可以对数据批量处理,可以构建一个数据集索引,PyTorch中的以方便批量训练数据时,方便调取。
  • 数据集创建完成后,我们 可以对数据进行索引,但是 还是无法实现批量获取数据,这时,我们就用到 DataLoader 去加载数据做一个数据加载器。

The DataLoader combines the dataset and a sampler, returning an iterable over the dataset.

它指出了 DataLoader 本质上是一个 迭代器,而且同时由 dataset 和 sampler 组成。一语道破,妙不可言。

上文中关于 “数据加载器” 的概念,同时出现 dataloaderDataloader。因为后者是 PyTorch 提供的。通常使用的时候,我们对 Dataloader 的参数赋值,然后将 Dataloader 赋值给一个 自己命名的 dataloader。如下所示:

train_loader = DataLoader(dataset = my_dataset,
                          batch_size = 32,
                          shuffle = True,
                          num_workers = 2)

下面的代码 ex1,我专门把 from torch.utils.data import Datasetfrom torch.utils.data import DataLoader 写出来了,

为什么?

因为在写自己的类 MyDataset 的时候,类 MyDataset 要继承 PyTorch 的抽象类 Dataset。

另外,也用到了 PyTorch 的 DataLoader 来得到参数 batch_size 等赋值后的我们自己的 train_loader 。


from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class MyDataset(Dataset):

    def __init__(self):

    def __getitem__(self,index)
        return

    def __len__(self):
        return

my_dataset = MyDataset()
train_loader = DataLoader(dataset = my_dataset,
                          batch_size = 32,
                          shuffle = True,
                          num_workers = 2)

ex2 这个代码的背景是要解决 分类 问题, 代码数据的来源是 data.csv。当然在 init 函数中, 还可以有其他一些代码,根据实际需求。比如 假设场景是 图像识别,那么在 init 函数中可能会有例如 ex3 的一段代码:


class MyDataset(Dataset):

    def __init__(self):
    xy = np.loadtxt('data.csv',delimiter=',',dtype=np.float32)
    self.len = xy.shape[0]
    self.data_input= torch.from_numpy(xy[:, 0:-1])
    self.label= torch.from_numpy(xy[:,[-1]])

    def __getitem__(self,index)
        return self.data_input[index], self.label[index]

    def __len__(self):
        return self.len

from torchvision import transforms as T
class MyDataset(Dataset):

    def __init__(self):
        上文代码省略
        transform = T.Compose([
            T.Resize(112,112),
            T.ToTensor(),
            T.Normalize(mean=[0.5], std=[0.5])
        ])

    def __getitem__(self,index)
        return

    def __len__(self):
        return

for step, data in enumerate(train_loader):
    data_input, label = data

for epoch in range(max_epoch):

    model.train()
    for step, data in enumerate(train_loader):
        data_input, label = data
  • 本文得到了该视频的启发,该视频作者信息如下:
    PyTorch Zero To All Lecture by Sung Kim hunkim+ml@gmail.com at HKUST
    Code: https://github.com/hunkim/PyTorchZero…

Slides: http://bit.ly/PyTorchZeroAll
* PyTorch 中文教程:构建自己的数据集

Original: https://blog.csdn.net/OrdinaryMatthew/article/details/123182727
Author: 培之
Title: 用开源代码跑自己的数据集:修改dataloader

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/710800/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球