pytorch中DataLoader详解

功能初体验


import torch
import torch.utils.data as Data

if __name__ == '__main__':
    torch.manual_seed(1)

    BATCH_SIZE = 5

    x = torch.linspace(11, 20, 10)
    y = torch.linspace(1, 10, 10)

    torch_dataset = Data.TensorDataset(x, y)

    loader = Data.DataLoader(
        dataset=torch_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=0,
    )

    for epoch in range(3):
        for step,(batch_x,batch_y) in enumerate(loader):

            print('Epoch: ', epoch, '| Step:', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())

pytorch中DataLoader详解

参数简介

pytorch中DataLoader详解
上图为源码中dataloader中所有的可选参数。除了第一个dataset参数外,其他均为可选参数。
  • Dataset:处理好的所有数据
  • batch_size:批数量
  • shuffle:打乱数据
  • sampler:采样机制,即从数据集里面取样本的方式(迭代器,每次返回一个样本)
  • batch_sampler:把sampler的采样的样本根据batch_size组织成一个batch返回
  • num_worker:加载数据的线程数
  • collate_fn:把batch_sampler返回的list结构的一个batch的样本打包成一个tensor的结构
  • pin_memory:将加载的数据拷贝到CUDA中的固定内存中,从而使数据更快地传输到支持cuda的gpu
  • drop_last:丢弃余数
  • timeout:如果是正数,表明等待从加载一个batch等待的时间,若超出设定的时间还没有加载完,就放弃这个batch,如果是0,表示不设置限制时间。默认为0
  • worker_init_fn:如果不是None ,它将在每个worker子进程上以worker id ([0, num_workers – 1] )作为输入调用,在seeding之后和数据加载之前。
  • generater:如果不是None,这个RNG将被RandomSampler用来生成随机索引,并被multiprocessing用来为worker生成’ base_seed ‘。 (默认值:’ ‘没有’ ‘)
  • prefetch_factor:提前加载多少个batch的数据,可以保证线程不会等待,每个线程都总有至少一个数据在加载。提升显卡利用率。
  • persistent_workers:如果为True,数据加载器将不会在数据集运行完一个Epoch后关闭worker进程。这允许维护worker数据集实例保持激活。(默认值:False),意思是运行完一个Epoch后并不会关闭worker进程,而是保持现有的worker进程继续进行下一个Epoch的数据加载。好处是Epoch之间不必重复关闭启动worker进程,加快训练速度。

; Dataloader参数之间的互斥

值得注意的是,Dataloader的参数之间存在互斥的情况,主要针对自己定义的采样器:

  • sampler:如果自行指定了sampler参数,则shuffle必须保持默认值,即False
  • batch_sampler:如果自行指定了batch_sampler参数,则 batch_size、shuffle、sampler、drop_last 都必须保持默认值
  • 如果没有指定自己是采样器,那么默认的情况下(即sampler和batch_sampler均为None的情况下),dataloader的采样策略是如何的呢:

sampler:

  • shuffle = True:sampler采用 RandomSampler,即随机采样
  • shuffle = Flase:sampler采用 SequentialSampler,即按照顺序采样
  • batch_sampler:采用 BatchSampler,即根据 batch_size 进行batch采样
  • 上面提到的 RandomSampler、SequentialSampler和BatchSampler都是PyTorch自己实现的,且它们都是Sampler的子类。

Sampler

SequentialSampler

SequentialSampler就是一个按照顺序进行采样的采样器,接收一个数据集做参数(实际上任何可迭代对象都可),按照顺序对其进行采样:

from torch.utils.data import SequentialSampler

pseudo_dataset = list(range(10, 20))
for data in SequentialSampler(pseudo_dataset):
    print(data, end=" ")
0 1 2 3 4 5 6 7 8 9

RandomSampler

