AlexNet模型及代码详解

Alex在2012年提出的alexnet网络结构模型引爆了神经网络的应用热潮,并赢得了2012届图像识别大赛的冠军,使得CNN成为在图像分类上的核心算法模型。

AlexNet模型及代码详解

该网络的亮点在于:
(1)首次使用了GPU进行网络加速训练。
(2)使用了ReLU激活函数,而不是传统的Sigmoid激活函数以及Tanh激活函数。
(3)使用了LRN局部响应归一化。
(4)在全连接层的前两层中使用了Droupout随机失活神经元操作,以减少过拟合。

模型组成

  • 输入层
  • 5个卷积层
  • 3个全链接层

AlexNet模型及代码详解

第1层:卷积层(卷积、池化)

Conv1

输入:input_size = [224, 224, 3]
卷积层:
kernels = 48 * 2 = 96 组卷积核
kernel_size = 11
padding = [1, 2] (左上围加半圈0,右下围加2倍的半圈0)
stride = 4
输出:output_size = [55, 55, 96]

AlexNet模型及代码详解

Maxpool1

  • 输入:input_size = [55, 55, 96]
  • 池化层:(只改变尺寸,不改变深度channel)
  • kernel_size = 3
  • padding = 0
  • stride = 2
  • 输出:output_size = [27, 27, 96]

AlexNet模型及代码详解

Conv2

  • 输出:output_size = [27, 27, 256]

AlexNet模型及代码详解

Maxpool2

  • 输出:output_size = [13, 13, 256]

AlexNet模型及代码详解

Conv3

  • 输出:output_size = [13, 13, 384]

AlexNet模型及代码详解

Conv4

  • 输出:output_size = [13, 13, 384]

AlexNet模型及代码详解

Conv5

  • 输出:output_size = [13, 13, 256]

AlexNet模型及代码详解

Maxpool3

  • 输出:output_size = [6, 6, 256]

AlexNet模型及代码详解

FC1、FC2、FC3

Maxpool3 → (66256) → FC1 → 4096 → FC2 → 4096 → FC3 → 1000

代码:

1. model.py

import torch.nn as nn
import torch

class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        # 用nn.Sequential()将网络打包成一个模块,精简代码
        self.features = nn.Sequential(   # 卷积层提取图像特征
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),                                  # 直接修改覆盖原值,节省运算内存
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(   # 全连接层对图像分类
            nn.Dropout(p=0.5),             # Dropout 随机失活神经元,默认比例为0.5
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    # 前向传播过程
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)   # 展平后再传入全连接层
        x = self.classifier(x)
        return x

    # 网络权重初始化,实际上 pytorch 在构建网络时会自动初始化权重
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):                            # 若是卷积层
                nn.init.kaiming_normal_(m.weight, mode='fan_out',   # 用(何)kaiming_normal_法初始化权重
                                        nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)                    # 初始化偏重为0
            elif isinstance(m, nn.Linear):            # 若是全连接层
                nn.init.normal_(m.weight, 0, 0.01)    # 正态分布初始化
                nn.init.constant_(m.bias, 0)          # 初始化偏重为0

2. train.py

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))

    net = AlexNet(num_classes=5, init_weights=True)

    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(net.parameters(), lr=0.0002)

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')

if __name__ == '__main__':
    main()

3. predict.py

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json

预处理
data_transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

load image
img = Image.open("rose.jpg")
plt.imshow(img)
[N, C, H, W]
img = data_transform(img)
expand batch dimension
img = torch.unsqueeze(img, dim=0)

read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

create model
model = AlexNet(num_classes=5)
load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))

关闭 Dropout
model.eval()
with torch.no_grad():
    # predict class
    output = torch.squeeze(model(img))     # 将输出压缩,即压缩掉 batch 这个维度
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

