Pytorch分类网络入门(MNIST)

利用Pytorch构建VGG分类网络,对MNIST(60000张1 * 28 * 28的手写体)数据集进行识别和分类。 该任务是一个已知类别数进行的分类。

1.1 导入外部数据,构造Dataset类

在constructed_datasets.py文件定义自己的Dataset类,必须继承并重写父类(torch.utils.data.Dataset)的以下两个私有成员函数。
PS:因为DataLoader只能装载三维的图像,所以__getitem__() 返回的img必须是三维的tensor,这里通过转换后返回的是C * W * H的单通道图像。

import os
import torch
import struct
import numpy as np

#读入MNIST,构造自定义Dataset类
class Constructed_mnist(torch.utils.data.Dataset):
    def __init__(self,path,filekind='train'):
        self.data_path = path

        #'train'读训练集,'t10k'读测试集
        if filekind=='train':
            imgs_path = os.path.join(path,'%s-images.idx3-ubyte'%filekind)
            labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % filekind)
        elif filekind=='t10k':
            imgs_path = os.path.join(path,'%s-images.idx3-ubyte'%filekind)
            labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % filekind)
        else:
            print("Error:filekind is only a string with 'train' or 't10k' ")

        #读入特定格式数据
        with open(labels_path, 'rb') as lbpath:
            magic, n = struct.unpack('>II',lbpath.read(8))
            labels = np.fromfile(lbpath,dtype=np.uint8)

        with open(imgs_path, 'rb') as imgpath:
            magic, num, rows, cols = struct.unpack('>IIII',imgpath.read(16))
            images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784)
        #赋值为类的属性
        self.labels = labels
        self.images = images

    def __getitem__(self, index):
        temp_label = self.labels[index]
        temp_image = self.images[index]
        #####读入的数据是1个1*784的一维numpy,转换为1个1*28*28的三维tensor######
        temp_image = torch.tensor(temp_image.reshape(28,28),dtype=torch.float32)[None,...]

        return temp_label,temp_image

    def __len__(self):
        return len(self.labels)

1.2 DataLoader函数装载自定义Dataset

在进行数据集装载时,需要用到torch.utils.data.DataLoader函数。 装载的目的主要是把读入的数据进行分批,为训练做准备。DataLoader出来的单个数据一般是batch_size * C * W * H四维的tensor

my_mnist = Constructed_mnist(path,filekind)
#将导入的数据按照训练的batch_size要求进行分批
dataloader = torch.utils.data.DataLoader(dataset=my_mnist, batch_size=16, shuffle=True)

在mymodels.py文件构建VGG

2.1 定义模型结构

定义自己的Module类,必须继承父类(torch.nn.Module),在定义自己的Module类时,需要自己重新定义以下两个函数:

class Models(torch.nn.Module):
    def __init__(self):
    #TODO
    #调用父类的初始化方法
    #利用torch.nn.Sequential()函数定义神经网络模块
    def forward(self):
    #TODO
    #定义前向传播的过程,返回传播后的结果。

EXAMPLE

import torch

class VGG16(torch.nn.Module):
    def __init__(self):
        super(VGG16,self).__init__()
        self.Conv = torch.nn.Sequential(
            torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),  #1需要修改为实际通道数
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size= 2,stride=2),

            torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),

            torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),

            torch.nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.Classes = torch.nn.Sequential(
            torch.nn.Linear(512*1*1, 1024),  #512*1*1需要修改为forward中view一致
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(1024,1024),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(1024,10)   #10需要修改为实际分类出的数量
        )

    #x的shape决定了view的第二个值,view相当于reshape
    def forward(self, input):
        x = self.Conv(input)
        # print(x.shape)
        x = x.view(-1,512*1*1)
        x = self.Classes(x)
        return  x

2.1 定义损失函数和优化器

#实例化损失函数CrossEntropyLoss类
loss_f = torch.nn.CrossEntropyLoss()
#实例化优化器Adam类
optimizer = torch.optim.Adam(model.parameters(),lr = 0.00001)

3.1 定义训练网络

训练是在每次迭代下, 循环取出每一个batch进行前向运算,反馈loss,根据梯度调整参数;再下一个batch继续前向和反馈(调整参数)。
PS:喂入model(x)进行前向计算前,一定要把输入图x转换为Variable。

