CIFAR-10基础优化一(加入标准化和激活函数)

一、须知

1.在基础框架上加入激活函数和标准化,基础框架参考:

基于Pytorch的卷积神经网络代码(CIFAR图像分类)及基本构架_百炼成丹的博客-CSDN博客

2.数据集读取路径更改为本地运行,故暂不支持kaggle直接训练,若需要Kaggle服务器run,则要自行更改读取路径,读取方式在基础框架中。

3.本轮优化只对比加入激活函数和标准化后的提升

4.学习率均设置为0.001对比不同优化方案效果

5.优化前,基础框架中,测试集准确率68.5,epoch = 165 ,学习率 = 0.001

二、优化过程

方案一:在池化后添加标准化并加入relu激活函数

网络构建代码:

class Model(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1024, 64),# 182528
            nn.Linear(64, 10)
        )

效果:

CIFAR-10基础优化一(加入标准化和激活函数)

测试集准确率,最高0.72

CIFAR-10基础优化一(加入标准化和激活函数)

测试集损失

CIFAR-10基础优化一(加入标准化和激活函数)

训练集损失

CIFAR-10基础优化一(加入标准化和激活函数)

第十次epoch时间

结论:训练集准确率在100次epoch时,已经达到了99.8%,但此时训练集的正确率已经出现了下降。十代时间为79.76s

在41次epoch时,测试集正确率达到了72%的最高值,但训练集仅有84.6%

分析:100次时出现了明显的过拟合现象,训练后大致稳定在0.7附近

方案二,在池化前添加标准化,加入relu激活函数

网络构建代码:

class Model(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2), nn.BatchNorm2d(32),
            nn.MaxPool2d(2),  nn.ReLU(),
            nn.Conv2d(32, 32, 5, padding=2), nn.BatchNorm2d(32),
            nn.MaxPool2d(2), nn.ReLU(),
            nn.Conv2d(32, 64, 5, padding=2), nn.BatchNorm2d(64),
            nn.MaxPool2d(2), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1024, 64),# 182528
            nn.Linear(64, 10)
        )

效果

CIFAR-10基础优化一(加入标准化和激活函数)

测试集准确率,最高73.1%

CIFAR-10基础优化一(加入标准化和激活函数)

测试集损失

CIFAR-10基础优化一(加入标准化和激活函数)

训练集准确率

CIFAR-10基础优化一(加入标准化和激活函数)

训练集损失

结论:由上可看出当把标准化BN层加在池化前面,则会有更好的表现效果,峰值73.1%,相较方案一多1%左右的准确率。测试集损失函数先降低后升高,十代epoch时间为80.8s,方案二为79.76s

分析:仍然出现了过拟合现象,随池化前就对数据进行标准化,但时间几乎没有明显增长

方案三、在方案二的基础上将relu改为prelu激活

网络构建代码

class Model(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2), nn.BatchNorm2d(32),
            nn.MaxPool2d(2), nn.PReLU(),
            nn.Conv2d(32, 32, 5, padding=2), nn.BatchNorm2d(32),
            nn.MaxPool2d(2), nn.PReLU(),
            nn.Conv2d(32, 64, 5, padding=2), nn.BatchNorm2d(64),
            nn.MaxPool2d(2), nn.PReLU(),
            nn.Flatten(),
            nn.Linear(1024, 64),# 182528
            nn.Linear(64, 10)
        )

效果:峰值71.6%,收敛慢于方法一和方法二。

结论:仍然出现过拟合现象,效果差于方法一和方法二

三、全代码(以方案二为例)

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time

#../input/cifar10-python
train_data = torchvision.datasets.CIFAR10("../dataset", train=True, transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10("../dataset", train=False, transform=torchvision.transforms.ToTensor())

train_dataloader = DataLoader(train_data, batch_size=64, drop_last=True)
test_dataloader = DataLoader(test_data, batch_size=64, drop_last=True)
print(len(train_dataloader)) #781
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

test_data_size = len(test_dataloader) * 64
train_data_size = len(train_dataloader) * 64
print(f'测试集大小为:{test_data_size}')
print(f'训练集大小为:{train_data_size}')
writer = SummaryWriter("../model_logs")

loss_fn = nn.CrossEntropyLoss(reduction='mean')
loss_fn = loss_fn.to(device)
time_able = False # True

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2), nn.BatchNorm2d(32),
            nn.MaxPool2d(2), nn.ReLU(),
            nn.Conv2d(32, 32, 5, padding=2), nn.BatchNorm2d(32),
            nn.MaxPool2d(2), nn.ReLU(),
            nn.Conv2d(32, 64, 5, padding=2), nn.BatchNorm2d(64),
            nn.MaxPool2d(2), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1024, 64),# 182528
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)

        return x

model = Model()
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
epoch = 100
running_loss = 0
total_train_step = 0
total_test_step = 0
if time_able:
    str_time = time.time()
