PyTorch介绍-使用 TORCH.AUTOGRAD 自动微分

训练神经网络时,最常用的算法就是 反向传播。在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整。

为了计算这些梯度,PyTorch内置了名为 torch.autograd 的微分引擎。它支持任意计算图的自动梯度计算。

一个最简单的单层神经网络,输入 x,参数 wb,某个损失函数。它可以用PyTorch这样定义:

import torch

x = torch.ones(5)      # input tensor
y = torch.zeros(3)     # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w) + b    # 矩阵乘法
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

Tensors、Functions and Computational graph

上述代码定义了下面的 computational graph:

PyTorch介绍-使用 TORCH.AUTOGRAD 自动微分

在该网络中, wbparameters,是我们需要优化的。因此,我们需要能够计算损失函数关于这些变量的梯度。因此,我们设置了这些tensor的 requires_grad 属性。

注意:在创建tensor时可以设置 requires_grad 的值,或者创建之后使用 x.requires_grad_(True) 方法。

我们应用到tensor上构成计算图的function实际上是 Function 类的对象。该对象知道如何计算前向的函数,还有怎么计算反向传播步骤中函数的导数。反向传播函数存储在tensor的 grad_fn 属性中。You can find more information of Function in the documentation

print('Gradient function for z =', z.grad_fn)
print('Gradient function for loss =', loss.grad_fn)

输出:

Gradient function for z = <addbackward0 object at 0x7faea5ef7e10>
Gradient function for loss = <binarycrossentropywithlogitsbackward0 object at 0x7faea5ef7e10>
</binarycrossentropywithlogitsbackward0></addbackward0>

计算梯度

为了优化神经网络的参数权重,我们需要计算损失函数关于参数的导数,即,我们需要利用一些固定的 xy 计算(\frac{\partial loss}{\partial w})和(\frac{\partial loss}{\partial b})。为计算这些导数,可以调用 loss.backward(),然后从 w.gradb.grad

loss.backward()
print(w.grad)
print(b.grad)

输出:

tensor([[0.0043, 0.2572, 0.3275],
        [0.0043, 0.2572, 0.3275],
        [0.0043, 0.2572, 0.3275],
        [0.0043, 0.2572, 0.3275],
        [0.0043, 0.2572, 0.3275]])
tensor([0.0043, 0.2572, 0.3275])

注意:

  • 我们只能在计算图中 requires_grad=True 的叶节点获得 grad 属性。对于其它节点,梯度是无效的。
  • 出于性能原因,我们只能对给定的graph使用 backward 执行梯度计算。如果需要在同一graph调用若干次 backward,在调用时,需要传入 retain_graph=True

禁用梯度跟踪

默认情况下,所有 requires_grad=True 的tensor都会跟踪它们的计算历史,并支持梯度计算。但是在一些情况下并不需要,例如,当我们已经训练了一个模型,并将其用在一些输入数据上,即,仅仅经过网络做前向运算。那么可以在我们的计算代码外包围 torch.no_grad() 块停止跟踪计算。

z = torch.matmul(x, w) + b
print(z.requires_grad())

with torch.no_grad():
    z = torch.matmul(x, w) + b
print(z.requires_grad)

输出:

True
False

在tensor上使用 detach() 也能达到同样的效果

z = torch.matmul(x, w) + b
z_det = z.detach()
print(z_det.requires_grad)

输出:

False

禁止梯度跟踪的几个原因:

  • 将神经网络的一些参数标记为 frozen parameters。这在finetuning a pretrained network中是非常常见的脚本。
  • 当你只做前向过程,用于 speed up computations,因为tensor计算而不跟踪梯度将会更有效。

More on Coputational Graphs

概念上,autograd在一个由Function对象组成的有向无环图(DAG)中保留了数据(tensors)记录,还有所有执行的操作(以及由此产生的新的tensors)。在DAG中,叶节点是输入tensor,根节点是输出tensors。通过从根到叶跟踪该图,可以使用链式法则自动地计算梯度。

在前向过程中,autograd同时进行两件事:

  • 运行请求的操作计算结果tensor
  • 在DAG中保存操作的梯度函数

当在DAG根部调用 .backward()时,后向过程就会开始。 autograd会:

  • 由每一个 .grad_fn 计算梯度。
  • 在对应tensor的 ‘.grad’ 属性累积梯度
  • 使用链式法则,一直传播到叶tensor

注意: DAGs在PyTorch是动态的,需要注意的一点是,graph是从头开始创建的;在每次调用 .backward() 之后,autograd开始生成一个新的graph。这允许你在模型中使用控制流语句;如果需要,你可以在每次迭代中改变shape,size,and operations。

