pytorch迁移学习+ResNet50实现猫十二分类

之前写过一篇实现猫十二分类的文章,写出了大体的流程,但实际效果并不佳。本文采取 微调预训练模型的方式,使 准确率从0.3提升到了0.93。大体流程参考ResNet猫十二分类,本文只给出不同的地方。

迁移学习的两种方式

  • dataset定义

import os
import cv2
import torch
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from PIL import Image

class myData(Dataset):
    def __init__(self, kind):
        super(myData, self).__init__()
        self.mode = kind

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
        ])

        if kind == 'test':
            self.imgs = self.load_origin_data()

        elif kind == 'train':
            self.imgs, self.labels = self.load_origin_data()

            print('train size:')
            print(len(self.imgs))

        else:
            self.imgs, self.labels = self.load_origin_data()

    def __getitem__(self, index):
        if self.mode == 'test':
            sample = self.transform(self.imgs[index])
            return sample
        else:
            sample = self.transform(self.imgs[index])
            return sample, torch.tensor(self.labels[index])

    def __len__(self):
        return len(self.imgs)

    def load_origin_data(self):
        filelist = './data/%s_split_list.txt' % self.mode

        imgs = []
        labels = []
        data_dir = os.getcwd()
        if self.mode == 'train' or self.mode == 'val':
            with open(filelist) as flist:
                lines = [line.strip() for line in flist]
                if self.mode == 'train':
                    np.random.shuffle(lines)
                for line in lines:
                    img_path, label = line.split('&')
                    img_path = os.path.join(data_dir, img_path)
                    try:

                        img = Image.fromarray(cv2.imdecode(np.fromfile(img_path, dtype=np.float32), 1))
                        imgs.append(img)
                        labels.append(int(label))
                    except:
                        print(img_path)
                        continue
                return imgs, labels
        elif self.mode == 'test':
            full_lines = os.listdir('data/test/')
            lines = [line.strip() for line in full_lines]
            for img_path in lines:
                img_path = os.path.join(data_dir, "data/test/", img_path)

                img = Image.fromarray(cv2.imdecode(np.fromfile(img_path, dtype=np.float32), 1))
                imgs.append(img)
            return imgs

    def load_data(self, mode, shuffle, color_jitter, rotate):
        '''
        :return : img, label
        img: (channel, w, h)
        '''
        filelist = './data/%s_split_list.txt' % mode

        imgs = []
        labels = []
        data_dir = os.getcwd()
        if mode == 'train' or mode == 'val':
            with open(filelist) as flist:
                lines = [line.strip() for line in flist]
                if shuffle:
                    np.random.shuffle(lines)

                for line in lines:
                    img_path, label = line.split('&')
                    img_path = os.path.join(data_dir, img_path)
                    try:
                        img, label = process_image((img_path, label), mode, color_jitter, rotate)
                        imgs.append(img)
                        labels.append(label)
                    except:

                        continue
                return imgs, labels

        elif mode == 'test':
            full_lines = os.listdir('data/test/')
            lines = [line.strip() for line in full_lines]
            for img_path in lines:
                img_path = os.path.join(data_dir, "data/test/", img_path)

                img = process_image((img_path, 0), mode, color_jitter, rotate)
                imgs.append(img)
            return imgs

首先数据部分有一些改动

img_datasets = {x: myData(x) for x in ['train', 'val', 'test']}
dataset_sizes = {x: len(img_datasets[x]) for x in ['train', 'val', 'test']}

train_loader = DataLoader(
    dataset=img_datasets['train'],
    batch_size=batches,
    shuffle=True
)

val_loader = DataLoader(
    dataset=img_datasets['val'],
    batch_size=1,
    shuffle=False
)

test_loader = DataLoader(
    dataset=img_datasets['test'],
    batch_size=1,
    shuffle=False
)

dataloaders = {
    'train': train_loader,
    'val': val_loader,
    'test': test_loader
}

值得参考的tricks有:


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    print('Best val Acc: {:4f}'.format(best_acc))

    model.load_state_dict(best_model_wts)
    return model

对模型所有层的所有参数都进行目标域的训练。

使用pretrain = True的方式得到预训练模型, 更改全连接层的输出维度,然后训练残差模型


model_ft = models.resnet50(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 12)
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

optimizer_ft = torch.optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.1)

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=15)

没啥说的

def vali(M ,dataset):
    M.eval()
    with torch.no_grad():
        correct = 0
        for (data, target) in val_loader:
            data, target = data.to(device), target.to(device)

            pred = M(data)
            _, id = torch.max(pred, 1)
            correct += torch.sum(id == target.data)
        print("test accu: %.03f%%" % (100 * correct / len(dataset)))
    return (100 * correct / len(dataset)).item()
test_accu = int(vali(model_ft, img_datasets['val']) * 100)

model_name = 'val_{}.pkl'.format(test_accu)

torch.save(model_ft.state_dict(), os.path.join("./myModels", model_name))

