PyTorch 介绍 | 使用 TORCH.AUTOGRAD 自动微分

训练神经网络时,最常用的算法就是 反向传播。在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整。

为了计算这些梯度,PyTorch内置了名为 torch.autograd 的微分引擎。它支持任意计算图的自动梯度计算。

一个最简单的单层神经网络,输入 x,参数 wb,某个损失函数。它可以用PyTorch这样定义:

import torch

x = torch.ones(5)      # input tensor
y = torch.zeros(3)     # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w) + b    # 矩阵乘法
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

Tensors、Functions and Computational graph

上述代码定义了下面的 computational graph:

PyTorch 介绍 | 使用 TORCH.AUTOGRAD 自动微分

在该网络中, wbparameters,是我们需要优化的。因此,我们需要能够计算损失函数关于这些变量的梯度。因此,我们设置了这些tensor的 requires_grad 属性。

注意:在创建tensor时可以设置 requires_grad 的值,或者创建之后使用 x.requires_grad_(True) 方法。

我们应用到tensor上构成计算图的function实际上是 Function 类的对象。该对象知道如何计算前向的函数,还有怎么计算反向传播步骤中函数的导数。反向传播函数存储在tensor的 grad_fn 属性中。You can find more information of Function in the documentation

print('Gradient function for z =', z.grad_fn)
print('Gradient function for loss =', loss.grad_fn)

输出:

Gradient function for z = <addbackward0 object at 0x7faea5ef7e10>
Gradient function for loss = <binarycrossentropywithlogitsbackward0 object at 0x7faea5ef7e10>
</binarycrossentropywithlogitsbackward0></addbackward0>

计算梯度

为了优化神经网络的参数权重,我们需要计算损失函数关于参数的导数,即,我们需要利用一些固定的 xy 计算(\frac{\partial loss}{\partial w})和(\frac{\partial loss}{\partial b})。为计算这些导数,可以调用 loss.backward(),然后从 w.gradb.grad

loss.backward()
print(w.grad)
print(b.grad)

输出:

tensor([[0.0043, 0.2572, 0.3275],
        [0.0043, 0.2572, 0.3275],
        [0.0043, 0.2572, 0.3275],
        [0.0043, 0.2572, 0.3275],
        [0.0043, 0.2572, 0.3275]])
tensor([0.0043, 0.2572, 0.3275])

注意:

  • 我们只能在计算图中 requires_grad=True 的叶节点获得 grad 属性。对于其它节点,梯度是无效的。
  • 出于性能原因,我们只能对给定的graph使用 backward 执行梯度计算。如果需要在同一graph调用若干次 backward,在调用时,需要传入 retain_graph=True

禁用梯度跟踪

默认情况下,所有 requires_grad=True 的tensor都会跟踪它们的计算历史,并支持梯度计算。但是在一些情况下并不需要,例如,当我们已经训练了一个模型,并将其用在一些输入数据上,即,仅仅经过网络做前向运算。那么可以在我们的计算代码外包围 torch.no_grad() 块停止跟踪计算。

z = torch.matmul(x, w) + b
print(z.requires_grad())

with torch.no_grad():
    z = torch.matmul(x, w) + b
print(z.requires_grad)

输出:

True
False

在tensor上使用 detach() 也能达到同样的效果

z = torch.matmul(x, w) + b
z_det = z.detach()
print(z_det.requires_grad)

输出:

False

禁止梯度跟踪的几个原因:

  • 将神经网络的一些参数标记为 frozen parameters。这在finetuning a pretrained network中是非常常见的脚本。
  • 当你只做前向过程,用于 speed up computations,因为tensor计算而不跟踪梯度将会更有效。

More on Coputational Graphs

概念上,autograd在一个由Function对象组成的有向无环图(DAG)中保留了数据(tensors)记录,还有所有执行的操作(以及由此产生的新的tensors)。在DAG中,叶节点是输入tensor,根节点是输出tensors。通过从根到叶跟踪该图,可以使用链式法则自动地计算梯度。

在前向过程中,autograd同时进行两件事:

  • 运行请求的操作计算结果tensor
  • 在DAG中保存操作的梯度函数

当在DAG根部调用 .backward()时,后向过程就会开始。 autograd会:

  • 由每一个 .grad_fn 计算梯度。
  • 在对应tensor的 ‘.grad’ 属性累积梯度
  • 使用链式法则,一直传播到叶tensor

注意: DAGs在PyTorch是动态的,需要注意的一点是,graph是从头开始创建的;在每次调用 .backward() 之后,autograd开始生成一个新的graph。这允许你在模型中使用控制流语句;如果需要,你可以在每次迭代中改变shape,size,and operations。

选读:Tensor梯度和Jacobian Products

延伸阅读

