基于Pytorch的cifar-10图像分类问题代码实现

之前在学习深度学习图片分类任务的时候,跟着老师的讲解实现了一个Cifair-10的图像分类任务。

数据集地址:网盘地址需要解码。
数据有50000张训练图片和10000张测试图片。
下载好数据后,在文件夹下新建两个文件夹一个为Train,一个为Test,用来保存解码后的图片。如图:

基于Pytorch的cifar-10图像分类问题代码实现
解码后的训练集:
基于Pytorch的cifar-10图像分类问题代码实现
基于Pytorch的cifar-10图像分类问题代码实现

文件的解码方式官方已经给出:

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding="bytes")
    return dict

解码训练集的py文件的具体代码如下:

import os
import pickle
import glob
import cv2
import numpy as np

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding="bytes")
    return dict
lable_name = ["airplane",
              "automobile",
              "bird",
              "cat",
              "deer",
              "dog",
              "frog",
              "horse",
              "ship",
              "truck"]

train_list = glob.glob("D:\*\cifar-10-python\cifar-10-batches-py\data_batch_*")
print(train_list)
save_path = "D:\*\cifar-10-python\cifar-10-batches-py\Train"

for l in train_list:
    print(l)
    l_dict = unpickle(l)

    print(l_dict)
    print(l_dict.keys())

    for im_idx, im_data in enumerate(l_dict[b'data']):
        print(im_idx)
        print(im_data)

解码测试集的py文件的具体代码如下:

import os
import pickle

import cv2
import numpy as np

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding="bytes")
    return dict
lable_name = ["airplane",
              "automobile",
              "bird",
              "cat",
              "deer",
              "dog",
              "frog",
              "horse",
              "ship",
              "truck"]

import glob

train_list = glob.glob("D:\genglijia\cifar-10-python\cifar-10-batches-py\Test_batch")
print(train_list)
save_path = "D:\genglijia\cifar-10-python\cifar-10-batches-py\Test"

for l in train_list:
    print(l)
    l_dict = unpickle(l)

    print(l_dict)
    print(l_dict.keys())

    for im_idx, im_data in enumerate(l_dict[b'data']):
        print(im_idx)
        print(im_data)

        im_lable = l_dict[b'labels'][im_idx]
        im_name = l_dict[b'filenames'][im_idx]

        print(im_lable, im_name, im_data)

        im_lable_name = lable_name[im_lable]
        im_data = np.reshape(im_data, [3, 32, 32])
        im_data = np.transpose(im_data, (1, 2, 0))

        if not os.path.exists("{}/{}".format(save_path,
                                             im_lable_name)):
                    os.mkdir("{}/{}".format(save_path,
                                             im_lable_name))
        cv2.imwrite("{}/{}/{}".format(save_path,
                                   im_lable_name,
                                   im_name.decode("utf-8")),
                                   im_data)

加载本地数据集的py文件:

import glob
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
import os
from PIL import Image
import numpy as np

lable_name = ["airplane","automobile","bird",
              "cat","deer","dog","frog",
              "horse","ship","truck"]

lable_dict = {}

for idx, name in enumerate(lable_name):
    lable_dict[name] = idx
print(lable_dict)
def default_loader(path):
    return Image.open(path).convert("RGB")

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

