PYTORCH: 60分钟 | TORCH.AUTOGRAD

torch.autograd 是PyTorch的自动微分引擎,用以推动神经网络训练。在本节,你将会对autograd如何帮助神经网络训练的概念有所理解。

背景

神经网络(NNs)是在输入数据上执行的嵌套函数的集合。这些函数由参数(权重、偏置)定义,并在PyTorch中保存于tensors中。

训练NN需要两个步骤:

  • 前向传播:在前向传播中(forward prop),神经网络作出关于正确输出的最佳预测。它使输入数据经过每一个函数来作出预测。
  • 反向传播:在反向传播中(backprop),神经网络根据其预测中的误差来调整其参数,它通过从输出向后遍历,收集关于函数参数的误差的导数(梯度),并使用梯度下降优化参数。有关更多关于反向传播的细节,参见video from 3Blue1Brownvideo from 3Blue1Brown。

在PyTorch中的使用

让我们来看一下单个训练步骤。对于这个例子,我们从 torchvision 加载了一个预训练的resnet18模型。我们创建了一个随机数据tensor,用以表示一个3通道图片,其高和宽均为64,而其对应的 label 初始化为某一随机值。

import torch, torchvision
model = torchvision.models.resnet18(pretrained=True)
data = torch.rand(1, 3, 64, 64)
labels = torch.rand(1, 1000)

接下来,我们将数据输入模型,经过模型的每一层最后作出预测。这是 前向过程

prediction = model(data) # forward pass

我们使用模型的预测及其对应的标签计算误差( loss)。下一步是通过网络反向传播误差。当在误差tensor上调用 .backward()时,反向传播开始。然后,Autograd计算针对每一个模型参数的梯度,并将其保存在参数的 .grad 属性中。

loss = (prediction - labels).sum()
loss.backward() # backward pass

接下来,我们加载一个优化器,在此案例中是SGD,学习率是0.01,动量参数(momentum)是0.9。我们在优化器中注册所有的模型参数。

optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)

最后,我们调用 .step()启动梯度下降。优化器会通过保存在 .grad 的参数梯度调整所有参数。

optim.step() # gradient descent

此时,你已拥有训练神经网络所需的一切。以下部分详细介绍了autograd的工作原理 – 可随意跳过。

Autograd中的微分

让我们来看一下 autograd是如何收集梯度的。创建两个tensor ab,并且 requires_grad=True。这向 autograd 发出信号,跟踪在它们上执行的每一个操作。

import torch
a = torch.tensor([2., 3.], requires_grad=True)
b = torch.tensor([6., 4.], requires_grad=True)

ab 创建tensor Q

[Q = 3a^2 – b^2 ]

Q = 3*a**2 - b**2

假设 ab 是一个神经网络的参数, Q 是误差。在NN训练中,求解关于参数的梯度,即:

[\frac{\partial Q}{\partial a} = 6a ]

[\frac{\partial Q}{\partial b} = -2b ]

当我们在 Q 上调用 .backward(),autograd计算以上梯度并保存在对应tensor的 .grad 属性中。
Q.backward() 是一个向量,因此我们需要在 Q.backward() 中显示地传递一个 gradient 参数。 gradient 是一个和 Q相同形状的tensor,它表示Q关于其本身的梯度,即:

[\frac{\partial Q}{\partial Q} = 1 ]

等效地,我们还可以将Q聚合为一个标量,并隐式的向后调用,如 Q.sum().backward()

external_grad = torch.tensor([1., 1.])
Q.backward(gradient=external_grad)

梯度现在被保存在 a.gradb.grad

## 检查收集的梯度是否正确
print(9*a**2 == a.grad)
print(-2*b == b.grad)

输出:

tensor([True, True])
tensor([Ture, True])

选读 – 使用 autograd 进行矢量微分

计算图

从概念上来说,autograd在一个由Function对象组成的有向无环图(DAG)中记录了数据(tensors)和所有执行的操作(连同由此产生的新tensors)。在DAG中,叶节点是输入tensors,根节点是输出tensors。通过从根节点到叶节点跟踪此图,你可以使用链式法则自动计算梯度。

