图像分类:AlexNet网络、五分类 flower 数据集、pytorch

文章目录

*
一、代码结构
二、数据集的处理

+ 2.1 数据集的下载和切分:split_data.py
+ 2.2 数据集的加载:dataset.py
+ 2.3 数据集图片可视化:imgs_vasual.py
三、AlexNet介绍及网络搭建:model.py

+ 3.1 AlexNet网络结构
+ 3.2 AlexNet网络的亮点
+ 3.3 网络搭建
四、训练及保存精度最高的网络参数:train.py
五、用数据集之外的图片进行测试:predict.py

代码来源:
使用pytorch搭建AlexNet并训练花分类数据集

一、代码结构

图像分类:AlexNet网络、五分类 flower 数据集、pytorch

; 二、数据集的处理

2.1 数据集的下载和切分:split_data.py

"""
视频教程:https://www.bilibili.com/video/BV1p7411T7Pc/?spm_id_from=333.788
flower数据集为5分类数据集,共有 {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} 5个分类。

该程序用于将数据集切分为训练集和验证集,使用步骤如下:
(1)在"split_data.py"的同级路径下创建新文件夹"flower_data"
(2)点击链接下载花分类数据集 http://download.tensorflow.org/example_images/flower_photos.tgz
(3)解压数据集到flower_data文件夹下
(4)执行"split_data.py"脚本自动将数据集划分为训练集train和验证集val

切分后的数据集结构:
├── split_data.py
├── flower_data
       ├── flower_photos.tgz (下载的未解压的原始数据集)
       ├── flower_photos(解压的数据集文件夹,3670个样本)
       ├── train(生成的训练集,3306个样本)
       └── val(生成的验证集,364个样本)
"""""

import os
from shutil import copy, rmtree
import random

def mk_file(file_path: str):
    if os.path.exists(file_path):

        rmtree(file_path)
    os.makedirs(file_path)

def main():
    random.seed(0)

    split_rate = 0.1

    cwd = os.getcwd()
    data_path = os.path.join(cwd, "flower_data/flower_photos/flower_photos")
    data_root=os.path.join(cwd, "flower_data")
    origin_flower_path = os.path.join(data_path, "")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)

    flower_class = [cla for cla in os.listdir(origin_flower_path)
                    if os.path.isdir(os.path.join(origin_flower_path, cla))]

    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in flower_class:

        mk_file(os.path.join(train_root, cla))

    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:

        mk_file(os.path.join(val_root, cla))

    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path, cla)
        images = os.listdir(cla_path)
        num = len(images)

        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:

                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:

                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")
        print()

    print("processing done!")

if __name__ == '__main__':
    main()

2.2 数据集的加载:dataset.py

import os
import json
import torch
from torchvision import transforms, datasets

def dataset(batch_size):
    train_path = "flower_data/train"
    val_path = "flower_data/val"
    assert os.path.exists(train_path), "{} path does not exist.".format(train_path)

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print('Using {} dataloader workers every process'.format(nw))

"""
    数据预处理,训练集做随机裁剪和随机翻转用来数据增强
    RandomResizedCrop(224) 表示先随机裁剪为不同的大小和宽高比,然后缩放为(224,224)大小
    RandomHorizontalFlip() 表示随机水平翻转(即左右翻转),默认概率为 0.5
"""

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

"""
    torchvision.datasets.ImageFolder 适用于加载特定存储格式的数据集,具体使用可参考博客:
    https://blog.csdn.net/qq_39507748/article/details/105394808
"""

    train_dataset = datasets.ImageFolder(root=train_path,transform=data_transform["train"])
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                               shuffle=True, num_workers=nw)
    validate_dataset = datasets.ImageFolder(root=val_path, transform=data_transform["val"])
    valid_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size,
                                               shuffle=True, num_workers=nw)
    train_num = len(train_dataset)
    val_num = len(validate_dataset)
    print(f"using {train_num} images for training, {val_num} images for valid.")

    flower_class_id = train_dataset.class_to_idx

    cla_dict = dict((val, key) for key, val in flower_class_id.items())

    json_str = json.dumps(cla_dict, indent=4)
"""
    json.dumps() 将 python对象转换成 json对象,生成一个字符串。
    indent=4 表示缩进4个空格,方便阅读。
    json_str的内容为:
        {
            "0": "daisy",
            "1": "dandelion",
            "2": "roses",
            "3": "sunflowers",
            "4": "tulips"
        }
"""

    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    return train_loader,valid_loader,val_num

2.3 数据集图片可视化:imgs_vasual.py

"""
图片可视化函数,用于imshow多张图片,并输出每张图片对应的label
"""""

import os
import torch
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np

def imgs_imshow(batch_size):

    train_path = "flower_data/train"
    assert os.path.exists(train_path), "{} path does not exist.".format(train_path)
    tramsform=transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_dataset = datasets.ImageFolder(root=train_path, transform=tramsform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                               shuffle=True, num_workers=0)

    data_iter = iter(train_loader)
    image, label = data_iter.next()

    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    print('   '.join('%5s' % cla_dict[label[j].item()] for j in range(batch_size)))

    img = utils.make_grid(image)
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

