【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

文章目录

卷积网络实战 对花进行分类

本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能

在文件夹中有102种花,我们主要要对这些花进行分类任务
文件夹结构

flower_data

  • train
  • 1(类别)
  • 2
    • xxx.png / xxx.jpg
  • valid

主要分为以下几个大模块

数据预处理部分

  • 数据增强
  • 数据预处理

网络模块设置

  • 加载预训练模型,直接调用torchVision的经典网络架构
  • 因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务

网络模型的保存与测试

  • 模型保存可以带有选择性

数据下载:

https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset

改一下文件名,然后将它放到同一根目录就可以了

下面是我的数据根目录

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

; 1. 导入工具包

import os
import matplotlib.pyplot as plt

%matplotlib inline
import numpy as np
import torch
from torch import nn

import torch.optim as optim
import torchvision
from torchvision import transforms, models, datasets

import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

  1. 数据预处理与操作

data_dir = './flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

python目录点杠的组合与区别
注: 里面注明了点杠和斜杠的操作

  1. 制作好数据源

  2. data_transforms中制定了所有图像预处理的操作

  3. ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
data_transforms = {

    'train': transforms.Compose([transforms.RandomRotation(45),
                                 transforms.CenterCrop(224),

                                 transforms.RandomHorizontalFlip(p = 0.5),
                                 transforms.RandomVerticalFlip(p = 0.5),

                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
                                 transforms.RandomGrayscale(p = 0.025),

                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                ]),

    'valid': transforms.Compose([transforms.Resize(256),
                                 transforms.CenterCrop(224),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                ]),
}
batch_size = 8
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes

image_datasets

{'train': Dataset ImageFolder
     Number of datapoints: 6552
     Root location: ./flower_data/train
     StandardTransform
 Transform: Compose(
                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
                CenterCrop(size=(224, 224))
                RandomHorizontalFlip(p=0.5)
                RandomVerticalFlip(p=0.5)
                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
                RandomGrayscale(p=0.025)
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ),
 'valid': Dataset ImageFolder
     Number of datapoints: 818
     Root location: ./flower_data/valid
     StandardTransform
 Transform: Compose(
                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
                CenterCrop(size=(224, 224))
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            )}

dataloaders
{'train': <torch.utils.data.dataloader.dataloader at 0x2796a9c0940>,
 'valid': <torch.utils.data.dataloader.dataloader at 0x2796aaca6d8>}
</torch.utils.data.dataloader.dataloader></torch.utils.data.dataloader.dataloader>
dataset_sizes
{'train': 6552, 'valid': 818}

读取标签对应的实际名字

使用同一目录下的json文件,反向映射出花对应的名字

with open('./flower_data/cat_to_name.json', 'r') as f:
    cat_to_name = json.load(f)
cat_to_name
{'21': 'fire lily',
 '3': 'canterbury bells',
 '45': 'bolero deep blue',
 '1': 'pink primrose',
 '34': 'mexican aster',
 '27': 'prince of wales feathers',
 '7': 'moon orchid',
 '16': 'globe-flower',
 '25': 'grape hyacinth',
 '26': 'corn poppy',
 '79': 'toad lily',
 '39': 'siam tulip',
 '24': 'red ginger',
 '67': 'spring crocus',
 '35': 'alpine sea holly',
 '32': 'garden phlox',
 '10': 'globe thistle',
 '6': 'tiger lily',
 '93': 'ball moss',
 '33': 'love in the mist',
 '9': 'monkshood',
 '102': 'blackberry lily',
 '14': 'spear thistle',
 '19': 'balloon flower',
 '100': 'blanket flower',
 '13': 'king protea',
 '49': 'oxeye daisy',
 '15': 'yellow iris',
 '61': 'cautleya spicata',
 '31': 'carnation',
 '64': 'silverbush',
 '68': 'bearded iris',
 '63': 'black-eyed susan',
 '69': 'windflower',
 '62': 'japanese anemone',
 '20': 'giant white arum lily',
 '38': 'great masterwort',
 '4': 'sweet pea',
 '86': 'tree mallow',
 '101': 'trumpet creeper',
 '42': 'daffodil',
 '22': 'pincushion flower',
 '2': 'hard-leaved pocket orchid',
 '54': 'sunflower',
 '66': 'osteospermum',
 '70': 'tree poppy',
 '85': 'desert-rose',
 '99': 'bromelia',
 '87': 'magnolia',
 '5': 'english marigold',
 '92': 'bee balm',
 '28': 'stemless gentian',
 '97': 'mallow',
 '57': 'gaura',
 '40': 'lenten rose',
 '47': 'marigold',
 '59': 'orange dahlia',
 '48': 'buttercup',
 '55': 'pelargonium',
 '36': 'ruby-lipped cattleya',
 '91': 'hippeastrum',
 '29': 'artichoke',
 '71': 'gazania',
 '90': 'canna lily',
 '18': 'peruvian lily',
 '98': 'mexican petunia',
 '8': 'bird of paradise',
 '30': 'sweet william',
 '17': 'purple coneflower',
 '52': 'wild pansy',
 '84': 'columbine',
 '12': "colt's foot",
 '11': 'snapdragon',
 '96': 'camellia',
 '23': 'fritillary',
 '50': 'common dandelion',
 '44': 'poinsettia',
 '53': 'primula',
 '72': 'azalea',
 '65': 'californian poppy',
 '80': 'anthurium',
 '76': 'morning glory',
 '37': 'cape flower',
 '56': 'bishop of llandaff',
 '60': 'pink-yellow dahlia',
 '82': 'clematis',
 '58': 'geranium',
 '75': 'thorn apple',
 '41': 'barbeton daisy',
 '95': 'bougainvillea',
 '43': 'sword lily',
 '83': 'hibiscus',
 '78': 'lotus lotus',
 '88': 'cyclamen',
 '94': 'foxglove',
 '81': 'frangipani',
 '74': 'rose',
 '89': 'watercress',
 '73': 'water lily',
 '46': 'wallflower',
 '77': 'passion flower',
 '51': 'petunia'}