test_transform = transforms.Compose([
    transforms.CenterCrop((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

class MyDataset(Dataset):

    def __init__(self, im_list, transform=None, loader=default_loader):
        super(MyDataset, self).__init__()
        imgs = []
        for im_item in im_list:
            im_lable_name = im_item.split("\\")[-2]
            imgs.append([im_item, lable_dict[im_lable_name]])

        self.imgs = imgs
        self.transform = transform
        self.loader = loader

    def __getitem__(self, index):
        im_path,im_lable = self.imgs[index]
        im_data = self.loader(im_path)
        if self.transform is not None:
            im_data = self.transform(im_data)
        return im_data, im_lable

    def __len__(self):
        return len(self.imgs)

im_train_list = glob.glob("D:\*\cifar-10-python\cifar-10-batches-py\Train\*\*.png")
im_test_list = glob.glob("D:\*\cifar-10-python\cifar-10-batches-py\Test\*\*.png")

train_dataset = MyDataset(im_train_list, transform=train_transform)
test_dataset = MyDataset(im_test_list, transform=transforms.ToTensor())

train_loader = DataLoader(dataset=train_dataset,
                            batch_size=64,
                            shuffle=True,
                            num_workers=0)

test_loader = DataLoader(dataset=test_dataset,
                            batch_size=64,
                            shuffle=False,
                            num_workers=0)
print("num_of_train", len(train_dataset))
print("num_of_test", len(test_dataset))

定义网络结构的py文件(用的是经典resnet残差网络机构,也可以用其他的网络结构例如:vggnet、mobilenet等):

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, stride=1):
        super(ResBlock, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel,
                      kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(),
            nn.Conv2d(out_channel, out_channel,
                      kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channel),
        )
        self.shortcut = nn.Sequential()
        if in_channel != out_channel or stride > 1:

            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel,
                          kernel_size=3, stride=stride, padding=1),
                nn.BatchNorm2d(out_channel),
            )

    def forward(self, x):
        out1 = self.layer(x)
        out2 = self.shortcut(x)
        out = out1 + out2
        out = F.relu(out)
        return out

class ResNet(nn.Module):

    def make_layer(self, block, out_channel, stride, num_block):
        layers_list = []
        for i in range(num_block):
            if i == 0:
                in_stride = stride
            else:
                in_stride = 1
            layers_list.append(block(self.in_channel,out_channel, in_stride))
            self.in_channel = out_channel
        return nn.Sequential(*layers_list)

    def __init__(self, ResBlock):
        super(ResNet, self).__init__()
        self.in_channel = 32
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.layer1 = \
            self.make_layer(ResBlock, 64, 2, 2)

        self.layer2 = \
            self.make_layer(ResBlock, 128, 2, 2)

        self.layer3 = \
            self.make_layer(ResBlock, 256, 2, 2)

        self.layer4 = \
            self.make_layer(ResBlock, 512, 2, 2)

        self.fc = nn.Linear(512, 10)

    def forward(self,x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def resnet():
    return ResNet(ResBlock)

训练和测试的py文件:

import torch
import torch.nn as nn
import torchvision
from resnet import resnet
from load_cifar10 import train_loader,test_loader
import os

epoch_num = 1
lr = 0.01
batch_size = 128
net = resnet()

loss_func = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(net.parameters(), lr=lr)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

if __name__ == '__main__':
    for epoch in range(epoch_num):
        net.train()

        for i, data in enumerate(train_loader):
            inputs, labels = data

            outputs = net(inputs)
            loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch_size = inputs.size(0)
            _, pred = torch.max(outputs.data, dim=1)
            correct = pred.eq(labels.data).sum()
            print("train step", i, "loss is:", loss.item(), "mini-batch correct is:", 1.0 * correct / batch_size)

        if not os.path.exists("models"):
            os.mkdir("models")
        torch.save(net.state_dict(),"models/{}.pth".format(epoch+1))
        scheduler.step()

        sum_loss = 0
        sum_correct = 0

        for i, data in enumerate(test_loader):
            net.eval()
            inputs, labels = data

            outputs = net(inputs)
            loss = loss_func(outputs, labels)
            _, pred = torch.max(outputs.data, dim=1)
            correct = pred.eq(labels.data).sum()

            sum_loss += loss.item()
            sum_correct += correct.item()

            im = torchvision.utils.make_grid(inputs)

        test_loss = sum_loss*1.0/len(test_loader)
        test_correct = sum_correct *1.0/len(test_loader)/batch_size

        print("epoch", epoch+1, "loss is:", test_loss, "mini-batch correct is:", test_correct)

测试的结果:

基于Pytorch的cifar-10图像分类问题代码实现

这里面epoch为1,因为没有GPU训练的太慢就只训练了一次,正确率也有了百分之70左右,有条件的话多训练几次应该会达到更高。也可以改网络结构啥的,方法很多。

完成后可以在Test的文件夹中看到已经分类好的图片。

有什么疑问可以在下面评论哦~大家一起加油学习!!!

Original: https://blog.csdn.net/weixin_44250159/article/details/124518562
Author: 啊砉
Title: 基于Pytorch的cifar-10图像分类问题代码实现

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/641074/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球