Original: https://www.cnblogs.com/DeepRS/p/15743698.html
Author: Deep_RS
Title: PyTorch 介绍 | 使用 TORCH.AUTOGRAD 自动微分

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/714148/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • 4.多元线性回归

    线性模型假定预测(\hat{y})是对应(x={x_1,x_2,\cdots,x_p})的属性的线性组合,即: [\begin{align} \hat{y} &=\thet…

    技术杂谈 2023年7月10日
    044
  • window Tomcat安装教程

    404. 抱歉,您访问的资源不存在。 可能是网址有误,或者对应的内容被删除,或者处于私有状态。 代码改变世界,联系邮箱 contact@cnblogs.com 园子的商业化努力-困…

    技术杂谈 2023年5月31日
    094
  • KL散度(距离)和JS散度(距离)zz

    两者都可以用来衡量两个概率分布之间的差异性。JS散度是KL散度的一种变体形式。 KL散度:也称相对熵、KL距离。对于两个概率分布P和Q之间的差异性(也可以简单理解成相似性),二者越…

    技术杂谈 2023年5月31日
    094
  • NO.1通讯录管理系统+源代码(C++)

    功能描述:显示简单的菜单,供用户选择操作 实现步骤:直接cout输出 功能描述:根据用户不同的操作代码选择,进入不同的功能,我们使用switch分支结构进行搭建 实现步骤:用whi…

    技术杂谈 2023年7月24日
    067
  • 初学者必犯的10个Python错误

    前言 当我们开始学习Python时,我们会养成一些不良编码习惯,而更可怕的是我们连自己也不知道。 我们学习变成的过程中,大概有会这样的经历: 写的代码只能完成了一次工作,但后来再执…

    技术杂谈 2023年6月21日
    095
  • InnoDB中不同SQL语句设置的锁

    锁定读、UPDATE 或 DELETE 通常会给在SQL语句处理过程扫描到的每个索引记录上设置记录锁。语句中是否存在排除该行的WHERE条件并不重要。InnoDB不记得确切的WHE…

    技术杂谈 2023年7月24日
    058
  • 迁移学习

    古语有言:”它山之石可以攻玉”,迁移学习就是这么一种思想,将在其他训练集上训练好的神经网络迁移到目标任务上。自打迁移学习的思想提出后,在工业实践上,就很少有…

    技术杂谈 2023年7月23日
    066
  • 编程技巧│浏览器 Notification 桌面推送通知

    404. 抱歉,您访问的资源不存在。 可能是网址有误,或者对应的内容被删除,或者处于私有状态。 代码改变世界,联系邮箱 contact@cnblogs.com 园子的商业化努力-困…

    技术杂谈 2023年7月11日
    071
  • SLF4J 日志门面

    SLF4J( Simple Logging Facade For Java),即 简单日志门面。主要是为了给 Java 日志访问提供一套标准、规范的 API 框架,其主要意义在于提…

    技术杂谈 2023年7月11日
    076
  • Laravel项目中使用GroupBy时报错

    今天用Laravel做一个新的项目,GroupBy一个字段内容为中文时候,一直报错。 $list = ApCategories::where(‘site_code’, ‘MY’) …

    技术杂谈 2023年7月11日
    062
  • 异步、邮件、定时任务

    异步、邮件、定时任务 14.1 异步任务 编写一个业务测试类 文件路径:com–dzj–service–AsynService.java @Se…

    技术杂谈 2023年6月21日
    0104
  • Filter拦截器从入门到快速上手与Listener监听器概述

    前置内容: 会话跟踪技术 1、 过滤器Filter 1.1 Filter快速入门 1.2 Filter执行流程 1.3 Filter使用细节 1.4 案例 2、 监听器Listen…

    技术杂谈 2023年7月24日
    075
  • Mybatis缓存机制

    MyBatis是常见的 Java数据库访问层框架。在日常工作中,多数情况下是使用 MyBatis的默认缓存配置减轻数据库压力,提高数据库性能,但是 MyBatis缓存机制有一些不足…

    技术杂谈 2023年7月24日
    081
  • LeetCode27.移除元素

    给你一个数组nums和一个值val,你需要 原地 移除所有数值等于val的元素,并返回移除后数组的新长度。 不要使用额外的数组空间,你必须仅使用O(1)额外空间并 原地 修改输入数…

    技术杂谈 2023年7月24日
    074
  • Go异步check简单示例

    异步check代码: 测试: Original: https://www.cnblogs.com/-wenli/p/14737981.htmlAuthor: stdTitle: G…

    技术杂谈 2023年5月31日
    087
  • java可变参数

    可变参数 用法: public void test(int… i){} //类型后边加… 本质是数组参考文档: 方法中有多个参数是,可变参数必须放在最后 例…

    技术杂谈 2023年7月11日
    045
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球