选读:Tensor梯度和Jacobian Products

延伸阅读

Original: https://www.cnblogs.com/DeepRS/p/15743698.html
Author: Deep_RS
Title: PyTorch介绍-使用 TORCH.AUTOGRAD 自动微分

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/609932/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • CentOS 压缩解压

    打包:将多个文件合成一个总的文件,这个总的文件通常称为 “归档”。 压缩:将一个大文件通过某些压缩算法变成一个小文件。 1.1、tar 压缩格式: tar …

    Linux 2023年6月8日
    092
  • python入门基础知识四(字典与集合)

    dict_name = {key1:value1,key2,value2,…} 空字典:dict_name = {} or dict_name = dict() 字典的…

    Linux 2023年6月7日
    074
  • VMware服务关闭后一定要重启

    重要的事情说三遍:服务暂时关闭记得重启,服务暂时关闭记得重启,服务暂时关闭记得重启!!! VMware服务由于安装补丁的需要我暂时把服务关闭了,于是我遇到了尴尬的一幕,于是乎发现上…

    Linux 2023年6月7日
    0121
  • 上篇:Go函数的骚包玩法有哪些

    1. 用type关键字可以定义函数类型,函数类型变量可以作为函数的参数或返回值。 package main import "fmt" func add(a, b…

    Linux 2023年6月7日
    094
  • SSH免密登录

    SSH免密登录实现三步: 客户端生成公钥和私钥 上传公钥到服务端 SSH免密登录 (1) 客户端生成和公钥和私钥 ssh-keygen 一路回车即可,默认会在~/.ssh/目录下创…

    Linux 2023年6月7日
    099
  • Linux动静分离与Rewrite

    一、动静分离 1.1 单台机器动静分离 1、创建NFS挂载点(NFS服务端) mkdir /static vim /etc/exports /static 172.16.1.0/2…

    Linux 2023年6月14日
    088
  • 19-TCP、UDP的区别和应用场景

    可靠性TCP 提供交付保证,这意味着一个使用TCP协议发送的消息是保证交付给客户端的,如果消息在传输过程中丢失,那么它将重发。UDP是不可靠的,它不提供任何交付的保证,一个数据包在…

    Linux 2023年6月7日
    084
  • MySQL里的那些日志们

    该系列博文会告诉你如何从入门到进阶,从sql基本的使用方法,从MySQL执行引擎再到索引、事务等知识,一步步地学习MySQL相关技术的实现原理,更好地了解如何基于这些知识来优化sq…

    Linux 2023年6月14日
    0107
  • 【Leetcode】64. 最小路径和

    给定一个包含非负整数的 m&#xA0;x&#xA0;n网格 grid,请找出一条从左上角到右下角的路径,使得路径上的数字总和为最小。 说明:每次只能向下或者向右移动…

    Linux 2023年6月6日
    0103
  • LVM逻辑卷与磁盘配额

    一、LVM逻辑卷 1、LVM概述 LVM(Logical Volume Manager,逻辑卷管理)重点在于可以弹性地调整文件系统的容量,需要文件的读写性能或是数据的可靠性,LVM…

    Linux 2023年6月6日
    0106
  • oracle 怎么查看用户对应的表空间

    oracle 怎么查看用户对应的表空间? 查询用户: 查看数据库里面所有用户,前提是你是有 dba 权限的帐号,如 sys,system: select * from dba_us…

    Linux 2023年6月6日
    0114
  • SharePoint 2007 Full Text Searching PowerShell and CS file content with SharePoint Search

    Ensure your site or shared folder in one Content Source. Add file types. The second step i…

    Linux 2023年5月28日
    078
  • mysql字符串拼接

    Mysql数据库中的字符串 CONCAT()CONCAT_WS()GROUP_CONCAT() CONCAT() CONCAT(string1,string2)最常用的字符串拼接方…

    Linux 2023年6月6日
    080
  • Redis多线程原理详解

    从上图中可以看出只有以下3个地方用的是多线程,其他地方都是单线程: 1:接收请求参数 2:解析请求参数 3:请求响应,即将结果返回给client 很明显以上3点各个请求都是互相独立…

    Linux 2023年5月28日
    086
  • redis配置systemctl

    [Unit]Description=redisAfter=network.target [Service]Type=forkingPIDFile=/var/run/redis_63…

    Linux 2023年5月28日
    0110
  • DMA 与零拷贝技术

    原文链接:DMA 与零拷贝技术 注意事项:除了 Direct I/O,与磁盘相关的文件读写操作都有使用到 page cache 技术。 1. 数据的四次拷贝与四次上下文切换 很多应…

    Linux 2023年6月16日
    0137
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球