PyTorch介绍-优化模型参数

既然已经有模型和数据了,是时候在数据上优化模型参数来训练、验证和测试它了。模型训练是一个迭代过程;在每一次迭代( epoch),模型会作出一个预测,计算其预测误差( loss),收集误差关于模型参数的导数(如前一节所述),并使用梯度 优化这些参数。关于这一过程的详细信息,可以观看backpropagation from 3Blue1Brown

先决代码

import torch
from torch import import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

training_data = datasets.FashionMNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = Dataloader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512)
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

输出:

点击查看代码

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

超参数

我们为训练定义了以下超参数:

  • Epoch – 迭代数据集的次数
  • Batch Size – 更新参数前,通过网络传播的数据样本数
  • Learning Rate -每次 batch/epoch,更新模型的程度。较小的值会导致学习速度较慢,而较大的值可能会导致训练过程中不可预测的行为。
learning_rate = 1e-3
batch_size = 64
epochs = 5

优化循环

一旦我们设置好超参数,就可以通过一个optimization loop来训练和优化网络。每次optimization loop的迭代称为一个epoch。

每个epoch包含两部分:

  • 训练Loop – 迭代训练集,尝试收敛到最佳参数
  • 验证\测试Loop – 迭代测试集,检查模型性能是否提高。

Loss Function

给定一些数据,未经训练的网络可能不会给出正确答案。 Loss function衡量了所获结果和目标值的不同程度,训练时正是要最小化损失函数。为了计算loss我们使用给定样本对的输入作出预测,并与其真实标签做对比。

将模型输出的logist传入 nn.CrossEntropyLoss, 该函数将标准化logits并计算预测误差。

Initialize the loss function
loss_fn = nn.CrossEntropyLoss()

优化器

优化是每次训练时调整模型参数,减少模型误差的过程。 优化算法定义了该过程是如何实现的(该例中我们使用了Stochastic Gradient Descent随机梯度下降)。所有的优化逻辑都被封装在了 optimizer 对象。在这里,我们使用SGD优化器;此外,在PyTorch中还有许多不同的优化器,例如ADAM和RMSProp,对不同类型的模型和数据都很有效。

我们通过注册需要训练的模型参数来初始化优化器,并传入学习率超参数。

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

在训练循环中,优化分为三个步骤:

  • 调用 optimizer.zero_grad()重置模型参数的梯度。默认情况下梯度相加,为防止重复计数,我们在每次迭代时显示地将它们归零。
  • 调用loss.backwards()反向传播预测误差。PyTorch计算loss关于每个参数的梯度。
  • 调用 optimizer.step(),通过在反向传播中得到的梯度调整参数。

完整实现

我们定义了 train_loop 循环迭代optimization代码, test_loop 评估模型在测试集上的性能。

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

def test_loop(dataloader, model, loss_fn):
      size = len(dataloader.dataset)
      num_batches = len(dataloader)
      test_loss, correct = 0, 0

      with torch.no_grad():
          for X, y in dataloader:
              pred = model(X)
              test_loss += loss_fn(pred, y).item()
              correct += (pred.argmax(1) == y).type(torch.float).sum().item()

      test_loss /= num_batches
      correct /= size
      print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

初始化损失函数和优化器,传入 train_looptest_loop。随意增加epoch,以跟踪模型不断改进的性能。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

输出:

点击查看代码

Epoch 1
loss: 2.124757  [    0/60000]
loss: 2.107859  [ 6400/60000]
loss: 2.045332  [12800/60000]
loss: 2.061512  [19200/60000]
loss: 2.002954  [25600/60000]
loss: 1.940844  [32000/60000]
loss: 1.962774  [38400/60000]
loss: 1.874285  [44800/60000]
loss: 1.875532  [51200/60000]
loss: 1.802694  [57600/60000]
Test Error:
 Accuracy: 58.7%, Avg loss: 1.794751

Epoch 3
loss: 1.499763  [    0/60000]
loss: 1.472005  [ 6400/60000]
loss: 1.319050  [12800/60000]
loss: 1.399100  [19200/60000]
loss: 1.283040  [25600/60000]
loss: 1.279892  [32000/60000]
loss: 1.300507  [38400/60000]
loss: 1.221794  [44800/60000]
loss: 1.262865  [51200/60000]
loss: 1.173478  [57600/60000]
Test Error:
 Accuracy: 63.9%, Avg loss: 1.193923

Epoch 5
loss: 1.114492  [    0/60000]
loss: 1.130664  [ 6400/60000]
loss: 0.944653  [12800/60000]
loss: 1.083935  [19200/60000]
loss: 0.961972  [25600/60000]
loss: 0.981254  [32000/60000]
loss: 1.033072  [38400/60000]
loss: 0.961604  [44800/60000]
loss: 1.007507  [51200/60000]
loss: 0.948494  [57600/60000]
Test Error:
 Accuracy: 66.0%, Avg loss: 0.956025

Epoch 7
loss: 0.926312  [    0/60000]
loss: 0.987333  [ 6400/60000]
loss: 0.768049  [12800/60000]
loss: 0.943189  [19200/60000]
loss: 0.831892  [25600/60000]
loss: 0.833098  [32000/60000]
loss: 0.916814  [38400/60000]
loss: 0.850216  [44800/60000]
loss: 0.887719  [51200/60000]
loss: 0.846100  [57600/60000]
Test Error:
 Accuracy: 68.5%, Avg loss: 0.844885