4.展示一下数据

def im_convert(tensor):
    """数据展示"""
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()

    image = image.transpose(1, 2, 0)

    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))

    image = image.clip(0, 1)

    return image

fig = plt.figure(figsize = (20, 12))
columns = 4
rows = 2

dataiter = iter(dataloaders['valid'])
inputs, classes = dataiter.next()

for idx in range(columns * rows):
    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])

    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    plt.imshow(im_convert(inputs[idx]))
plt.show()

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
  1. 加载models提供的模型,并直接用训练好的权重做初始化参数
model_name = 'resnet'

feature_extract = True

train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.   Training on CPU ...')
else:
    print('CUDA is available! Training on GPU ...')

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
CUDA is not available.   Training on CPU ...


def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

model_ft = models.resnet152()
model_ft
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
&#x4E2D;&#x95F4;&#x8FD8;&#x6709;&#x5F88;&#x591A;&#x8F93;&#x51FA;&#x7ED3;&#x679C;&#xFF0C;&#x6211;&#x4EEC;&#x7740;&#x91CD;&#x770B;&#x6A21;&#x578B;&#x67B6;&#x6784;&#x7684;&#x4E24;&#x4E2A;&#x5C42;&#x7EA7;&#x5C31;&#x5B8C;&#x4E86;&#xFF0C;&#x7F29;&#x7565;&#x3002;&#x3002;&#x3002;
    (2): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)

最后是1000分类,2048输入,分为1000个分类
而我们需要将我们的任务进行调整,将1000分类改为102输出

6.初始化模型架构

步骤如下:

  1. 将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
  2. 可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
  3. 无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数

官方文档链接
https://pytorch.org/vision/stable/models.html


def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):

    model_ft = None
    input_size = 0

    if model_name == "resnet":
"""
        Resnet152
"""

        model_ft = models.resnet152(pretrained = use_pretrained)

        set_parameter_requires_grad(model_ft, feature_extract)

        num_frts = model_ft.fc.in_features

        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
                                   nn.LogSoftmax(dim = 1))
        input_size = 224

    elif model_name == "alexnet":
"""
        Alexnet
"""
        model_ft = models.alexnet(pretrained = use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)

        num_frts = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
        input_size = 224

    elif model_name == "vgg":
"""
        VGG11_bn
"""
        model_ft = models.vgg16(pretrained = use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_frts = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
        input_size = 224

    elif model_name == "squeezenet":
"""
        Squeezenet
"""
        model_ft = models.squeezenet1_0(pretrained = use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
"""
        Densenet
"""
        model_ft = models.desenet121(pretrained = use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_frts = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_frts, num_classes)
        input_size = 224

    elif model_name == "inception":
"""
        Inception V3
"""
        model_ft = models.inception_V(pretrained = use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)

        num_frts = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)

        num_frts = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_frts, num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size
  1. 设置需要训练的参数

model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)

model_ft = model_ft.to(device)

filename = 'checkpoint.pth'

params_to_update = model_ft.parameters()

print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name, param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t", name)
else:
    for name, param in model_ft.named_parameters():
        if param.requires_grad ==True:
            print("\t", name)
Params to learn:
     fc.0.weight
     fc.0.bias
  1. 训练与预测

7.1 优化器设置


optimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)

scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

criterion = nn.NLLLoss()

def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
    since = time.time()

    best_acc = 0
"""
    checkpoint = torch.load(filename)
    best_acc = checkpoint['best_acc']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    model.class_to_idx = checkpoint['mapping']
"""

    model.to(device)

    val_acc_history = []
    train_acc_history = []
    train_losses = []
    valid_losses = []
    LRs = [optimizer.param_groups[0]['lr']]

    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:

                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):

                    if is_inception and phase == 'train':
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            time_elapsed = time.time() - since
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc

                best_model_wts = copy.deepcopy(model.state_dict())
                state = {

                  'state_dict': model.state_dict(),
                  'best_acc': best_acc,
                  'optimizer' : optimizer.state_dict(),
                }
                torch.save(state, filename)
            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
                scheduler.step(epoch_loss)
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)

        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs

7.2 开始训练模型

我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次


model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))

`
Epoch 0/4
Time elapsed 60m 11s
train Loss: 2.3126 Acc: 0.7053
Time elapsed 63m 16s
valid Loss: 3.2325 Acc: 0.6626
Optimizer learning rate : 0.0100000

Epoch 2/4
Time elapsed 132m 49s
train Loss: 5.4290 Acc: 0.6548
Time elapsed 138m 49s
valid Loss: 6.4208 Acc: 0.6027
Optimizer learning rate : 0.0100000

Epoch 4/4
Time elapsed 35m 22s
train Loss: 1.7636 Acc: 0.7346
Time elapsed 38m 42s
valid Loss: 3.6377 Acc: 0.6455
Optimizer learning rate : 0.0010000

Epoch 1/1

Original: https://blog.csdn.net/LeungSr/article/details/126747940
Author: FeverTwice
Title: 【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/661435/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球