逐句解析点积注意力pytorch源码(配图解)

前言

结合pytorch源码和原始论文学习Scaled Dot-Product Attention的原理。

原论文链接:Attention Is All You Need
原论文中的网络结构如下图所示。

逐句解析点积注意力pytorch源码(配图解)

; 计算公式

逐句解析点积注意力pytorch源码(配图解)

源码(pytorch)


def _scaled_dot_product_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attn_mask: Optional[Tensor] = None,
    dropout_p: float = 0.0,) -> Tuple[Tensor, Tensor]:

    B, Nt, E = q.shape
    q = q / math.sqrt(E)

    attn = torch.bmm(q, k.transpose(-2, -1))

    if attn_mask is not None:
        attn += attn_mask

    attn = softmax(attn, dim=-1)

    if dropout_p > 0.0:
        attn = dropout(attn, p=dropout_p)

    output = torch.bmm(attn, v)
    return output, attn

下面逐句解读。

第一步:实现scaled

q = q / math.sqrt(E)

其中的 math.sqrt(E)对应计算公式中的d k \sqrt{d_k}d k ​​。

根据源码,E(也就是d k d_k d k ​)代表Q、K、V三个张量中特征(embedding)的维度。

关于Q、K、V的形状的代码注释。

  • q: :math: (B, Nt, E) where B is batch size, Nt is the target sequence length,
    and E is embedding dimension.

  • key: :math: (B, Ns, E) where B is batch size, Ns is the source sequence length,
    and E is embedding dimension.

  • value: :math: (B, Ns, E) where B is batch size, Ns is the source sequence length,
    and E is embedding dimension.

在论文中对Q、K、V维度的说明:

The input consists of queries and keys of dimension d k d_k d k ​ , and values of dimension d v d_v d v ​ .

为什么要除以 d k \sqrt{d_k}d k ​​ ?

原论文中的解释如下:

While for small values of d k d_k d k ​ the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of d k d_k d k ​ . We suspect that for large values of d k d_k d k ​, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients . To counteract this effect, we scale the dot products by 1 d k \frac{1}{\sqrt{d_k}}d k ​​1 ​

简单解释就是:当 d k d_k d k ​ 较大时(也就是Q和K的维度较大时),dot-product attention的效果就比加性注意力差。作者推测,对于较大的d k d_k d k ​值,点积(Q和K的转置的点积)的增长幅度很大,进入到了softmax函数梯度非常小的区域。

李沐老师在Transformer论文逐段精读【论文精读】中对这部分的解读摘录如下:

当你的dk不是很大的时候,除不除都没关系。
但是当dk很大的时候,也就是说两个向量比较长的时候,那么你做点积的时候,这些值呢,就可能会比较大,但也可能是比较小。
当你的值相对来说比较大的时候呢,你之间的相对的差距就会变大,就导致说,你值最大的那个值做出来softmax就会更加靠近1。剩下那些值呢就会更加靠近0。就是你的值更加向两端靠拢。当你出现这个样子的时候,你算梯度的时候,你发现梯度比较小。
因为softmax最后的结果是什么?就是我希望我的预测值啊,置信的地方尽量靠近1,不置信的地方尽量靠近0。这样子我的收敛就差不多了。这时候你的梯度就会变得比较小,那你就会跑不动。

softmax公式
s o f t m a x ( x i ) = e x i ∑ j = 1 K e x j f o r i = 1 , 2 , … , K softmax(x_i) = \frac{e^{x_{i}}}{\sum_{j=1}^K e^{x_{j}}} \ \ \ for\ i=1,2,\dots,K s o f t m a x (x i ​)=∑j =1 K ​e x j ​e x i ​​f o r i =1 ,2 ,…,K

更详细计算建议参考https://blog.csdn.net/qq_37430422/article/details/105042303

为什么先除以 d k \sqrt{d_k}d k ​​ ?

为什么不按照公式,先计算Q K T QK^T Q K T矩阵乘法,再除以d k \sqrt{d_k}d k ​​呢?
从数值计算的角度考虑,要尽量控制数值的大小,这样可以保持浮点数的精度。
也就是说,先计算除法,后面计算矩阵乘法的时候,误差就更小。

第二步:通过点积(dot product)计算注意力分数


    attn = torch.bmm(q, k.transpose(-2, -1))

计算q和k T k^T k T的乘积。bmm是批量矩阵乘法。
其中 k.transpose(-2, -1)表示交换k的倒数第1和倒数第2维度,保持批量维度不变。

第三步:mask

    if attn_mask is not None:
        attn += attn_mask

在解码阶段,要限制未来的数据影响,只保留当前时刻之前的数据,所以一般加上一个负无穷大的数,这样后面计算softmax的时候,相应的结果就是0,起到mask的效果。

第四步:计算注意力权重

attn = softmax(attn, dim=-1)

经过softmax之后,attn就称为注意力权重了,因为归一化了。

其中的dim=-1,表示对最后一个维度进行softmax,也就是词向量维度。

第五步:dropout

    if dropout_p > 0.0:
        attn = dropout(attn, p=dropout_p)

注:Transformer原始论文中的 dropout_p=0.1

第六步:加权平均


    output = torch.bmm(attn, v)
    return output, attn

V和注意力权重相乘,并返回结果。

图解QKV矩阵乘法

注意:这里不再使用源码中矩阵shape的符号。

逐句解析点积注意力pytorch源码(配图解)

Q是n乘以dk的矩阵,代表多个Querys,理解为每行一个query向量。
K是m乘以dk的矩阵,代表多个Keys,理解为每行一个key向量。
Q和K的转置相乘,得到的矩阵中每一行(橙色的行)代表某一个query向量和所有key向量之间的相似度,也就是一个query向量和所有key向量点积。
灰色矩阵实际上就是相关度矩阵,注意力分数矩阵。

V是m乘以dv的矩阵,代表多个Values,每行代表一个value向量。
橙色向量左乘V,代表V的行向量按照橙色向量的权重进行线性组合,得到蓝色向量。

Original: https://blog.csdn.net/m0_48742971/article/details/123392039
Author: coder1479
Title: 逐句解析点积注意力pytorch源码(配图解)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/710179/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球