PyTorch 介绍 | 使用 TORCH.AUTOGRAD 自动微分

训练神经网络时,最常用的算法就是 反向传播。在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整。

为了计算这些梯度,PyTorch内置了名为 torch.autograd 的微分引擎。它支持任意计算图的自动梯度计算。

一个最简单的单层神经网络,输入 x,参数 wb,某个损失函数。它可以用PyTorch这样定义:

import torch

x = torch.ones(5)      # input tensor
y = torch.zeros(3)     # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w) + b    # 矩阵乘法
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

Tensors、Functions and Computational graph

上述代码定义了下面的 computational graph:

PyTorch 介绍 | 使用 TORCH.AUTOGRAD 自动微分

在该网络中, wbparameters,是我们需要优化的。因此,我们需要能够计算损失函数关于这些变量的梯度。因此,我们设置了这些tensor的 requires_grad 属性。

注意:在创建tensor时可以设置 requires_grad 的值,或者创建之后使用 x.requires_grad_(True) 方法。

我们应用到tensor上构成计算图的function实际上是 Function 类的对象。该对象知道如何计算前向的函数,还有怎么计算反向传播步骤中函数的导数。反向传播函数存储在tensor的 grad_fn 属性中。You can find more information of Function in the documentation

print('Gradient function for z =', z.grad_fn)
print('Gradient function for loss =', loss.grad_fn)

输出:

Gradient function for z = <addbackward0 object at 0x7faea5ef7e10>
Gradient function for loss = <binarycrossentropywithlogitsbackward0 object at 0x7faea5ef7e10>
</binarycrossentropywithlogitsbackward0></addbackward0>

计算梯度

为了优化神经网络的参数权重,我们需要计算损失函数关于参数的导数,即,我们需要利用一些固定的 xy 计算(\frac{\partial loss}{\partial w})和(\frac{\partial loss}{\partial b})。为计算这些导数,可以调用 loss.backward(),然后从 w.gradb.grad

loss.backward()
print(w.grad)
print(b.grad)

输出:

tensor([[0.0043, 0.2572, 0.3275],
        [0.0043, 0.2572, 0.3275],
        [0.0043, 0.2572, 0.3275],
        [0.0043, 0.2572, 0.3275],
        [0.0043, 0.2572, 0.3275]])
tensor([0.0043, 0.2572, 0.3275])

注意:

  • 我们只能在计算图中 requires_grad=True 的叶节点获得 grad 属性。对于其它节点,梯度是无效的。
  • 出于性能原因,我们只能对给定的graph使用 backward 执行梯度计算。如果需要在同一graph调用若干次 backward,在调用时,需要传入 retain_graph=True

禁用梯度跟踪

默认情况下,所有 requires_grad=True 的tensor都会跟踪它们的计算历史,并支持梯度计算。但是在一些情况下并不需要,例如,当我们已经训练了一个模型,并将其用在一些输入数据上,即,仅仅经过网络做前向运算。那么可以在我们的计算代码外包围 torch.no_grad() 块停止跟踪计算。

z = torch.matmul(x, w) + b
print(z.requires_grad())

with torch.no_grad():
    z = torch.matmul(x, w) + b
print(z.requires_grad)

输出:

True
False

在tensor上使用 detach() 也能达到同样的效果

z = torch.matmul(x, w) + b
z_det = z.detach()
print(z_det.requires_grad)

输出:

False

禁止梯度跟踪的几个原因:

  • 将神经网络的一些参数标记为 frozen parameters。这在finetuning a pretrained network中是非常常见的脚本。
  • 当你只做前向过程,用于 speed up computations,因为tensor计算而不跟踪梯度将会更有效。

More on Coputational Graphs

概念上,autograd在一个由Function对象组成的有向无环图(DAG)中保留了数据(tensors)记录,还有所有执行的操作(以及由此产生的新的tensors)。在DAG中,叶节点是输入tensor,根节点是输出tensors。通过从根到叶跟踪该图,可以使用链式法则自动地计算梯度。

在前向过程中,autograd同时进行两件事:

  • 运行请求的操作计算结果tensor
  • 在DAG中保存操作的梯度函数

当在DAG根部调用 .backward()时,后向过程就会开始。 autograd会:

  • 由每一个 .grad_fn 计算梯度。
  • 在对应tensor的 ‘.grad’ 属性累积梯度
  • 使用链式法则,一直传播到叶tensor

注意: DAGs在PyTorch是动态的,需要注意的一点是,graph是从头开始创建的;在每次调用 .backward() 之后,autograd开始生成一个新的graph。这允许你在模型中使用控制流语句;如果需要,你可以在每次迭代中改变shape,size,and operations。

选读:Tensor梯度和Jacobian Products

