04 Transformer 中的位置编码的 Pytorch 实现

1:10 点赞

16:00

04 Transformer 中的位置编码的 Pytorch 实现

我爱你

你爱我

1401

04 Transformer 中的位置编码的 Pytorch 实现
class PositionalEncoding(nn.Module):

    def __init__(self, dim, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()

        if dim % 2 != 0:
            raise ValueError("Cannot use sin/cos positional encoding with "
                             "odd dim (got dim={:d})".format(dim))

"""
        构建位置编码pe
        pe公式为:
        PE(pos,2i/2i+1) = sin/cos(pos/10000^{2i/d_{model}})
"""
        pe = torch.zeros(max_len, dim)  # max_len 是解码器生成句子的最长的长度,假设是 10
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
                              -(math.log(10000.0) / dim)))

        pe[:, 0::2] = torch.sin(position.float() * div_term)
        pe[:, 1::2] = torch.cos(position.float() * div_term)
        pe = pe.unsqueeze(1)
        self.register_buffer('pe', pe)
        self.drop_out = nn.Dropout(p=dropout)
        self.dim = dim

    def forward(self, emb, step=None):

        emb = emb * math.sqrt(self.dim)

        if step is None:
            emb = emb + self.pe[:emb.size(0)]
        else:
            emb = emb + self.pe[step]
        emb = self.drop_out(emb)
        return emb

Original: https://www.cnblogs.com/nickchen121/p/16529997.html
Author: 二十三岁的有德
Title: 04 Transformer 中的位置编码的 Pytorch 实现

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/566927/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球