if __name__ == '__main__':
    imgs_imshow(batch_size=6)

三、AlexNet介绍及网络搭建:model.py

3.1 AlexNet网络结构

图像分类:AlexNet网络、五分类 flower 数据集、pytorch
本程序中输入图片的尺寸是 224*224,输出为5分类而不是1000分类,其他数据均为图中的数据。

; 3.2 AlexNet网络的亮点

(1)首次利用GPU进行网络加速训练,作者用了两块GPU进行并行训练。

(2)使用了ReLU激活函数,而不是传统的Sigmoid激活函数以及Tanh激活函数。

(3) 使用了LRN局部响应归一化(Local Response Normalization)。本程序中没有用LRN,因为这个方法现在已经用的很少了。

(4)在全连接层的前两层中使用了Dropout随机失活神经元操作,以减少过拟合。

3.3 网络搭建

import torch.nn as nn

"""
本程序中没有使用LRN归一化,因为这个方法现在已经用的很少了。
"""

class AlexNet(nn.Module):
    def __init__(self,class_num=1000,init_weights=False):
        super(AlexNet,self).__init__()
        self.dropout=0.1

        self.features=nn.Sequential(
            nn.ZeroPad2d((2, 1, 2, 1)),

            nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4),

            nn.ReLU(inplace=True),

            nn.MaxPool2d(kernel_size=3,stride=2),

            nn.Conv2d(96,256,kernel_size=5,padding=2),

            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2),
            nn.Conv2d(256,384,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384,256,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2),
        )

        self.classifier=nn.Sequential(
            nn.Dropout(p=self.dropout),
            nn.Linear(in_features=9216,out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=self.dropout),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=class_num),
        )

        if init_weights:
            self._initialize_weights()

    def forward(self,x):
        x=self.features(x)
        x=x.view(-1,256*6*6)
        x=self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

"""
    _initialize_weights()方法的解释:
    self.modules():  Returns an iterator over all modules in the network,即遍历网络中的所有层,并返回一个迭代器。
    for m in self.modules(): 遍历网络中的每一层
    if isinstance(m, nn.Conv2d): 判断m是否是 nn.Conv2d层
    其实并不需要用_initialize_weights()方法进行初始化,因为pytorch会默认以 nn.init.kaiming_normal_() 进行初始化。
"""

四、训练及保存精度最高的网络参数:train.py

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet
from dataset import dataset

def train(batch_size, epochs, lr=0.001):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    train_loader, valid_loader, val_num = dataset(batch_size=batch_size)
    model = AlexNet(class_num=5, init_weights=True)
    model.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):

        model.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader)
        for step, (images, labels) in enumerate(train_bar):
            optimizer.zero_grad()
            outputs = model(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_bar.desc = f"train epoch [{epoch+1}/{epochs}]   loss= {loss:.3f}"

        model.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(valid_loader)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d]   train_loss= %.3f   val_accuracy= %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(model.state_dict(), save_path)

    print('Finished Training')

if __name__ == '__main__':
    train(batch_size=16, epochs=10, lr=0.0002)

训练结果(没有跑完):

图像分类:AlexNet网络、五分类 flower 数据集、pytorch

五、用数据集之外的图片进行测试:predict.py

import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import AlexNet

def predict():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    img_path = "./tulip.png"
    assert os.path.exists(img_path), f"file: '{img_path}' dose not exist."
    img = Image.open(img_path)
    plt.imshow(img)
    img = data_transform(img)
    img = torch.unsqueeze(img, dim=0)

    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
    json_file = open(json_path, "r")
    class_indict = json.load(json_file)

    model = AlexNet(class_num=5).to(device)
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path), f"file: '{weights_path}' dose not exist."
    model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        output = torch.squeeze(model(img.to(device))).cpu()

        predict = torch.softmax(output, dim=0)

        predict_cla = torch.argmax(predict).item()

    img_class = class_indict[str(predict_cla)]
    img_preb=predict[predict_cla].item()
    print_res = f"class: {img_class}    prob: {img_preb:.3}"
    plt.title(print_res)
    for i in range(len(predict)):
        print(f"class: {class_indict[str(i)]:12}   prob: {predict[i].item():.3}")
    plt.show()

if __name__ == '__main__':
    predict()

测试结果:

class: daisy          prob: 0.00238
class: dandelion      prob: 0.000163
class: roses          prob: 0.199
class: sunflowers     prob: 0.00173
class: tulips         prob: 0.797

测试图片及类别预测:

图像分类:AlexNet网络、五分类 flower 数据集、pytorch

Original: https://blog.csdn.net/qq_43799400/article/details/123555090
Author: ctrl A_ctrl C_ctrl V
Title: 图像分类:AlexNet网络、五分类 flower 数据集、pytorch

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/662386/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球