延伸阅读

Original: https://www.cnblogs.com/DeepRS/p/15743698.html
Author: Deep_RS
Title: PyTorch 介绍 | 使用 TORCH.AUTOGRAD 自动微分

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/714148/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • Tcpdump命令抓包详细分析【转】

    1 起因 前段时间,一直在调线上的一个问题:线上应用接受POST请求,请求body中的参数获取不全,存在丢失的状况。这个问题是偶发性的,大概发生的几率为5%-10%左右,这个概率已…

    技术杂谈 2023年5月31日
    0108
  • 力扣算法题1.两数之和(Java)

    力扣算法题1.两数之和(Java) 给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标…

    技术杂谈 2023年7月25日
    085
  • 2022年rhce最新认证—(满分通过)

    RHCE认证 重要配置信息 在考试期间,除了您就坐位置的台式机之外,还将使用多个虚拟系统。您不具有台式机系统的 root 访问权,但具有对虚拟系统的完整 root 访问权。 系统信…

    技术杂谈 2023年6月21日
    097
  • 成大事者,必精读也!!!

    一:沉稳 (1)不要随便显露你的情绪。 (2)不要逢人就诉说你的困难和遭遇。 (3)在征询别人的意见之前,自己先思考,但不要先讲。 (4)不要一有机会就唠叨你的不满。 (5)重要的…

    技术杂谈 2023年7月23日
    090
  • cocos 动画系统

    前面的话 cocos 动画系统支持任意组件属性和用户自定义属性的驱动,再加上可任意编辑的时间曲线和移动轨迹编辑功能,就可以制作出各种动态效果 概述 Animation 组件可以以动…

    技术杂谈 2023年5月30日
    0116
  • Perl安装教程

    1.Perl下载地址 https://platform.activestate.com/tangxing806/ActivePerl-5.28/distributions 2.下载…

    技术杂谈 2023年5月31日
    0112
  • react新手demo——TodoList

    今天我们就使用 react 来实现一个简易版的 todolist ,我们可以使用这个 demo 进行 list 的增删改差,实际效果如上图所示。大家可以 clone下来查看:rea…

    技术杂谈 2023年5月31日
    0130
  • phpcms如何在前台文章列表前显示所属类别名称

    最近做单位网站模版遇到的问题,欲实现的效果: 但是phpcms中自带的文章列表标签没有这个功能,数据库中文章表中也只有类别id的字段,因此不能通过简单的{$r[catname]}读…

    技术杂谈 2023年7月11日
    099
  • C# Task的用法详解 z

    1、Task的优势ThreadPool相比Thread来说具备了很多优势,但是ThreadPool却又存在一些使用上的不方便。比如:◆ ThreadPool不支持线程的取消、完成、…

    技术杂谈 2023年6月1日
    099
  • Graphviz

    Graphviz 是一款由 AT&T Research 和 Lucent Bell 实验室开源的可视化图形工具,可以很方便的用来绘制结构化的图形网络,支持多种格式输出 使用…

    技术杂谈 2023年5月31日
    0121
  • 8 月份全球 Wi-Fi6 技术标准更新

    1.巴林 TRA 启用 Wi-Fi6 2022 年 8 月 17 日,巴林电信管理局 (TRA) 批准了 5470-5725 MHz 和 5925-6425 MHz 频段用于 Wi…

    技术杂谈 2023年6月21日
    076
  • IO流–创建文件夹,复制移动文件

    创建多级文件夹 final String ROOTPATH = "/Users/mac/Downloads"; // 默认文件下载的位置 @Test //创建多…

    技术杂谈 2023年7月24日
    083
  • 奇安信服务端一二面面经(来源牛客)

    一.一面 应用层——HTTP: ​ 当输入URL后,对URL进行解析。​ URL解析方式如下:​ https://www.baidu.com/​ https:代表访问数组的协议(h…

    技术杂谈 2023年7月11日
    0107
  • 【Qt+VS】Qt图标不显示|Qt程序运行时图标不显示

    4、配置属性-》自定义生成工具-》常规命令行:【”你自己的rcc.exe路径” -name “%(Filename)” -no-co…

    技术杂谈 2023年5月30日
    0115
  • 机器学习(6)K近邻算法

    k-近邻,通过离你最近的来判断你的类别 例子: 定义:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近的样本中大多数属于某一类别),则该样本属于这个类别 K近邻需要做标准化…

    技术杂谈 2023年7月23日
    077
  • RTC 系统音视频传输弱网对抗技术

    qq: 517712484 wx: ldbgliet Original: https://www.cnblogs.com/lidabo/p/16501720.htmlAuthor:…

    技术杂谈 2023年5月31日
    0100
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球