在前向过程中,autograd同时进行两件事:

  • 执行请求的操作计算结果tensor,
  • 在DAG中保留操作的 gradient function

在DAG根节点处调用 .backward() 时启动反向过程。然后 autograd

  • 由每个 .grad_fn计算梯度,
  • 将梯度累积在其对应tensor的 .grad 属性中,
  • 使用链式法则,将梯度一直传播到叶节点。

下图是以上例子中DAG的可视化表示。在该图中,箭头表示前向过程的方向。节点表示在前向过程中每一个操作的backward functions。蓝色叶节点表示我们的tensor ab

PYTORCH: 60分钟 | TORCH.AUTOGRAD

注意:DAGs在PyTorch中是动态的。需要重点注意的是:DAG是从头开始重新创建的,在每次 .backward调用时,autograd开始填充一个新图。这正是在模型中允许你使用控制流语句的原因。如果需要,你可以在每次迭代中更改形状、大小和操作。

从DAG中排除

torch.autograd 跟踪所有 requires_grad=True 的tensor上的操作。对于不要求计算梯度的tensor, requires_grad=False,并将其从梯度计算DAG中排除。

当一个操作就算只有一个输入tensor有 requires_grad=True,其输出的tensor仍然要计算梯度。

x = torch.rand(5, 5)
y = torch.rand(5, 5)
z = torch.rand((5, 5), requires_grad=True)

a = x + y
print(f"Does 'a' require gradients? : {a.requires_grad}")
b = x + z
print(f"Does 'b' require gradients? : {b.requires_grad}")

输出:

Does a require gradients? : False
Does b require gradients? : True

在神经网络中,不计算梯度的参数通常成为冻结参数。如果你事先知道不需要这些参数的梯度,那冻结模型的一部分很有用(这通过减少autograd计算量提供了一些性能优势)。

从DAG中排除的另一个重要的常见用法是finetuning a pretrained network

在finetune中,我们冻结模型的大部分参数,并且通常只修改分类层以对新的标签作出预测。让我们通过一个小例子来演示这一点。像之前一样,我们加载一个预训练resnet18模型,并且冻结所有参数。

from torch import nn, optim

model = torchvision.models.resnet18(pretrained=True)

冻结网络中的所有参数
for param in model.parameters():
    param.requires_grad = False

假设我们要在一个10标签数据集上微调模型。在resnet中,分类层是最后的线性层 model.fc。我们可以简单地用一个新的线性层(默认情况下未冻结)替换它作为我们的分类器。

model.fc = nn.Linear(512, 10)

模型中除了 model.fc 的所有参数均被冻结。需要计算梯度的参数仅仅是 model.fc 的权重和偏置

仅优化分类层
optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)

注意,尽管我们在优化器中注册了所有参数,但是计算梯度(在梯度下降中更新)的参数仅是分类层的权重和偏置。

The same exclusionary functionality is available as a context manager in torch.no_grad().

