nn.linear()函数

import torch
import torch.nn as nn
import torch.nn.functional as F

class LinearFC(nn.Module):

    def __init__(self):
        super(DropoutFC, self).__init__()
        self.fc = nn.Linear(3, 2)

    def forward(self, input):
        out = self.fc(input)
        return out

Net = LinearFC()
x = torch.randint(10, (2, 3)).float()
Net.train()
output = Net(x)
print(output)

创建了一个最简单的 LinearFC模型,里面有一个线性函数 nn.Linear(3, 2),线性变换公式为:y = x W T + b y=x W^T + b y =x W T +b。

通过Debug,一步一步查看运行情况:

nn.linear()函数

当前这一步可以看到模型给我们随机初始化了权重W 2 × 3 W_{2 \times 3}W 2 ×3 ​和偏置b 2 × 3 b_{2 \times 3}b 2 ×3 ​,为什么权重W W W的shape是2 × 3 2\times3 2 ×3,因为公式里需要转置。

x x x随机生成不大于10的整数,转为float, 因为nn.linear需要float类型数据。

nn.linear()函数
可以看出使用模型算出来的output,与手动使用公式算出来的结果一致。
nn.linear()函数

Net.train()的作用

当网络中有 dropout,Batch Normalization 的时候。训练的要记得 Net.train(), 测试 要记得 Net.eval()。

在训练模型时会在前面加上:

Net.train()

在测试模型时在前面使用:

model.eval()

同时发现,如果不写这两个程序也可以运行,这是因为这两个方法是针对在网络训练和测试时采用不同方式的情况,比如Batch Normalization 和 Dropout。

Original: https://blog.csdn.net/vincent_duan/article/details/119934349
Author: vincent_hahaha
Title: nn.linear()函数

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/710626/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球