for i in range(epoch):
    print(f'第{i + 1}次epoch')
    total_accuracy1 = 0
    for data in train_dataloader:
        imgs, targets = data
        imgs = imgs.to(device)
        targets = targets.to(device)
        output = model(imgs)
        loss = loss_fn(output, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_step += 1
        if total_train_step % 200 == 0:
            if time_able:
                end_time = time.time()
                print(f'{end_time-str_time}')
            print(f'第{total_train_step}次训练,loss = {loss.item()}')
            writer.add_scalar("train_loss", loss.item(), total_train_step)
        accuracy1 = (output.argmax(1) == targets).sum()
        total_accuracy1 += accuracy1

    # 测试
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            imgs = imgs.to(device)
            targets = targets.to(device)
            outputs = model(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy
    total_test_loss = total_test_loss / test_data_size
    print(f'整体测试集上的loss = {total_test_loss}')
    print(f'整体测试集正确率 = {total_accuracy / test_data_size}')
    print(f'整体训练集正确率 = {total_accuracy1 / train_data_size}')
    writer.add_scalar("test_loss", total_test_loss.item(), total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)
    writer.add_scalar("train_accuracy", total_accuracy1 / train_data_size, total_test_step) # test_step == epoch
    total_test_step += 1

writer.close()

四、优化空间,一下文章更换为adamw优化器优化性能并处理过拟合、加网络层数

Original: https://blog.csdn.net/weixin_42037511/article/details/124099537
Author: 百炼成丹
Title: CIFAR-10基础优化一(加入标准化和激活函数)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/663156/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • mysql 锁机制与原理详解

    前言 不管是数据库,还是很多后端编程语言,都存在锁的机制,锁的存在有效解决了并发情况下对共同资源的抢占,保证了数据的稳定性和一致性,在mysql中,锁是如何工作的呢?其底层的工作原…

    人工智能 2023年7月30日
    066
  • python中df head_解决Python spyder显示不全df列和行的问题

    python中有的df列比较长head的时候会出现省略号,现在数据分析常用的就是基于anaconda的notebook和sypder,在spyder下head的时候就会比较明显的遇…

    人工智能 2023年7月9日
    075
  • SimpleITK使用——1. 进行Resample/Resize操作

    抵扣说明: 1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。 Original: https:…

    人工智能 2023年5月26日
    062
  • Tensorflow简明教程

    一直以来没有系统学习过TF,为了更加高效的投入到近期的关系抽取比赛中,所以准备系统学习一下,系统学习的内容是李理老师的《Tensorflow简明教程》,地址为 http://fan…

    人工智能 2023年5月25日
    069
  • 【机器学习实验五】基于多分类线性SVM实现简易人机猜拳游戏

    文章目录 基于多分类线性SVM&mediapipe手势关键点实现简易人机猜拳游戏 * 基于SMO优化的SVM分类算法完整实现版本 SVM决策结果与数据集可视化 多分类SVM…

    人工智能 2023年7月2日
    0131
  • 数字图像处理 位平面切片/压缩

    一、位平面切片 1、概述 位平面切片是在图像处理中使用的众所周知的技术。在图像压缩中使用位平面切片。位平面切片是将图像转换为多级二值图像。 然后使用不同的算法压缩这些二进制图像。使…

    人工智能 2023年6月20日
    079
  • 部署农业知识图谱开源项目

    项目是上海市《农业信息服务平台及农业大数据综合利用研究》子课题《上海农业农村大数据共享服务平台建设和应用》的研究成果。 该课题是由上海市农业委员会信息中心主持,以”致富…

    人工智能 2023年6月1日
    099
  • 李航《统计学习方法》第2版 第6章课后习题答案

    ; 习题6.1 题目:确认逻辑斯谛分布属于指数分布族. 习题6.2 题目:写出逻辑斯谛回归模型学习的梯度下降算法. ; 习题6.3 题目:写出最大熵模型学习的DFP算法.(关于一般…

    人工智能 2023年6月18日
    076
  • 基于人机协作的无人集群搜索方法研究

    基于人机协作的无人集群搜索方法研究 人工智能技术与咨询 点击蓝字 关注我们 关键词: 无人集群 ; 人机协作 ; 动态规划 ; 多Agent系统 摘要: 人与机器人交互是当前一项研…

    人工智能 2023年6月1日
    097
  • 最大熵模型详解

    最大熵模型学习过程 前言 在将最大熵模型之前,先学习一下准备知识。 ①拉格朗日乘子法 ②贝叶斯定理 Bayes定理用来描述两个条件概率之间的关系。若计P(A)和P(B)分别表示事件…

    人工智能 2023年6月16日
    086
  • 数据处理笔记

    1.利用read_csv读取txt文档 例1:原始数据是txt文档,格式如下所示: 导入必备…

    人工智能 2023年7月7日
    089
  • 安卓隐藏摄像_一款可以隐藏录像的app

    一款可以隐藏录像的app是一款专业的屏幕录制软件,一款可以隐藏录像的app可以帮助用户方便又快捷的录制屏幕视频,一款可以隐藏录像的app还提供悬浮窗功能,能够自由控制软件开始录屏和…

    人工智能 2023年5月27日
    0122
  • pandas学习

    以下是我使用pansas时的纪录,没有过多的解释,只是自己的练习,如果想详细学习可以参考pandsd相关文档,想要快速了解的,我可以推荐一下另一位博主的介绍,很详细,可供参考学习 …

    人工智能 2023年7月8日
    074
  • 目标检测算法评价指标之mAP

    随着计算机技术的发展和计算机视觉原理的广泛应用,利用计算机图像处理技术对目标进行实时跟踪研究越来越热门,对目标进行动态实时跟踪定位在智能化交通系统、智能监控系统、军事目标检测及医学…

    人工智能 2023年7月12日
    071
  • 关于回归分析分类

    目的:当需要用一个数学表达式(模型)表示多个因素(原因)与另外一个因素(因素)之间关系时,可选用回归分析法。 应用:1)分析哪些自变量对因变量存在显著影响作用,R方值可以不要求大于…

    人工智能 2023年7月2日
    068
  • NLP发展大事记:顶会,预训练大模型,BERT系列

    文章目录 * – 1. NLP发展重要时间线 – + 时间线 – 2. NLP以BERT发展的延伸 – 3. NLP领域顶会 1. …

    人工智能 2023年5月31日
    0118
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球