Original: https://www.cnblogs.com/DeepRS/p/15715297.html
Author: Deep_RS
Title: PYTORCH: 60分钟 | TORCH.AUTOGRAD

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/621346/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • shell脚本调试方法

    set -x bash -x test.sh 作者:习惯沉淀 如果文中有误或对本文有不同的见解,欢迎在评论区留言。 如果觉得文章对你有帮助,请点击文章右下角【推荐】一下。您的鼓励是…

    Linux 2023年5月28日
    0114
  • [apue] linux 文件访问权限那些事儿

    说到 linux 上的文件权限,其实我们在说两个实体,一是文件,二是进程。一个进程能不能访问一个文件,其实由三部分内容决定: 下面先简单说明一下这些基本概念,最后再说明它们是如何相…

    Linux 2023年6月6日
    0123
  • WEB自动化-11-数据驱动

    11 数据驱动 数据驱动是测试框架中一个非常好的功能,使用数据驱动,可以在不增加代码量的情况下生成不同的测试策略。下面我们来看看在Cypress中的数据驱动使用方法。 11.1 数…

    Linux 2023年6月7日
    0136
  • Danskin’s Theorem

    Statement 1 假设 (\phi(x,z)) 为含有两个变量的连续函数: (\phi : \mathbb{R}^n \times Z \rightarrow \mathbb…

    Linux 2023年6月7日
    0106
  • git reset 命令删除本地文件怎么恢复

    执行 git reflog命令可以看到曾经执行过的操作,还有版本序号。 执行 git reset –hard HEAD@{【填那个序号】}就可以恢复本地删除的文件了! …

    Linux 2023年6月14日
    0130
  • 大数据之Hadoop的HDFS存储优化—异构存储(冷热数据分离)

    异构存储主要解决,不同的数据,储存在不同类型的硬盘中,达到最佳性能的问题 1)存储类型 RAM_DISK:内存镜像文件系统 SSD:SSD固态硬盘 DISK:普通磁盘,在HDFS中…

    Linux 2023年6月8日
    0114
  • logstash写入文件慢的问题排查记录

    终于找到根本原因了!!!!! logstash部署到k8s集群内部的,当所在节点的CPU资源被其他应用抢占时,logstash的处理速度就会降低 问题现象 logstash从kaf…

    Linux 2023年6月14日
    0190
  • 【Example】C++ Vector 内存预分配的良好习惯

    为什么要对 Vector 进行内存预分配? 1,Vector 本身是一个内存只会增长不会减小的容器。 2,Vector 存在 size 和 capacity 两种计数,size 即…

    Linux 2023年6月13日
    0130
  • Windows+VSCode编译在Linux-x86_64环境上运行的程序

    一、简介 本文主要介绍在Windows平台上使用VSCode,从而可以一键编译出运行在Linux-x86_64环境中的程序或库。 二、实现方式 ① 交叉编译 ② WSL(Windo…

    Linux 2023年6月7日
    0115
  • ASP.NET Core 2.2 : 二十二. 多样性的配置方式

    大多数应用都离不开配置,本章将介绍ASP.NET Core中常见的几种配置方式及系统内部实现的机制。(ASP.NET Core 系列目录) 说到配置,第一印象可能就是”…

    Linux 2023年6月7日
    0132
  • 三种移除list中的元素(可靠)

    /** * 直接使用foreach方法移除list中的元素会抛异常 * Exception in thread "main" java.util.Concurr…

    Linux 2023年6月7日
    0111
  • 聊聊消息中心的设计与实现逻辑

    厌烦被消息打扰,又怕突然间的安静; 一、业务背景 微服务的架构体系中,会存在很多基础服务,提供一些大部分服务都可能需要的能力,比如文件管理、MQ队列、缓存机制、消息中心等等,这些服…

    Linux 2023年6月14日
    0129
  • Linux高可用之Keepalived

    注意: 各节点时间必须同步 确保各节点的用于集群服务的接口支持MULTICAST通信(组播); 安装 从CentOS 6.4开始keepalived随系统base仓库提供,可以使用…

    Linux 2023年5月27日
    0155
  • VMware vSphere 7 Update 3 下载

    请访问原文链接:https://sysin.org/blog/vmware-vsphere-7-u3/,查看最新版。原创作品,转载请保留出处。 vSphere 7 Update 3…

    Linux 2023年5月27日
    0136
  • [云原生]Kubernetes-实战入门(第4章)

    一、Namespace 二、Pod 三、Label 四、Deployment 五、Service 参考: Kubernetes(K8S) 入门进阶实战完整教程,黑马程序员K8S全套…

    Linux 2023年6月13日
    0141
  • MTSP问题的简单介绍

    1. TSP问题与MTSP问题 1.1 TSP与MTSP问题的介绍: TSP:是指旅行家(1名)要旅行n个城市,要求各个城市经历且仅经历一次然后回到出发城市,并要求所走的 路程最短…

    Linux 2023年6月14日
    0252
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球