Pytorch:利用torch.nn.Modules.parameters修改模型参数

1. 关于parameters()方法

Pytorch中继承了 torch.nn.Module的模型类具有 named_parameters()/parameters()方法,这两个方法都会返回一个用于迭代模型参数的迭代器( named_parameters还包括参数名字):

import torch

net = torch.nn.LSTM(input_size=512, hidden_size=64)
print(net.parameters())
print(net.named_parameters())
#
#

我们可以将 net.parameters()迭代器和将 net.named_parameters()转化为列表类型,前者列表元素是模型参数,后者是包含参数名和模型参数的元组。

当然,我们更多的是对迭代器直接进行迭代:

for param in net.parameters():
    print(param.shape)
torch.Size([256, 512])
torch.Size([256, 64])
torch.Size([256])
torch.Size([256])
for name, param in net.named_parameters():
    print(name, param.shape)
weight_ih_l0 torch.Size([256, 512])
weight_hh_l0 torch.Size([256, 64])
bias_ih_l0 torch.Size([256])
bias_hh_l0 torch.Size([256])

我们知道,Pytorch在进行优化时需要给优化器传入这个参数迭代器,如:

from torch.optim import RMSprop
optimizer = RMSprop(net.parameters(), lr=0.01)

2. 关于参数修改

那么底层具体是怎么对参数进行修改的呢?

我们在博客《Python对象模型与序列迭代陷阱》中介绍过,Python序列中本身存放的就是对象的引用,而迭代器返回的是序列中的对象的二次引用,如果序列的引用指向基础数据类型,则是不可以通过遍历序列进行修改的,如:

my_list = [1, 2, 3, 4]
for x in my_list:
    x += 1
print(my_list) #[1, 2, 3, 4]

而序列中的引用指向复合数据类型,则可以通过遍历序列来完成修改操作,如:

my_list = [[1, 2],[3, 4]]
for sub_list in my_list:
    sub_list[0] += 1
print(my_list)
[1, 2, 3, 4]
[[2, 2], [4, 4]]

具体原理可参照该篇博客,此处我就不在赘述。这里想提到的是,用 net.parameters()/net.named_parameters()来迭代并修改参数,本质上就是上述第二种对复合数据类型序列的修改。我们可以如下写:

for param in net.parameters():
    with torch.no_grad():
        param += 1

with torch.no_grad():表示将将所要修改的张量关闭梯度计算。所增加的1会广播到 param张量的中的每一个元素上。上述操作本质上为:

for param in net.parameters():
    with torch.no_grad():
        param += torch.ones(param.shape)

但是需要注意,如果我们想让参数全部置为0,切不可像下列这样写:

for param in net.parameters():
    with torch.no_grad():
        param = torch.zeros(param.shape)

param是二次引用, param=0操作再语义上会被解释为让 param这个二次引用去指向新的全0张量对象,但是对参数张量本身并不会产生任何变动。该操作实际上类似下列这种操作:

list_1 = [1, 2]
list_2 = list_1
list_2 = [0, 0]
print(list_1) # [1, 2]

修改二次引用 list_2自然不会影响到 list_1引用的对象。

下面让我们纠正这种错误,采用下列方法直接来将参数张量中的所有数值置0:

for param in net.parameters():
    with torch.no_grad():
        param[:] = 0 #张量类型自带广播操作,等效于param[:] = torch.zeros(param.shape)

这时语义上就类似

list_1 = [1, 2]
list_2 = list_1
list_2[:] = [0, 0]
print(list_1) # [0, 0]

自然就能完成修改的操作了。