Original: https://blog.csdn.net/weixin_42457110/article/details/124980914
Author: 工藤新三
Title: AlexNet模型及代码详解

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/785626/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • OpenCV-Python实战(番外篇)——OpenCV、NumPy和Matplotlib直方图比较

    OpenCV-Python实战(番外篇)——OpenCV、NumPy和Matplotlib直方图比较 * – 前言 – OpenCV、NumPy和Matpl…

    Python 2023年8月28日
    055
  • Python设计模式-单例模式

    ; 设计思想 通过上面的介绍,我们可以知道单例模式最重要的就是要保证一个类只有一个实例并且这个类易于被访问,那么要怎么做才能保证一个类具有一个实例呢?一个全局变量使得一个对象可以被…

    Python 2023年8月4日
    058
  • python pytest setupclass_pytest:如何将类参数传递给setup类

    我正在使用pytest的参数化注释将参数传递到类中。我可以在测试方法中使用参数,但是,我不知道如何在setup_class方法中使用参数。import pytest params …

    Python 2023年9月11日
    057
  • 波动方程数值求解(一)

    波动方程数值解是波动方程正演、逆时偏移和全波形反演的核心技术之一。本文采用二阶有限差分对波动方程进行了离散,进而实现了对波动方程的数值求解,模拟出其在介质中的传播过程。1、二维声波…

    Python 2023年8月26日
    055
  • Python编程—pytest自动化测试部署

    1、pytest介绍 pytest是一个非常成熟的自动化测试框架。 pytest主要features如下: 容易学习 支持简单的单元测试和复杂的功能测试 具有大量第三方插件:pyt…

    Python 2023年9月10日
    059
  • Python实现WOA智能鲸鱼优化算法优化支持向量机分类模型(SVC算法)项目实战

    说明:这是一个机器学习实战项目(附带 数据+代码+文档+视频讲解),如需 数据+代码+文档+视频讲解可以直接到文章最后获取。 1.项目背景 鲸鱼优化算法 (whale optimi…

    Python 2023年9月3日
    062
  • 1-出租车数据的基础处理,由gps生成OD(pandas)

    import pandas as pd data = pd.read_csv(r’data-sample/TaxiData-Sample’,header = None) data….

    Python 2023年8月7日
    054
  • pytest

    pytest学习总结 1、介绍 单元测试框架——软件开发过程中针对软件的最小单位(函数、方法)进行正确性的检查测试Java测试框架:junit和testngpython测试框架:u…

    Python 2023年9月14日
    036
  • 深度学习笔记:07神经网络之手写数字识别的经典实现

    神经网络之手写数字识别的经典实现 上一节完成了简单神经网络代码的实现,下面我们将进行最终的实现:输入一张手写图片后,网络输出该图片对应的数字。由于网络需要用0-9一共十个数字中挑选…

    Python 2023年8月28日
    059
  • python带你采集西瓜无水印美女舞蹈视频数据~

    Original: https://www.cnblogs.com/Qqun261823976/p/16700068.htmlAuthor: python倩Title: pytho…

    Python 2023年6月9日
    071
  • pytest单元测试框架

    单元测试是指在软件开发中,针对软件的最小单位(函数、方法)进行正确性的检查测试 java:Junit python: unittest和pytest 自动化测试框架的作用: * &…

    Python 2023年9月13日
    038
  • 设计模式之建造者模式

    builder desigin pattern 建造者模式的概念、建造者模式的结构、建造者模式的优缺点、建造者模式的使用场景、建造者模式的实现示例、建造者模式的源码分析 1、建造者…

    Python 2023年10月10日
    062
  • numba学习一

    numba 编译型语言和解释型语言 首先了解一下编译型语言和解释型语言(也经常叫脚本语言): 1、编译型语言,C、C++、Fortran、Pascal、Ada,由编译型语言编写的源…

    Python 2023年6月6日
    075
  • pytest实战练习

    pytest是单元测试框架,用作代码层测试的框架。简单、易用,很多大型开源测试框架如appium、httprunner框架也基于它实现。网页、手机应用以及接口等测试都支持,也就是p…

    Python 2023年9月9日
    040
  • 【Pandas数据处理100例目录】Python数据分析玩转Excel表格数据

    ### 回答1: Python_是一种功能强大的编程语言,可以用于各种 _数据分析_任务。而在 _Python_的 _数据分析_工具库中, _pandas_是最受欢迎和广泛使用的工…

    Python 2023年11月8日
    047
  • Python自学教程12-面向对象编程

    几乎所有的现代编程语言都支持面向对象编程,而面向对象编程是最有效的软件编写方法之一。您可以使用类和对象来表示现实中的任何内容和行为。 [En] Almost all modern …

    Python 2023年5月24日
    080
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球