Pytorch的类(nn.Module的子类)中的forward函数

使用

直接通过类的实例对象就可以向类中的forward函数进行参数的传递(当然也可以通过调用forward函数进行传参)

import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x

data1 = 1
data2 = 2
module = MyModule()
x1 = module(data1)          # 不需要显示调用forward函数就可以传递参数
x2 = module.forward(data2)
print(x1)
print(x2)

>> 1
>> 2

解释

nn.Module() 中包含了 __call__ 函数;

实现了 __call__ 函数的类,其类实例是一个可调用的对象,其可以简化对于类中某些方法的调用(写在 __call__ 中的方法),模糊了实例对象和类成员函数的区别。使用类实例 module() 时 就相当于 module.call(),如果在 call() 中写上函数,就可以直接通过类实例对象传参调用了。

而在 nn.Module() 中的 __call__ 函数中调用了 forward() 函数,

...

例子 #
def __call__(self, param):
    res = self.forward(param)
    return res
...

由于继承关系,对于MyModule(nn.Module) 类 同样具备了 __call__ 函数的功能,即可以通过类实例module 直接 调用 forward 并传参。

Original: https://www.cnblogs.com/jack-nie-23/p/16506630.html
Author: jacknie23
Title: Pytorch的类(nn.Module的子类)中的forward函数

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/582382/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球