pytorch 搭建AlexNet 对花进行分类

目录

1. 介绍

2. 搭建AlexNet网络

3. 准备数据集

4. 训练网络

5. 预测图片

6. code

文章内容参考:霹雳吧啦Wz 的视频教程
代码的讲解可以参考之前的文章:pytorch 搭建 LeNet 网络对 CIFAR-10 图片分类

  1. 介绍

AlexNet 网络的结构为:

pytorch 搭建AlexNet 对花进行分类

卷积层的计算公式为:

pytorch 搭建AlexNet 对花进行分类

通过计算,可以得到网络之间的参数为:

pytorch 搭建AlexNet 对花进行分类
  1. 搭建AlexNet网络

之前介绍过,卷积层相当于特征提取、全连接层相当于分类器

所以这里分开搭建不同的模块

特征提取的部分:

pytorch 搭建AlexNet 对花进行分类
  • 这里因为训练的数据较少,所以卷积核的数目都降为了一半
  • Sequential 是一个特殊的Module ,它包含了几个子module,前向传播时会将输入一层接着一层的传递下去
  • nn.ReLU(inplace = True) , ReLU有个inplace参数,设置为Ture 的时候,它会把输出之间覆盖到输入中,这样可以节省资源。因为ReLU计算反向传播的时候,只需要根据输出就能反推出反向传播的梯度。(ReLU 反向传来的梯度会传给输入为正的部分)

分类的部分:

pytorch 搭建AlexNet 对花进行分类
  • AlexNet 使用了dropout 随机失活来防止过拟合,这是针对于全连接层而言的,所以要在每个全连接层的前面加上Dropout
  • num_classes 是最后分类的个数

前向传播的部分:

pytorch 搭建AlexNet 对花进行分类
  • 因为出了特征提取层后,数据的size为(128,6,6),具体来说是n1286*6,这里的n是batch_size ,所以我们只对后面三个维度做 flatten ,例如:

打印的网络结构为:

pytorch 搭建AlexNet 对花进行分类
  1. 准备数据集

这里对五个不同属性的花做分类,都放在了flower_data 下,分别是:雏菊、蒲公英、玫瑰、向日葵、郁金香

通过split_data 文件对flower_data 划分出训练集和验证集,比例为9:1

这里的目录顺序不能错

pytorch 搭建AlexNet 对花进行分类
  1. 训练网络

因为大部分代码是重合的,所以只做少量介绍,具体的可以看之前的文章:pytorch 搭建 LeNet 网络对 CIFAR-10 图片分类

首先是定义数据预处理函数,针对训练集和验证集定义不同的预处理

这里因为样本不够,所以随机裁剪以及随机翻转做数据增强

ToTensor 是归一化和改变通道顺序

pytorch 搭建AlexNet 对花进行分类

然后是载入数据,训练集和验证集等等

pytorch 搭建AlexNet 对花进行分类

接下来显示一下图片,需要把validate_loader 里面的batch_size 改成4,然后将shuffle 改为True,否则全部都是一种类型的图片

pytorch 搭建AlexNet 对花进行分类

打印的label 为

pytorch 搭建AlexNet 对花进行分类

显示的图像为:

pytorch 搭建AlexNet 对花进行分类

接下来实例化网络和定义优化器:

pytorch 搭建AlexNet 对花进行分类

这里因为网络结构较大,训练时间长,可以设置一个准确率最好的参数(best_acc)用来实时保存最好准确率的那个权重参数

然后开始训练网络:

pytorch 搭建AlexNet 对花进行分类

这里的net.train() 可以管理dropout方法,相当于开启dropout

最后就是计算准确率:

pytorch 搭建AlexNet 对花进行分类

net.eval() 用来关闭dropout 方法

最后保存网络的时候,我们根据best_acc 去保存最优的那个网络参数

torch.max 那里,dim = 1代表对第一个维度求取最大值,保留第零个维度,因为第零个维度是batch_size。然后后面的[1]因为torch.max会返回值、索引,这里我们只需要索引

  1. 预测图片

预测的代码基本上没有变化,就是满足几个步骤即可

  1. 将下载的图片进行预处理,这里的预处理要和之前的训练的预处理一样,并且要多一个将size改变成标准的输入size

  2. 增加维度,因为图片是3通道的,而我们输入多了一个batch_size 维度,所以通过unsqueeze增加一个维度

  3. 加载网络参数

  4. 做预测,读取最大的那个预测概率