Epoch 9
loss: 0.814177  [    0/60000]
loss: 0.904296  [ 6400/60000]
loss: 0.667563  [12800/60000]
loss: 0.862825  [19200/60000]
loss: 0.764706  [25600/60000]
loss: 0.750034  [32000/60000]
loss: 0.848550  [38400/60000]
loss: 0.794559  [44800/60000]
loss: 0.821466  [51200/60000]
loss: 0.785530  [57600/60000]
Test Error:
 Accuracy: 70.9%, Avg loss: 0.780144

Done!

Original: https://www.cnblogs.com/DeepRS/p/15753763.html
Author: Deep_RS
Title: PyTorch介绍-优化模型参数

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/609934/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • shell 下载aliplayer 的视频

    #!/bin/bash url="http://v.example.com/8dedaec32ca9415eaa8ccd423ee33bf3/" #下载视频索引…

    Linux 2023年5月28日
    0104
  • DML

    用来对数据库中的表的数据进行增删改 添加数据 给指定列添加数据 insert into <表名> (&#x5217;&#x540D;1, &#x…

    Linux 2023年6月7日
    096
  • redis 订阅与发布

    Reference: https://redisbook.readthedocs.io/en/latest/feature/pubsub.html Redis 的 SUBSCRIB…

    Linux 2023年5月28日
    0112
  • Redis主从复制的配置和实现原理

    Redis的持久化功能在一定程度上保证了数据的安全性,即便是服务器宕机的情况下,也可以保证数据的丢失非常少。通常,为了避免服务的单点故障,会把数据复制到多个副本放在不同的服务器上,…

    Linux 2023年5月28日
    084
  • jarwarSpringBoot加载包内外资源的方式,告别FileNotFoundException吧

    工作中常常会用到文件加载,然后又经常忘记,印象不深,没有系统性研究过,从最初的war包项目到现在的springboot项目,从加载外部文件到加载自身jar包内文件,也发生了许多变化…

    Linux 2023年6月6日
    0108
  • 自动化集成:Pipeline整合Docker+K8S

    前言:该系列文章,围绕持续集成:Jenkins+Docker+K8S相关组件,实现自动化管理源码编译、打包、镜像构建、部署等操作; 本篇文章主要描述流水线集成K8S用法。 一、背景…

    Linux 2023年5月27日
    0169
  • docker 容器大小查看及清理docker磁盘空间

    这篇文章最初是由博主创作的。请注明转载的来源: [En] This article is originally created by the blogger. Please ind…

    Linux 2023年5月27日
    098
  • VMware 虚拟机图文安装和配置 Rocky Linux 8.5 教程

    前言这是《VMware 虚拟机图文安装和配置 AlmaLinux OS 8.6 教程》一文的姐妹篇教程,如果你需要阅读它,请点击这里。2020 年,CentOS 宣布:计划未来将重…

    Linux 2023年6月7日
    0224
  • linux应急响应具体操作

    第一件事情应该是切断网络,但是有些环境不允许网络断开,就只能跳过这一步。 1、查看历史命令 ​发现Linux 服务器被攻击,要做应急响应,登录主机后的第一件事,就是查看主机的历史命…

    Linux 2023年6月14日
    0100
  • SSM 集成 Freemarker 模板引擎

    在前后端分离的大趋势下,项目开发过程中,应尽量减少前端和后台的依赖和耦合,前端和后台尽可能采用 ajax 进行交互;但是全站 ajax,不利于网站 SEO,所以引入模板引擎,尽量减…

    Linux 2023年6月14日
    095
  • 分布式事务一站式解决方案与实现

    1 本地事务 1.1 事务的概述 事务指逻辑上的一组操作,组成这组操作的各个单元,要么全部成功,要么全部不成功。从而确保了数据的准确与安全。 1.2 事务的四大特性 原子性(Ato…

    Linux 2023年6月13日
    0140
  • 渣画质视频秒变清晰,“达芬奇”工具集帮你自动搞定

    https://www.msra.cn/zh-cn/news/features/davinci 2022-06-23 | 作者:微软亚洲研究院 编者按:是否时常”考古&…

    Linux 2023年6月13日
    0184
  • Java — 反射

    程序在运行中也可以获取类的变量和方法信息,并通过获取到的信息来创建对象。程序不必再编译期就完成确定,在运行期仍然可以扩展。 示例:学生类 public class Student …

    Linux 2023年6月8日
    0137
  • C盘空间怎么不够了?原来杀毒软件隔离区太大了

    打开小红伞的隔离区,文件全选,删除,居然一下多出20个G. C盘空间怎么不够了?原来杀毒软件隔离区太大了 麻了。打开小红伞的隔离区,文件全选,删除,一下多出20个G. 还有一个比较…

    Linux 2023年6月6日
    0115
  • Nginx基础入门篇(1)—优势及安装

    一、Nginx 的优势 1.1发展趋势: 2016年: 1.2、简介 Nginx (engine x) 是一个高性能的HTTP(解决C10k的问题)和反向代理服务器,也是一个IMA…

    Linux 2023年6月6日
    096
  • 数据结构和算法的关系

    针对Python数据结构与算法(裘宗燕版)中的第一章绪论最后的问题 数据结构 概念 数据与数据之间的结构关系(数组、队列、树、图等结构) 类别 分为 逻辑数据结构和 存储数据结构两…

    Linux 2023年6月7日
    0135
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球