RandomSampler 即一个随机采样器,返回随机采样的值,第一个参数依然是一个数据集(或可迭代对象)。还有一组参数如下:

  • replacement:bool值,默认是False,设置为True时表示可以采出重复的样本
  • num_samples:只有在replacement设置为True的时候才能设置此参数,表示要采出样本的个数,默认为数据集的总长度。有时候由于replacement置True的原因导致重复数据被采样,导致有些数据被采不到,所以往往会设置一个比较大的值
from torch.utils.data import RandomSampler

pseudo_dataset = list(range(10, 20))

randomSampler1 = RandomSampler(pseudo_dataset)
randomSampler2 = RandomSampler(pseudo_dataset, replacement=True, num_samples=20)

print("for random sampler #1: ")
for data in randomSampler1:
    print(data, end=" ")

print("\n\nfor random sampler #2: ")
for data in randomSampler2:
    print(data, end=" ")

for random sampler
4 5 2 9 3 0 6 8 7 1

for random sampler
4 9 0 6 9 3 1 6 1 8 5 0 2 7 2 8 6 4 0 6

WeightedRandomSampler

WeightedRandomSampler和RandomSampler的参数一致,但是不在传入一个dataset,第一个参数变成了weights,只接收一个一定长度的list作为 weights 参数,表示采样的权重,采样时会根据权重随机从 list(range(len(weights))) 中采样,即WeightedRandomSampler并不需要传入样本集,而是只在一个根据weights长度创建的数组中采样,所以采样的结果可能需要进一步处理才能使用。weights的所有元素之和不需要为1。

from torch.utils.data import WeightedRandomSampler

weights = [1,1,10,10]

weightedRandomSampler = WeightedRandomSampler(weights, replacement=True, num_samples=20)

for data in weightedRandomSampler:
    print(data, end=" ")
2 2 2 3 2 2 3 2 3 3 1 3 2 2 1 3 3 2 3 3

详细使用可参考:WeightedRandomSampler使用案例

BatchSampler

其他Sampler在每次迭代都只返回一个索引,而BatchSampler的作用是对上述这类返回一个索引的采样器进行包装,按照设定的batch size返回 一组具体数据,因其他的参数和上述的有些不同:

  • sampler:一个Sampler对象(或者一个可迭代对象)
  • batch_size:batch的大小
  • drop_last:是否丢弃最后一个可能不足batch size大小的数据
from torch.utils.data import BatchSampler
pseudo_dataset = list(range(10, 20))

batchSampler1 = BatchSampler(pseudo_dataset, batch_size=3, drop_last=False)
batchSampler2 = BatchSampler(pseudo_dataset, batch_size=3, drop_last=True)

print("for batch sampler #1: ")
for data in batchSampler1:
    print(data, end=" ")

print("\n\nfor batch sampler #2: ")
for data in batchSampler2:
    print(data, end=" ")
for batch sampler
[10, 11, 12] [13, 14, 15] [16, 17, 18] [19]

for batch sampler
[10, 11, 12] [13, 14, 15] [16, 17, 18]

SubsetRandomSampler

SubsetRandomSampler 可以设置子集的随机采样,多用于将数据集分成多个集合,比如训练集和验证集的时候使用:

from torch.utils.data import SubsetRandomSampler

pseudo_dataset = list(range(10, 20))

subRandomSampler1 = SubsetRandomSampler(pseudo_dataset[:7])
subRandomSampler2 = SubsetRandomSampler(pseudo_dataset[7:])

print("for subset random sampler #1: ")
for data in subRandomSampler1:
    print(data, end=" ")

print("\n\nfor subset random sampler #2: ")
for data in subRandomSampler2:
    print(data, end=" ")

for subset random sampler
14 15 11 16 13 10 12

for subset random sampler
17 19 18

参考:https://blog.csdn.net/qq_38962621/article/details/111146427

Original: https://blog.csdn.net/EMIvv/article/details/122509200
Author: Shashank497
Title: pytorch中DataLoader详解

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/625702/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球