def mytrain(dataloader):

    mymodel = VGG16()
    mymodel.train()

    # 检查cuda是否可用,可用则配置cuda
    print(torch.cuda.is_available())
    if torch.cuda.is_available():
        mymodel = mymodel.cuda()

    #配置损失函数和优化策略
    loss_f = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(mymodel.parameters(), lr=0.00001)

    epoch_n = 2
    time_start = time.time()

    # 开始train
    for epoch in range(epoch_n):
        print('epoch:{}.....................'.format(epoch))
        running_loss = 0.0
        running_corrects = 0

        for batch, data in enumerate(dataloader, 1):
            y, x = data
            y = y.type(torch.LongTensor)
            if torch.cuda.is_available():
                x, y = Variable(x.cuda()), Variable(y.cuda())
            else:
                x, y = Variable(x), Variable(y)

            y_pred = mymodel(x)
            _, pred = torch.max(y_pred.data, 1)  # 取预测出来的标签

            optimizer.zero_grad()
            loss = loss_f(y_pred, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_corrects += torch.sum(pred == y.data)  # 用预测标签和真实标签比对,统计正确数

            batch_loss = running_loss
            batch_acc = running_corrects.float() * 100 / (16 * batch)

            if batch % 10 == 0:
                print("batch:{}---train loss:{:.4f}---train acc:{:.4f}%"
                                  .format(batch, batch_loss, batch_acc))

        # 结束一个epoch后统计改次迭代优化得到的参数对应的准确率
        epoch_loss = running_loss * 16
        epoch_acc = 100 * running_corrects.float()
        print("train (loss:{:.4f}---acc:{:.4f})".format(epoch_loss, epoch_acc))
        print("epoch time:{:.4f}".format(time.time()-time_start))

    torch.save(mymodel, 'model.pkl')

3.2 定义测试网络

测试网络中 不需要DataLoader,因为无需分成batch。所以直接读入测试集,把单个测试图通过前向计算得的 预测矩阵(N个类别的概率矩阵),取出概率最大的作为预测结果,与标签比对。 最后,通过很多次的比对,统计得到准确率。
PS:(1)由于网络需要的是batch * C * W * H的四维tensor,但是没通过DataLoader的测试数据是三维的,因此需要x = x.unsqueeze(0),增加一个维度(0表示加在第1个维度前)。
(2)train里面采用model.train(),test里面采用model.eval()。是否开启BN和Dropout层

def mytest(test_mnists):
    mymodel = torch.load('model.pkl')
    mymodel.eval()

    for y,x in test_mnists:
        x = x.unsqueeze(0)
        # 检查cuda是否可用,可用则配置cuda
        if torch.cuda.is_available():
            mymodel.cuda()
            x = Variable(x.cuda())
        else:
            x = Variable(x)

        y_pred_prox = mymodel(x)
        _, y_pred = torch.max(y_pred_prox.data, 1)
        print(y_pred.numpy())
        img = x.numpy()[0][0]
        plt.title('y_pred = %i' %y_pred.numpy())
        plt.imshow(img)
        plt.show()

4.1 主函数

主函数中选择’train’表示训练,’t10k’表示测试. 默认是train。
PS:由于整个程序中有太多数据集相关的参数,这些参数属性都相同(比如:分辨率,通道等),但值都不一样,所以 为了防止一个程序适用一种数据集,可以将数据集的参数通过config.json 配置文件进行配置,然后通过argparse和json.open解析出来,作为全局变量,整个程序可用*。

import matplotlib.pyplot as plt
import torch
import torchvision
import time
from torch.autograd import Variable

from mymodels import VGG16
from constructed_datasets import Constructed_mnist

if __name__ == '__main__':
    path = "E:\\MyLocGetHub\\MNIST\\datasets\\"
    filekind = 't10k'     #'train'表示训练,'t10k'表示测试.  默认是train
    my_mnist = Constructed_mnist(path,filekind)
    #将导入的数据按照训练的batch_size要求进行分批
    dataloader = torch.utils.data.DataLoader(dataset=my_mnist, batch_size=16, shuffle=True)
    if filekind == 'train':
        mytrain(dataloader)
    elif filekind == 't10k':
        mytest(my_mnist)
    else:
        print("Error:filekind is only a string with 'train' or 't10k' ")

4.2 参数配置(自行练习)

用到了config = json.load(open(args.config)),只需要主函数中手动配置args.config参数,后续全自动。


parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('-c', '--config', default='configs\config_PANNET.json',type=str,
                        help='Path to the config file')
args = parser.parse_args()

config = json.load(open(args.config))

Original: https://blog.csdn.net/yeen123/article/details/124470671
Author: yeen123
Title: Pytorch分类网络入门(MNIST)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/663152/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球