PyTorch踩坑记录——torch.functional 与 torch.nn.functional的区别

问题描述:

提示:刚入门深度学习,记录一些犯下的小错误:

由于本周开始试图复现华为的CTR库以增加记忆,熟悉代码细节,没想到第一天看基础模块的时候就遇到了麻烦,在torch.utils类中,有如下获取损失函数的代码块:

def get_loss_fn(loss):
    if isinstance(loss, str):
        if loss in ["bce", "binary_crossentropy", "binary_cross_entropy"]:
            loss = "binary_cross_entropy"
    try:
        loss_fn = getattr(torch.functional.F, loss)
    except:
        try:
            from . import losses
            loss_fn = getattr(losses, loss)
        except:
            raise NotImplementedError("loss={} is not supported.".format(loss))
    return loss_fn

其中getattr()函数是用于返回一个对象属性值(Tip: class中的方法也是一种对象属性),因此可以看出第6行代码的作用就是返回 torch.functional.F这个类中的 loss函数,那么问题来了:上面代码片中的 torch.functional.F是哪个类呢,或者说是哪个模块呢?之前在学习PyTorch的过程中只接触过其中的:

import torch.nn.functional as F

那么这个 torch.functional.Ftorch.nn.functional有何区别?

解惑

因此抱着分辨清楚的目的,查看PyTorch官方文档,我发现只有 torch.nn.functional才有一系列的loss函数的实现,而输入关键词 torch.functional在搜索引擎上基本找不到相关的资料,返回的搜索结果都是与前者相关的文档。于是我决定去看源码弄清楚:

PyTorch踩坑记录——torch.functional 与 torch.nn.functional的区别
可以看到这两个模块显然是不同的模块!!!而后我打开torch.functional.py文件,出现了我无语的一幕,原来在torch.functional.py的第一行就是这么写的:

PyTorch踩坑记录——torch.functional 与 torch.nn.functional的区别
问题解决了, torch.functional .F指向的就是 torch.nn.functional,可能刚开始试图复现这个CTR库吧,实在搞不懂作者为什么不直接直接使用torch.nn.functional来指代?

END~

Original: https://blog.csdn.net/weixin_42234922/article/details/123937528
Author: kumare
Title: PyTorch踩坑记录——torch.functional 与 torch.nn.functional的区别

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/697134/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球