LeNet模型对CIFAR-10数据集分类【pytorch】

LeNet模型对CIFAR-10数据集分类【pytorch】

目录

本文为针对CIFAR-10数据集的基于简单神经网络LeNet分类实现(pytorch实现)

LeNet 网络模型

LeNet模型对CIFAR-10数据集分类【pytorch】
Tip:(以上为原始LeNet)为了更好的效果,我在模型实现时此处将池化层换为Max

; CIFAR-10 数据集

CIFAR-10数据集由60000张32×32的彩色图像组成,分为10类,每类有6000张图像。有50000张训练图像和10000张测试图像。

该数据集被分为五个训练批和一个测试批,每个批有10000张图像。测试批包含从每个类中随机选择的1000张图像。训练批包含其余的随机顺序的图像,但有些训练批可能包含一个类别的图像多于另一个。在它们之间,训练批次恰好包含了每个类别的5000张图像。

下面是数据集中的类别,以及每个类别的10张随机图像。

LeNet模型对CIFAR-10数据集分类【pytorch】

关于数据集更多详情请见:CIFAR-10数据集官方说明

Pytorch 实现代码

import torch
from torch import  nn
import torch.nn.functional as F
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms

class Lenet5(nn.Module):

    def __init__(self,input_channels):
        super().__init__()

        self.conv1 = nn.Conv2d(input_channels , 6 , kernel_size = 5 , padding = 2)

        self.pooling1 = nn.MaxPool2d(kernel_size = 2, stride = 2)

        self.conv2= nn.Conv2d(6 , 16 , kernel_size=5)

        self.pooling2 = nn.MaxPool2d(kernel_size = 2, stride=2)

        self.Flatten = nn.Flatten()

        self.Linear1 = nn.Linear(16*6*6,120)
        self.Linear2 = nn.Linear(120,84)
        self.Linear3 = nn.Linear(84,10)

    def forward(self,X):
    ''' 前向推导 '''
        X = self.pooling1(F.relu(self.conv1(X)))
        X = self.pooling2(F.relu(self.conv2(X)))
        X = X.view(X.size()[0],-1)
        X = F.relu(self.Linear1(X))
        X = F.relu(self.Linear2(X))
        X = F.relu(self.Linear3(X))

        return X

def load_CIFAR10(batch_size, resize=None):
    """ 加载数据集到内存 """
    trans = [transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    if resize:
        trans.insert(0, transforms.Resize(resize))

    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.CIFAR10(root="dataset",
                                                    train=True,
                                                    transform=trans,
                                                    download=True)
    mnist_test = torchvision.datasets.CIFAR10(root="dataset",
                                                   train=False,
                                                   transform=trans,
                                                   download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=2),
                data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=2))

def get_labels(labels):
    '''    标签转换  '''
    text_labels = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]

def train(loss,updater,train_iter,net,epoches):
    '''  训练模型  '''
    for epoch in range(epoches):
        run_loss = 0
        for step,(X,y) in enumerate(train_iter):
            if torch.cuda.is_available():
                X = X.cuda()
                y = y.cuda()
            y_hat = net.forward(X)
            ls = loss(y_hat,y).sum()
            updater.zero_grad()
            ls.backward()
            run_loss += ls.item()
            updater.step()
        print( f'true:{y} preds:{y_hat.argmax(axis=1)} epoch:{epoch:02d}\t epoch_loss {run_loss/5000}\t ')
    print('finished training\n')

def predict(net,test_iter,n=6):
    '''   测试集预测 '''
    for X, y in test_iter:
        if torch.cuda.is_available():
            X = X.cuda()
            y = y.cuda()
        trues = get_labels(y)
        preds = get_labels(net(X).argmax(axis=1))
        titles = ['groundTruth :'+true + ' ' +'preds: '+ pred for true, pred in zip(trues, preds)]
        print(titles[0:n])

if __name__ == '__main__':

    batch_size,  learning_rate,  epoches = 10, 0.05, 1

    trainSet,testSet = load_CIFAR10(batch_size)

    net = Lenet5(3)
    if torch.cuda.is_available():
        net.cuda()

    loss = nn.CrossEntropyLoss()

    updater = torch.optim.SGD(net.parameters(), lr=learning_rate)

    train(loss,updater,trainSet,net,batch_size,epoches,learning_rate)

    predict(net,testSet)

Original: https://blog.csdn.net/qq_45810349/article/details/118967362
Author: LA-AL
Title: LeNet模型对CIFAR-10数据集分类【pytorch】

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/665838/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球