文章目录
引入
在学习pytorch的过程中,用的一直都是教程中别人定义好从网上直接下载的数据集,不需要进行任何的处理,数据和标号都可以直接获取。但是,我想要进行自己的研究大多数情况需要我们自己收集数据并进行一些预处理在制作成数据集,然后通过pytorch读入后用来训练模型。这里记录的是一次对上万张验证码图片组成的数据集(标号是其名称)制作pytorch数据集的尝试。
部分数据如下:
大多数教程中并没有讲这些图片数据和标签是如何装载到torch中的,在分析了一个github项目https://github.com/braveryCHR/CNN_captcha 后我大概了解如何装载数据。
; 方法
如果我们需要利用pytorch装载数据以及标签,我们就必须自己写一个dataset类,该类要继承data.Dataset类,该类在torch.utils中,并实现该类的_getitem_和_len_方法。
示例:
为了实现将验证码分类,我们先定义label和字符互相转换的函数:
import os
import torch
from PIL import Image
from torch.utils import data
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms as T
def StrToLabel(Str):
label = []
for i in range(0, charNumber):
if '0' Str[i] '9':
label.append(ord(Str[i]) - ord('0'))
elif 'a' Str[i] 'z':
label.append(ord(Str[i]) - ord('a') + 10)
else:
label.append(ord(Str[i]) - ord('A') + 36)
return label
def LabelToStr(Label):
Str = ""
for i in Label:
if i 9:
Str += chr(ord('0') + i)
elif i 35:
Str += chr(ord('a') + i - 10)
else:
Str += chr(ord('A') + i - 36)
return Str
接下来是数据集合类的定义
class Captcha(data.Dataset):
def __init__(self, root, train=True):
self.imgPath = [os.path.join(root, img) for img in os.listdir(root)]
self.transform = T.Compose([
T.Resize((150, 30)),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def __getitem__(self, index):
img_path = self.imgPath[index]
label = img_path.split('\\')[-1].split('.')[0]
label_tensor = torch.Tensor(StrToLabel(label))
data=Image.open(img_path)
data = self.transform(data)
return data, label_tensor
def __len__(self):
return len(self.imgPath)
在init中的transform是预处理的定义。
getitem方法用来返回读取的图片数据和该图片的参数,我们将图片文件名获取到并转换为tensor,再使用PIL模块中的Image.open()读取图片数据,之后通过预处理transform转为tensor对象,最后返回图片数据data和图片标签label_tensor就可以了。
len函数返回文件中图片的数量。
dataloader会根据len读取文件中所有图片,每次读取图片的方法就是getitem中定义的方法。
测试
我们来使用一下这个Capthca类,看看能否正确读取图片数据data以及其标号label
import os.path
import torch
import torchvision
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
img_data = Captcha("./data/train/train", train=True)
trainDataLoader = DataLoader(img_data, batch_size=1,
shuffle=False, num_workers=4)
if __name__ == '__main__':
it = trainDataLoader.__iter__()
data, label = it.next()
print(data)
print(label)
print(LabelToStr(int(x)for x in label.squeeze().tolist()))
由于在jupyter中运行该代码会报错所以我放上在pycharm上的运行结果:
总结
想要使用自己定义的数据集就必须实现一个dataset,使得dataloader知道如何获取数据以及标签。
Original: https://blog.csdn.net/weixin_46919419/article/details/123674117
Author: LiterMa
Title: 在pytorch中使用自己的数据集,dataset的写法
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/710244/
转载文章受原作者版权保护。转载请注明原作者出处!