Pytorch框架训练时的数据预处理、数据集以及导入、加载数据

前言

目前刚刚接触深度学习方向,也在学习pytorch框架。本文是我在尝试相关网络的pytorch框架时遇到的一些问题以及认为有必要总结一下的内容。

此内容主要参考了以下博客:https://blog.csdn.net/m0_37867091/article/details/107150142​​​​​​

数据预处理

在网络开始训练之前,为了使训练更好的进行,我们需要对训练进行一些预处理操作。在pytorch中是由torchvision.transforms来操作的,torchvision.transforms中包含了一些常见的操作。以下是目前见到常用的几种:

transforms.Compose可以用来将多种操作集合到一起,打包了多个图片处理的方法,如:

transforms.Compose([
transforms.CenterCrop(10),
transforms.ToTensor(),
])

transforms.ToTensor() 将 shape(H, W, C)nump.ndarrayimg转为 shape(C, H, W)tensor,其将每一个数值归一化到 [0,1],其归一化方法比较简单,直接除以255即可。

transforms.Normalize()其作用就是先将输入归一化到 (0,1),再使用公式 "(x-mean)/std",将每个元素分布到 (-1,1)。

torchvision是pytorch的一个图形库,它服务于PyTorch 深度学习框架的。其构成如下:
torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
torchvision.utils: 其他的一些有用的方法。

原文链接:https://blog.csdn.net/wangkaidehao/article/details/104520022/

数据集

各种网络模型的训练都离不开数据集的支持,当我们针对某个数据集时,往往是两种导入方法:1.pytorch内置的torchvision.datasets函数进行在线导入相关的数据集

Pytorch框架训练时的数据预处理、数据集以及导入、加载数据

2.导入个人制作的数据集

参考:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/data_set/README.md

个人的数据集需要划分为训练集、测试集两部分,下面是对数据集进行分类的脚本:

import os
from shutil import copy, rmtree
import random

def mk_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在,则先删除原文件夹在重新创建
        rmtree(file_path)
    os.makedirs(file_path)

def main():
    # 保证随机可复现
    random.seed(0)

    # 将数据集中10%的数据划分到验证集中
    split_rate = 0.1

    # 指向你解压后的flower_photos文件夹
    cwd = os.getcwd()
    data_root = os.path.join(cwd, "flower_data")
    origin_flower_path = os.path.join(data_root, "flower_photos")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)

    flower_class = [cla for cla in os.listdir(origin_flower_path)
                    if os.path.isdir(os.path.join(origin_flower_path, cla))]

    # 建立保存训练集的文件夹
    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(train_root, cla))

    # 建立保存验证集的文件夹
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(val_root, cla))

    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path, cla)
        images = os.listdir(cla_path)
        num = len(images)
        # 随机采样验证集的索引
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                # 将分配至验证集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                # 将分配至训练集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()

    print("processing done!")

if __name__ == '__main__':
    main()

其中文件夹的名称根据自己的数据集进行替换。

导入、加载数据

对于在 torchvision图形库中在线导入的数据集代码如下:

导入训练集
train_set = torchvision.datasets.CIFAR10(root=’./data’, # 数据集存放目录
train=True, # 表示是数据集中的训练集
download=True, # 第一次运行时为True,下载数据集,下载完成后改为False
transform=transform) # 预处理过程
加载训练集
train_loader = torch.utils.data.DataLoader(train_set, # 导入的训练集
batch_size=50, # 每批训练的样本数
shuffle=False, # 是否打乱训练集
num_workers=0) # num_workers在windows下设置为0

对于个人划分的数据集代码如下:

获取图像数据集的路径
data_root = os.path.abspath(os.path.join(os.getcwd(), “../..”)) # get data root path
image_path = data_root + “/data_set/flower_data/” # flower data_set path

导入训练集并进行预处理
train_dataset = datasets.ImageFolder(root=image_path + “/train”,
transform=data_transform[“train”])
train_num = len(train_dataset)

按batch_size分批次加载训练集
train_loader = torch.utils.data.DataLoader(train_dataset, # 导入的训练集
batch_size=32, # 每批训练的样本数
shuffle=True, # 是否打乱训练集
num_workers=0) # 使用线程数,在windows下设置为0

Original: https://blog.csdn.net/weixin_45929203/article/details/123276387
Author: 不要瞎搞
Title: Pytorch框架训练时的数据预处理、数据集以及导入、加载数据

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/709841/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球