【深度学习】PyTorch Dataset类的使用与实例分析

Dataset类

介绍

当我们得到一个数据集时,Dataset类可以帮我们提取我们需要的数据,我们用子类继承Dataset类,我们先给每个数据一个编号(idx),在后面的神经网络中,初始化Dataset子类实例后,就可以通过这个编号去实例对象中读取相应的数据,会自动调用__getitem__方法,同时子类对象也会获取相应真实的Label(人为去复写即可)

Dataset类的作用:提供一种方式去获取数据及其对应的真实Label

在Dataset类的子类中,应该有以下函数以实现某些功能:

  1. 获取每一个数据及其对应的Label
  2. 统计数据集中的数据数量

关于2,神经网络经常需要对一个数据迭代多次,只有知道当前有多少个数据,进行训练时才知道要训练多少次,才能把整个数据集迭代完

Dataset官方文档解读

首先看一下Dataset的官方文档解释

导入Dataset类:

from torch.utils.data import Dataset

我们可以通过在 Jupyter中查看官方文档

from torch.utils.data import Dataset
help(Dataset)

输出:

Help on class Dataset in module torch.utils.data.dataset:

class Dataset(typing.Generic)
 |  An abstract class representing a :class:Dataset.

 |
 |  All datasets that represent a map from keys to data samples should subclass
 |  it. All subclasses should overwrite :meth:__getitem__, supporting fetching a
 |  data sample for a given key. Subclasses could also optionally overwrite
 |  :meth:__len__, which is expected to return the size of the dataset by many
 |  :class:~torch.utils.data.Sampler implementations and the default options
 |  of :class:~torch.utils.data.DataLoader.

 |
 |  .. note::
 |    :class:~torch.utils.data.DataLoader by default constructs a index
 |    sampler that yields integral indices.  To make it work with a map-style
 |    dataset with non-integral indices/keys, a custom sampler must be provided.

 |
 |  Method resolution order:
 |      Dataset
 |      typing.Generic
 |      builtins.object
 |
 |  Methods defined here:
 |
 |  __add__(self, other:'Dataset[T_co]') -> 'ConcatDataset[T_co]'
 |
 |  __getattr__(self, attribute_name)
 |
 |  __getitem__(self, index) -> +T_co
 |
 |  ----------------------------------------------------------------------
 |  Class methods defined here:
 |
 |  register_datapipe_as_function(function_name, cls_to_register, enable_df_api_tracing=False) from typing.GenericMeta
 |
 |  register_function(function_name, function) from typing.GenericMeta
 |
 |  ----------------------------------------------------------------------
 |  Data descriptors defined here:
 |
 |  __dict__
 |      dictionary for instance variables (if defined)
 |
 |  __weakref__
 |      list of weak references to the object (if defined)
 |
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |
 |  __abstractmethods__ = frozenset()
 |
 |  __annotations__ = {'functions': typing.Dict[str, typing.Callable]}
 |
 |  __args__ = None
 |
 |  __extra__ = None
 |
 |  __next_in_mro__ =
 |      The most base type
 |
 |  __orig_bases__ = (typing.Generic[+T_co],)
 |
 |  __origin__ = None
 |
 |  __parameters__ = (+T_co,)
 |
 |  __tree_hash__ = -9223371872509358054
 |
 |  functions = {'concat': functools.partial(

还有一种方式获取官方文档信息:

Dataset??

输出:

Init signature: Dataset(*args, **kwds)
Source:
class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:Dataset.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:__getitem__, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:__len__, which is expected to return the size of the dataset by many
    :class:~torch.utils.data.Sampler implementations and the default options
    of :class:~torch.utils.data.DataLoader.

    .. note::
      :class:~torch.utils.data.DataLoader by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.

"""
    functions: Dict[str, Callable] = {}

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No def __len__(self) default?

    # See NOTE [ Lack of Default __len__ in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

    def __getattr__(self, attribute_name):
        if attribute_name in Dataset.functions:
            function = functools.partial(Dataset.functions[attribute_name], self)
            return function
        else:
            raise AttributeError

    @classmethod
    def register_function(cls, function_name, function):
        cls.functions[function_name] = function

    @classmethod
    def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False):
        if function_name in cls.functions:
            raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))

        def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
            result_pipe = cls(source_dp, *args, **kwargs)
            if isinstance(result_pipe, Dataset):
                if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
                    if function_name not in UNTRACABLE_DATAFRAME_PIPES:
                        result_pipe = result_pipe.trace_as_dataframe()

            return result_pipe

        function = functools.partial(class_function, cls_to_register, enable_df_api_tracing)
        cls.functions[function_name] = function
File:           d:\environment\anaconda3\envs\py-torch\lib\site-packages\torch\utils\data\dataset.py
Type:           GenericMeta
Subclasses:     Dataset, IterableDataset, Dataset, TensorDataset, ConcatDataset, Subset, Dataset, Subset, Dataset, IterableDataset[+T_co], ...

其中我们可以看到:

"""An abstract class representing a :class:Dataset.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:__getitem__, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:__len__, which is expected to return the size of the dataset by many
    :class:~torch.utils.data.Sampler implementations and the default options
    of :class:~torch.utils.data.DataLoader.

"""

以上内容显示:

该类是一个抽象类,所有的数据集想要在数据与标签之间建立映射,都需要继承这个类,所有的子类都需要重写 __getitem__方法,该方法根据索引值获取每一个数据并且获取其对应的Label,子类也可以重写 __len__方法,返回数据集的size大小

实例:GetData类

准备工作

首先我们创建一个类,类名为GetData,这个类要继承Dataset类

class GetData(Dataset):

一般在类中首先需要写的是 __init__方法,此方法用于对象实例化,通常用来提供类中需要使用的变量,可以先不写

class GetData(Dataset):
    def __init__(self):
        pass

我们可以先写 __getitem__方法:

class GetData(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, idx):  # 默认是item,但常改为idx,是index的缩写
        pass

其中,idx是index的简称,就是一个编号,以便以后数据集获取后,我们使用索引编号访问每个数据

在实现GetData类之前,我们首先需要解决的问题就是如何读取一个图像数据,通常我们使用PIL来读取

PIL获取图像数据

我们使用 PIL来读取数据,它提供一个 Image模块,可以让我们提取图像数据,我们先导入这个模块

from PIL import Image

我们可以在Python Console中看看如何使用 Image

在Python Console中,输入代码:

from PIL import Image

将数据集放入项目文件夹,我们需要获取图片的绝对路径,选中具体的图片,右键选择Copy Path,然后选择 Absolute path(快捷键:Ctrl + Shift + C)

img_path = "D:\\DeepLearning\\dataset\\train\\ants\\0013035.jpg"

在Windows下,路径分割需要是 \\,来表示转译
也可以在字符串前面加 r 防转译

使用Image的 open方法读取图片:

img = Image.open(img_path)

可以在Python控制台看到读取出来的 img,是一个 JpegImageFile类的对象

【深度学习】PyTorch Dataset类的使用与实例分析

在图中,可以看到这个对象的一些属性,比如 size

我们查看这个属性的内容,输入以下代码:

img.size

输出:

(768, 512)

【深度学习】PyTorch Dataset类的使用与实例分析

我们可以看到此图的宽是768,高是512, __len__表示的是这个size元组的长度,有两个值,所以为 2

show方法显示图片:

img.show()

获取图片的文件名

从数据集路径中,获取所有文件的名字,存储到一个列表中

一个简单的例子(在Python Console中):

我们需要借助 os模块

import os
dir_path = "dataset/train/ants_image"
img_path_list = os.listdir(dir_path)

listdir方法会将路径下的所有文件名(包括后缀名)组成一个列表

【深度学习】PyTorch Dataset类的使用与实例分析

我们可以使用索引去访问列表中的每个文件名

img_path_list[0]
Out[14]: '0013035.jpg'

构建数据集路径

我们需要搭建数据集的路径表示,一个 根目录路径和一个具体的 子目录路径,以作为不同数据集的区分

一个简单的案例,在Python Console中输入:

root_dir = "dataset/train"
child_dir = "ants_image"

我们使用 os.path.join方法,将两个路径拼接起来,就得到了ants子数据集的相对路径

path = os.path.join(root_dir, child_dir)

path的值此时是:

path={str}'dataset/train\\ants_image'

我们有了这个数据集的路径后,就可以使用之前所讲的 listdir方法,获取这个路径中所有文件的文件名,存储到一个列表中

img_path_list = os.listdir(path)
idx = 0
img_path_list[idx]
Out[21]: '0013035.jpg'

可以看到结果与我们之前的小案例是一样的

有了具体的名字,我们还可以将这个文件名与路径进行组合,然后使用PIL获取具体的图像img对象

img_name = img_path_list[idx]
img_item_path = os.path.join(root_dir, child_dir, img_name)
img = Image.open(img_item_path)

在掌握了如何组装路径、获取路径中的文件名以及获取具体图像对象后,我们可以完善我们的 __init____getitem__方法了

完善__init__方法

在init中为啥使用self:一个函数中的变量是不能拿到另外一个函数中使用的,self可以当做类中的全局变量

class GetData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path_list = os.listdir(self.path)

很简单,就是接收实例化时传入的参数:获取根目录路径、子目录路径

然后将两个路径进行组合,就得到了目标数据集的路径

我们将这个路径作为参数传入listdir函数,从而让img_path_list中存储该目录下所有文件名(包含后缀名)

此时通过索引就可以轻松获取每个文件名

接下来,我们要使用这些初始化的信息去获取其中的每一个图片的JpegImageFile对象

完善__getitem__方法

我们在初始化中,已经通过组装数据集路径,进而通过listdir方法获取了数据集中每个文件的文件名,存入了一个列表中。

在__getitem__方法中,默认会有一个 item 参数,常命名为 idx,这个参数是一个 索引编号,用于对我们初始化中得到的文件名列表进行索引访问,我们就得到了具体的文件名,然后与根目录、子目录再次组装,得到具体数据的相对路径,我们可以通过这个路径获取到索引编号对应的数据对象本身。

这样巧妙的让索引与数据集中的具体数据对应了起来

def __getitem__(self, idx):
    img_name = self.img_path_list[idx]  # 从文件名列表中获取了文件名
    img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 组装路径,获得了图片具体的路径

获取了具体的图像路径后,我们需要使用PIL读取这个图像

def __getitem__(self, idx):
    img_name = self.img_path[idx]
    img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
    img = Image.open(img_item_path)
    label = self.label_dir
    return img, label

此处img是一个JpegImageFile对象,label是一个字符串

自此,这个函数我们就实现完成了

以后使用这个类进行实例化时,传入的参数是根目录路径,以及对应的label名,我们就可以得到一个GetData对象。

有了这个GetData对象后,我们可以直接使用索引来获取具体的图像对象(类:JpegImageFile),因为__getitem__方法已经帮我们实现了,我们只需要使用 索引即可调用__getitem__方法,会返回我们根据索引提取到的对应数据的图像对象以及其label

root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = GetData(root_dir, ants_label_dir)
bees_dataset = GetData(root_dir, bees_label_dir)
img1, label1 = ants_dataset[0]  # 返回一个元组,返回值是__getitem__方法的返回值
img2, label2 = bees_dataset[0]

完善__len__方法

__len__实现很简单

主要功能是获取数据集的长度,由于我们在初始化中已经获取了所有文件名的列表,所以只需要知道这个列表的长度,就知道了有多少个文件,也就是知道了有多少个具体的数据

def __len__(self):
    return len(self.img_path_list)

组合数据集

我们还可以将两个数据集对象进行组合,组合成一个大的数据集对象

train_dataset = ants_dataset + bees_dataset

我们看看这三个数据集对象的大小(在python Console中):

len1 = len(ants_dataset)
len2 = len(bees_dataset)
len3 = len(train_dataset)

输出:

124
121
245

我们可以看到刚好 $$124 + 121 = 245$$

而对这个组合的数据集的访问也很有意思,也同样是使用索引,0 ~ 123 都是ants数据集的内容,124 – 244 都是bees数据集的内容

img1, label1 = train_dataset[123]
img1.show()
img2, label2 = train_dataset[124]
img2.show()

完整代码

from torch.utils.data import Dataset
from PIL import Image
import os

class GetData(Dataset):

    # 初始化为整个class提供全局变量,为后续方法提供一些量
    def __init__(self, root_dir, label_dir):

        # self
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path_list = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_path_list[idx]  # 只获取了文件名
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 每个图片的位置
        # 读取图片
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)

root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = GetData(root_dir, ants_label_dir)
bees_dataset = GeyData(root_dir, bees_label_dir)
img, lable = ants_dataset[0] # 返回一个元组,返回值就是__getitem__的返回值

获取整个训练集,就是对两个数据集进行了拼接
train_dataset = ants_dataset + bees_dataset

len1 = len(ants_dataset)  # 124
len2 = len(bees_dataset)  # 121
len = len(train_dataset) # 245

img1, label1 = train_dataset[123]  # 获取的是蚂蚁的最后一个
img2, label2 = train_dataset[124]  # 获取的是蜜蜂第一个

Original: https://www.cnblogs.com/seansheep/p/16163159.html
Author: 在青青草原上抓羊
Title: 【深度学习】PyTorch Dataset类的使用与实例分析

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/608995/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • linux free命令available小于free值

    问题:前段时间在做服务器巡检时发现系统可用内存值小于空闲内存值 分析:查询网上各种资料,都说的是 available=free + buff/cache 这样一个大致计算方式,按这…

    Linux 2023年6月14日
    0154
  • 用户身份标识与账号体系实践

    互联网的账号自带备忘机制; 一、业务背景 通常在系统研发的过程中,需要不断适配各种业务场景,扩展服务的领域和能力,一般会将构建的产品矩阵划分出多条业务线,以便更好的管理; 由于各个…

    Linux 2023年6月14日
    083
  • docker 安装redis

    1、获取 redis 镜像 2、查看本地镜像 bind 127.0.0.1 #注释掉这部分,这是限制redis只能本地访问 protected-mode no #默认yes,开启保…

    Linux 2023年5月28日
    083
  • role: org.apache.maven.model.validation.ModelValidator【maven】项目创建后pom一直不能build出来还爆红【转】

    role: org.apache.maven.model.validation.ModelValidator【maven】项目创建后pom一直不能build出来还爆红 问题是因为m…

    Linux 2023年6月8日
    082
  • Linux系统调用接口

    Linux系统调用接口 进程控制 系统调用 描述 fork 创建一个新进程 clone 按指定条件创建子进程 execve 运行可执行文件 exit 终止进程 _exit 立即终止…

    Linux 2023年6月13日
    095
  • LeetCode-16. 最接近的三数之和

    题目来源 题目详情 给你一个长度为 n 的整数数组 nums和 一个目标值 target。请你从 nums 中选出三个整数,使它们的和与 target 最接近。 返回这三个数的和。…

    Linux 2023年6月7日
    093
  • python练习题:利用切片操作,实现一个trim()函数,去除字符串首尾的空格,注意不要调用str的strip()方法

    方法一: 方法二: (此方法会有一个问题,当字符串仅仅是一个空格时’ ‘,会返回return s[1:0];虽然不会报错,但是会比较奇怪。测试了下,当s=&…

    Linux 2023年6月8日
    099
  • 事务与事务隔离级别详解

    事务基本概念 一组要么同时执行成功,要么同时执行失败的SQL 语句。是数据库操作的一个执行单元。 事务开始于: 连接到数据库上,并执行一条DML 语句in sert 、update…

    Linux 2023年6月14日
    091
  • JavaWeb创建一个公共的servlet

    对于初学者来说,每次前端传数据过来就要新建一个类创建一个doget、dopost方法,其实铁柱兄在大学的时候也是这么玩的。后面铁柱兄开始认真了,就想着学习点容易的编程方式,其实说白…

    Linux 2023年6月13日
    084
  • 系统复位到操作系统启动的简要流程图

    多核下,处理器由系统复位到操作系统启动的简要流程图; 其中第一列为处理器核初始化过程, 第二列为芯片核外部分初始化过程, 第三列为设备初始化过程, 第四列为内核加 载过程, 第五列…

    Linux 2023年6月14日
    087
  • 操作系统实现-printk

    博客网址:www.shicoder.top微信:18223081347欢迎加群聊天 :452380935 这一次我们来实现最基础,也是最常见的函数 print,大家都知道这个是可变…

    Linux 2023年6月13日
    096
  • docker 启动mysql

    创建配置文件 mysqld.cnf Original: https://www.cnblogs.com/outsrkem/p/15704614.htmlAuthor: Outsrk…

    Linux 2023年6月6日
    061
  • 安装 Ubuntu 20.04 之后要做的事(持续更新中)

    以 Ubuntu 20.04 LTS 为例,在安装完操作系统后,应进行以下操作,以方便我们日常的工作。 1. SSH 远程登录相关设置 安装 Ubuntu 操作系统之后,首先应该按…

    Linux 2023年5月27日
    080
  • MySQL — 索引

    索引(Index)是高效获取数据的数据结构,就像书的目录,提高检索数据的效率。 优点:提高数据检索效率,降低数据库的 IO 成本;通过索引列对数据进行排序,降低数据排序的成本,降低…

    Linux 2023年6月8日
    066
  • Linux 服务器巡检脚本

    bash;gutter:true;</p> <h1>!/bin/bash</h1> <p>cat < $RESULTFILE …

    Linux 2023年6月13日
    089
  • 回忆我的第一个软件项目

    2009年大学毕业我去了成都,一番面试后,入职武侯区磨子桥附近的一个小型创业公司。公司的主营业务是代理销售用友或者金蝶的ERP软件,创业团队都是川大毕业的。公司的办公条件很差,两间…

    Linux 2023年6月6日
    0102
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球