model_ft = models.resnet50(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 12)
model_ft = model_ft.to(device)

model_ft.load_state_dict(torch.load("./myModels/val_9343.pkl"))
vali(model_ft, img_datasets['val'])

输出结果:

test accu: 93.433%

Original: https://blog.csdn.net/m0_56945333/article/details/123260627
Author: KimJuneJune
Title: pytorch迁移学习+ResNet50实现猫十二分类

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/662165/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • 使用KNN进行分类和回归

    一般情况下k-Nearest Neighbor (KNN)都是用来解决分类的问题,其实KNN是一种可以应用于数据分类和预测的简单算法,本文中我们将它与简单的线性回归进行比较。 KN…

    人工智能 2023年6月16日
    097
  • 无监督学习

    现实生活中常常会有这样的问题:缺乏足够的先验知识,因此难以人工标注类别或进行人工类别标注的成本太高。很自然地,在建模的过程中希望计算机能代我们完成这些工作,或至少提供一些帮助。根据…

    人工智能 2023年5月31日
    071
  • pytorch框架实现BI-LSTM模型进行情感分类

    总述 本文的目标是针对一个句子,给出其情感二分类,正向/负向。代码存放地址: https://github.com/stay-leave/BI-LSTM-sentiment-cla…

    人工智能 2023年6月30日
    090
  • Autoware 障碍物车辆意图与轨迹预测

    Autoware 障碍物车辆意图与轨迹预测yeye.liu@foxmail.comAutoware 中预测道路中其他车辆的意图与轨迹部分在一篇博士毕业论文的第四章中进行了详细解释,…

    人工智能 2023年6月10日
    073
  • pandas教程06—DataFrame的实用操作

    文章目录 欢迎关注公众号【Python开发实战】,免费领取Python学习电子书! 工具-pandas * Dataframe对象 – 自动对齐 处理缺失值 使用gro…

    人工智能 2023年7月8日
    092
  • 【机器学习】笔记1:回归与误差分析

    回归与误差分析 regression * step 1:model step 2:Goodness of Function step 3:Best Function(Gradien…

    人工智能 2023年6月17日
    071
  • 随机森林可视化

    今天看到别人的文章,说到了随机森林可视化,于是尝试了下。 window安装 windows版本安装:1.在下面去下载window的exe安装包,安装graphviz。 http:/…

    人工智能 2023年6月28日
    065
  • 【图像识别】基于HSV和RGB模型水果分类matlab代码

    1 简介 图像识别主要是研究用计算机代替人去处理大量的物理信息,从而帮助人们建华劳动。机械分类耗时段的特点很符合水果的时间特性。本设计针对多种常见水果混合的图像,利用Matlab软…

    人工智能 2023年7月2日
    077
  • yolov5s-5.0网络模型结构图

    看了很多yolov5方面的东西,最近需要yolov5得模型结构图,但是网上的最多的是大白老师的,但是大白老师的yolov5得模型结构图不知道是哪个版本得,肯定不是5.0和6.0版本…

    人工智能 2023年7月25日
    062
  • CLion配置opencv环境

    工具准备 1.clion官网链接:clion2.cmake官网链接:cmake下载红框标记的压缩包,免安装。3.mingw官网链接:mingw安装红框标记下载免安装版本,解压可用。…

    人工智能 2023年7月28日
    075
  • python二值化

    import cv2 as cv import numpy as np def threshold_demo(image): #全局阈值 gray = cv.cvtColor(im…

    人工智能 2023年7月19日
    055
  • YOLO算法创新改进系列项目汇总(入门级教程指南)

    抵扣说明: 1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。 Original: https://blo…

    人工智能 2023年7月30日
    087
  • git 记录

    git 工作区介绍 workspace:工作区,就是平时存放项目代码的地方。 Index/Stage:暂存区,用于临时存放你的改动,事实上只是一个文件,保存即将提交到文件列表信息。…

    人工智能 2023年6月27日
    091
  • MAC M1:解决在jupyter中引入tensorflow内核似乎挂掉的问题

    背景 :在使用jupyter进行tensorflow学习的过程中,遇到import tensorflow就出现内核似乎挂掉的提示,查阅与实践了好几种解决方法依然没能解决,最终结合a…

    人工智能 2023年5月23日
    086
  • 存储器的分类及层次结构

    存储器分类 1.按存储介质分类(1)半导体存储器:TTL、MOS(2)磁表面存储器:磁头、载磁体(3)磁芯存储器:硬磁材料、环状元件(4)光盘存储器:激光、磁光材料半导体存储器在断…

    人工智能 2023年7月3日
    085
  • 盘点5个C#开发的、可用于个人博客的系统

    很多程序员在业务时间,都会选择写博客。写技术博客对于程序员,对于程序员是非常有好处的。一篇博客的完成,需要作者思考、总结、整理、然后在把他变成文字、最后还需要学点排版,非常锻炼程序…

    人工智能 2023年5月30日
    088
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球