前言
目前刚刚接触深度学习方向,也在学习pytorch框架。本文是我在尝试相关网络的pytorch框架时遇到的一些问题以及认为有必要总结一下的内容。
此内容主要参考了以下博客:https://blog.csdn.net/m0_37867091/article/details/107150142
数据预处理
在网络开始训练之前,为了使训练更好的进行,我们需要对训练进行一些预处理操作。在pytorch中是由torchvision.transforms来操作的,torchvision.transforms中包含了一些常见的操作。以下是目前见到常用的几种:
transforms.Compose可以用来将多种操作集合到一起,打包了多个图片处理的方法,如:
transforms.Compose([
transforms.CenterCrop(10),
transforms.ToTensor(),
])
transforms.ToTensor() 将 shape
为 (H, W, C)
的 nump.ndarray
或 img
转为 shape
为 (C, H, W)
的 tensor
,其将每一个数值归一化到 [0,1]
,其归一化方法比较简单,直接除以255即可。
transforms.Normalize()其作用就是先将输入归一化到 (0,1)
,再使用公式 "(x-mean)/std"
,将每个元素分布到 (-1,1)。
torchvision
是pytorch的一个图形库,它服务于PyTorch 深度学习框架的。其构成如下:
torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
torchvision.utils: 其他的一些有用的方法。
原文链接:https://blog.csdn.net/wangkaidehao/article/details/104520022/
数据集
各种网络模型的训练都离不开数据集的支持,当我们针对某个数据集时,往往是两种导入方法:1.pytorch内置的torchvision.datasets函数进行在线导入相关的数据集
2.导入个人制作的数据集
参考:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/data_set/README.md
个人的数据集需要划分为训练集、测试集两部分,下面是对数据集进行分类的脚本:
import os
from shutil import copy, rmtree
import random
def mk_file(file_path: str):
if os.path.exists(file_path):
# 如果文件夹存在,则先删除原文件夹在重新创建
rmtree(file_path)
os.makedirs(file_path)
def main():
# 保证随机可复现
random.seed(0)
# 将数据集中10%的数据划分到验证集中
split_rate = 0.1
# 指向你解压后的flower_photos文件夹
cwd = os.getcwd()
data_root = os.path.join(cwd, "flower_data")
origin_flower_path = os.path.join(data_root, "flower_photos")
assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)
flower_class = [cla for cla in os.listdir(origin_flower_path)
if os.path.isdir(os.path.join(origin_flower_path, cla))]
# 建立保存训练集的文件夹
train_root = os.path.join(data_root, "train")
mk_file(train_root)
for cla in flower_class:
# 建立每个类别对应的文件夹
mk_file(os.path.join(train_root, cla))
# 建立保存验证集的文件夹
val_root = os.path.join(data_root, "val")
mk_file(val_root)
for cla in flower_class:
# 建立每个类别对应的文件夹
mk_file(os.path.join(val_root, cla))
for cla in flower_class:
cla_path = os.path.join(origin_flower_path, cla)
images = os.listdir(cla_path)
num = len(images)
# 随机采样验证集的索引
eval_index = random.sample(images, k=int(num*split_rate))
for index, image in enumerate(images):
if image in eval_index:
# 将分配至验证集中的文件复制到相应目录
image_path = os.path.join(cla_path, image)
new_path = os.path.join(val_root, cla)
copy(image_path, new_path)
else:
# 将分配至训练集中的文件复制到相应目录
image_path = os.path.join(cla_path, image)
new_path = os.path.join(train_root, cla)
copy(image_path, new_path)
print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
print()
print("processing done!")
if __name__ == '__main__':
main()
其中文件夹的名称根据自己的数据集进行替换。
导入、加载数据
对于在 torchvision图形库中在线导入的数据集代码如下:
导入训练集
train_set = torchvision.datasets.CIFAR10(root=’./data’, # 数据集存放目录
train=True, # 表示是数据集中的训练集
download=True, # 第一次运行时为True,下载数据集,下载完成后改为False
transform=transform) # 预处理过程
加载训练集
train_loader = torch.utils.data.DataLoader(train_set, # 导入的训练集
batch_size=50, # 每批训练的样本数
shuffle=False, # 是否打乱训练集
num_workers=0) # num_workers在windows下设置为0
对于个人划分的数据集代码如下:
获取图像数据集的路径
data_root = os.path.abspath(os.path.join(os.getcwd(), “../..”)) # get data root path
image_path = data_root + “/data_set/flower_data/” # flower data_set path
导入训练集并进行预处理
train_dataset = datasets.ImageFolder(root=image_path + “/train”,
transform=data_transform[“train”])
train_num = len(train_dataset)
按batch_size分批次加载训练集
train_loader = torch.utils.data.DataLoader(train_dataset, # 导入的训练集
batch_size=32, # 每批训练的样本数
shuffle=True, # 是否打乱训练集
num_workers=0) # 使用线程数,在windows下设置为0
Original: https://blog.csdn.net/weixin_45929203/article/details/123276387
Author: 不要瞎搞
Title: Pytorch框架训练时的数据预处理、数据集以及导入、加载数据
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/709841/
转载文章受原作者版权保护。转载请注明原作者出处!