pytorch 搭建AlexNet 对花进行分类
  1. code

搭建AlexNet 网络结构:

import torch.nn as nn
import torch

class AlexNet(nn.Module):       # 继承nn.Module 父类
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        # 提取图像的特征
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # (input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),        # 会自动舍去小数部分,将最后一行和一列舍去,等价于左补2,右补1,上补2,下补1
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        # 对特征分类
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),          # 随机失活,对全连接层操作
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)   # batch * channel * height * width ,第一个batch不变
        x = self.classifier(x)
        return x

net = AlexNet()
print(net)

训练网络部分:

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪成224 *224
                                transforms.RandomHorizontalFlip(), # 水平方向随机翻转
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

train_dataset = datasets.ImageFolder(root="./flower_data/train",transform=data_transform["train"])   # 读取训练集
validate_dataset = datasets.ImageFolder(root='./flower_data/val',transform=data_transform["val"])     # 读取验证集

classes = ("daisy", "dandelion", "roses", "sunflowers", "tulips")   # 雏菊、蒲公英、玫瑰、向日葵、郁金香

batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)     # 载入训练集
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=0)   # 载入验证集

#显示图像 code
test_data_iter = iter(validate_loader)
test_image, test_label = test_data_iter.next()
#
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
#
print(' '.join('%5s' % classes[test_label[j].item()] for j in range(4)))
imshow(utils.make_grid(test_image))

net = AlexNet(num_classes=5)         # 实例化网络
loss_function = nn.CrossEntropyLoss()                   # 定义交叉熵损失函数
optimizer = optim.Adam(net.parameters(), lr=0.0002)     # 定义优化器

save_path = './AlexNet.pth'         #  网络保存的路径
best_acc = 0.0                      # 保存最好准确率的model

for epoch in range(10):     # 训练次数
    net.train()             # 管理dropout 方法,在训练的时候随机失活
    running_loss = 0.0
    for step, data in enumerate(train_loader,start=0):
        images, labels = data
        optimizer.zero_grad()                   # 梯度清零
        outputs = net(images)                   # 前向传播
        loss = loss_function(outputs, labels)   # 计算损失函数
        loss.backward()                         # 反向传播
        optimizer.step()                        # 更新权重

        running_loss += loss.item()

    # validate
    net.eval()      # 关闭dropout
    acc = 0.0  # accumulate accurate number / epoch
    total = 0
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images)                    # 网络预测
            predicted= torch.max(outputs, dim=1)[1]
            acc += (predicted == val_labels).sum().item()    # 计算准确率
            total += val_labels.size(0)

    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / step, 100 * acc / total))

    if (acc / total) > best_acc:                # 保存当前最好的准确率
        best_acc = acc / total
        torch.save(net.state_dict(), save_path)

print('Finished Training')

预测图片部分:

import torch
from PIL import Image
from torchvision import transforms
from model import AlexNet

data_transform = transforms.Compose([transforms.Resize((224, 224)),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img = Image.open("./tulips.png")            # 载入图像
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)     # 增加维度,第0维增加1 ,维度(1,C,H,W)

classes = ( "daisy","dandelion","roses","sunflowers","tulips")  # 5个分类label

model = AlexNet(num_classes=5)      # 实例化网络
model.load_state_dict(torch.load("./AlexNet.pth"))  # 读取保存的网络参数
model.eval()                        # 关闭dropout 方法

with torch.no_grad():           # 预测不需要计算梯度
    output = model(img)
    predict = torch.max(output, dim=1)[1]
    print(classes[int(predict)])

训练网络打印的信息为:

pytorch 搭建AlexNet 对花进行分类

输入的预测图片为:

pytorch 搭建AlexNet 对花进行分类

预测的结果为:

pytorch 搭建AlexNet 对花进行分类

如果想要分类自己的分类目标的话,只需要将flower_data 里面的图片改成自己的就行了
然后用split_data 划分一下数据就行
注: 目录就是labels ,顺序不能错。目录的顺序需要和这个保持一致

pytorch 搭建AlexNet 对花进行分类

Original: https://blog.csdn.net/qq_44886601/article/details/127578228
Author: Henry_zs
Title: pytorch 搭建AlexNet 对花进行分类

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/666757/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球