Original: https://www.cnblogs.com/orion-orion/p/16293822.html
Author: orion-orion
Title: Pytorch:利用torch.nn.Modules.parameters修改模型参数

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/805160/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • Scrapy_redis框架的概念作用和流程

    当爬取的网站的数据量非常庞大的时候,再使用之前的Scrapy框架速度就会有点偏慢,这时可以使用分布式来快速的爬取大量的数据。 1. 分布式是什么 分布式就是不同的节点(服务器,ip…

    Python 2023年10月5日
    033
  • 小公司的应用服务部署历程

    先声明一下:我所在的公司是一个小团队,做物联网相关的,前后端、硬件、测试加起来也就五六十个人左右;本人的岗位是Java开发(兼DBA、运维);我进公司时整个项目的部署架构为 简单j…

    Python 2023年10月14日
    028
  • 深度强化学习-Dueling DQN算法原理与代码

    Dueling Deep Q Network(Dueling DQN)是对DQN算法的改进,有效提升了算法的性能。如果对DQN算法还不太了解的话,可以参考我的这篇博文:深度强化学习…

    Python 2023年10月8日
    050
  • 【python 游戏】闲的无聊?那就和博主一起来滑雪吧~

    前言 滑雪运动(特别是现代竞技滑雪)发展到当今,项目不断在增多,领域不断在扩展。 世界比赛正规的大项目分为:高山滑雪、北欧滑雪(Nordic Skiing,越野滑雪、跳台滑雪)、自…

    Python 2023年9月22日
    043
  • 《奇迹笨小孩》、《误杀》观后感

    前天晚上看的《奇迹笨小孩》,一部以讲述故事为主要线索,观看时情绪比较平缓,但是主角表现出了当代年轻人要有敢做,敢拼的劲。 印象深刻的情节简述:(主角我就以”他&#822…

    Python 2023年6月10日
    065
  • 使用vscode编辑markdown文件(可粘贴截图)

    使用markdown粘贴截图时,操作步骤比较多: 1)截取图片; 2)将图片存在特定位置; 3)记住图片路径,在markdown文件中编写代码; 4)预览效果; 而word之类的文…

    Python 2023年6月12日
    078
  • python命令行安装包

    1、单个包安装 pip指定软件源安装命令格式:pip install -i [ source_url ] [ package_name ] source_url:是软件源地址 pa…

    Python 2023年9月18日
    049
  • nn.BatchNorm讲解,nn.BatchNorm1d, nn.BatchNorm2d代码演示

    1 nn.BatchNorm BatchNorm是深度网络中经常用到的加速神经网络训练,加速收敛速度及稳定性的算法,是深度网络训练必不可少的一部分,几乎成为标配; BatchNor…

    Python 2023年8月1日
    0133
  • 1W+字概括精髓,Pandas中必知必会50例

    本篇我们继续前面 pandas系列教程的探讨,今天我会介绍 pandas库当中一些非常实用的方法与函数,希望大家看了之后会有所收获, 喜欢本文点赞支持,欢迎收…

    Python 2023年8月21日
    054
  • matplotlib

    本篇博客将从易到难的阐述matplotlib的各种用法。首先通过举一个例子,对matplotlib的用法有一个大致的了解。 import matplotlib.pyplot as …

    Python 2023年9月5日
    041
  • pygame音乐相关的功能实现

    pygame.mixer.music.load() —— 载入一个音乐文件用于播放pygame.mixer.music.play() —— 开始播放音乐流pygame.mixer….

    Python 2023年9月19日
    037
  • pip常用命令

    文章目录 一、pip是什么? 二、pip常见命令 * 1.升级 2.安装和卸载 3.查看 4.requirement相关 5.使用wheel文件安装库 三、pip换源 * 1.临时…

    Python 2023年8月2日
    067
  • selenium之使用POM模式设计PO类,将POM模式运用到项目中

    终极目标就是运用DDT思想+POM思想+pytest框架来最终实现项目但是这样虽然松耦合了但是pom思想多维护了一个类,视情况而定,可以不用pom思想 最基本的逻辑就是:test_…

    Python 2023年9月11日
    058
  • 教你用canvas打造一个炫酷的碎片切图效果

    前言 今天分享一个炫酷的碎片式切图效果,这个其实在自己的之前的博客上有实现过,本人觉得这个效果还是挺炫酷的,这次还是用我们的canvas来实现,代码量不多,但有些地方还是需要花点时…

    Python 2023年10月18日
    069
  • docker 容器中 ip addr 出现 bash: ip: commandnot found

    [root@linux-local /] bash: ip: command not found 问题出现的原因:**我们下载的某个镜像(例如Nginx镜像)是精简版的,使用此镜像…

    Python 2023年11月8日
    045
  • python 使用pandas 读写excel文件

    现在本地创建一个excel表,以及两个sheet,具体数据如下: sheet1: sheet2: 读取excel文件 pandas.read_excel(io, sheet_nam…

    Python 2023年8月1日
    058
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球