PyTorch——自注意力(self-attention)机制实现(代码详解)

参考链接

  1. https://www.bilibili.com/video/BV1JE411g7XF?p=54
  2. https://arxiv.org/abs/1706.03762
  3. https://blog.csdn.net/qq_36653505/article/details/83375160

简述自注意力机制(self-attention)

self-attention可以视为一个特征提取层,给定输入特征a 1 , a 2 , ⋅ ⋅ ⋅ a n a^{1},a^{2},\cdot \cdot \cdot a^{n}a 1 ,a 2 ,⋅⋅⋅a n,经过self-attention layer,融合每个输入特征,得到新的特征b 1 , b 2 , ⋅ ⋅ ⋅ b n b^{1},b^{2},\cdot \cdot \cdot b^{n}b 1 ,b 2 ,⋅⋅⋅b n。具体如下:

设输入特征为I I I,分别将其乘以三个矩阵W q W^{q}W q、W k W^{k}W k和W v W^{v}W v得到Q Q Q(query)、K K K(key)和V V V(value)三个矩阵;接下来使用矩阵Q Q Q和K K K的乘积得到注意力矩阵A A A,归一化得到A ^ \hat{A}A ^;最后,将归一化后的注意力矩阵A ^ \hat{A}A ^乘上V V V,得到最后的输出特征O O O。

PyTorch——自注意力(self-attention)机制实现(代码详解)

; 多头自注意力机制(multi-head self-attention)

上述的self-attention中,每个输入特征a i a^{i}a i乘上矩阵W q W^{q}W q、W k W^{k}W k和W v W^{v}W v后,分别得到一个向量q i q^{i}q i、k i k^{i}k i和v i v^{i}v i,称为单头自注意力机制。如果将这些向量q i q^{i}q i、k i k^{i}k i和v i v^{i}v i分裂为n n n个就得到n n n头自注意力机制了。公认多头自注意力机制的效果好于单头的,因为前者可以捕获更多维度的信息。示意图如下:

PyTorch——自注意力(self-attention)机制实现(代码详解)

代码实现

设超参数num_attention_heads为自注意力机制的头数,如此,计算出每个头的维度attention_head_size。

self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = hidden_size

定义W q W^{q}W q、W k W^{k}W k和W v W^{v}W v三个矩阵。

self.query = nn.Linear(input_size, self.all_head_size)
self.key = nn.Linear(input_size, self.all_head_size)
self.value = nn.Linear(input_size, self.all_head_size)

下面开始逐步计算,需要主要的是计算过程中张量维度的变化。
将输入特征乘以三个矩阵W q W^{q}W q、W k W^{k}W k和W v W^{v}W v,输出的张量此时还没有区分出多个头。维度变化为:input_tensor ( b a t c h , n , i n p u t _ s i z e ) \left ( batch,n,input_size\right )(b a t c h ,n ,i n p u t _s i z e )到mixed_query_layer ( b a t c h , n , a l l _ h e a d _ s i z e ) \left ( batch,n,all_head_size\right )(b a t c h ,n ,a l l _h e a d _s i z e )

mixed_query_layer = self.query(input_tensor)
mixed_key_layer = self.key(input_tensor)
mixed_value_layer = self.value(input_tensor)

切分为num_attention_heads个头,并变换维度。维度变化为:mixed_query_layer ( b a t c h , n , a l l _ h e a d _ s i z e ) \left ( batch,n,all_head_size\right )(b a t c h ,n ,a l l _h e a d _s i z e )到query_layer ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,num_attention_heads,n,attention_head_size\right )(b a t c h ,n u m _a t t e n t i o n _h e a d s ,n ,a t t e n t i o n _h e a d _s i z e )

def transpose_for_scores(self, x):
   new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
   x = x.view(*new_x_shape)
   return x.permute(0, 2, 1, 3)

query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)

矩阵Q Q Q和K K K相乘,得到注意力矩阵,并除以向量的维度的开方,防止注意力分数随维度增大而增大。维度变化为:query_layer ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,num_attention_heads,n,attention_head_size\right )(b a t c h ,n u m _a t t e n t i o n _h e a d s ,n ,a t t e n t i o n _h e a d _s i z e )到attention_scores ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , n ) \left ( batch,num_attention_heads,n,n\right )(b a t c h ,n u m _a t t e n t i o n _h e a d s ,n ,n )

attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

attention_scores = attention_scores / math.sqrt(self.attention_head_size)

注意力矩阵归一化。维度变化为:attention_scores ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , n ) \left ( batch,num_attention_heads,n,n\right )(b a t c h ,n u m _a t t e n t i o n _h e a d s ,n ,n )到attention_probs ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , n ) \left ( batch,num_attention_heads,n,n\right )(b a t c h ,n u m _a t t e n t i o n _h e a d s ,n ,n )

attention_probs = nn.Softmax(dim=-1)(attention_scores)

将注意力矩阵乘以矩阵V V V。维度变化为:ttention_probs ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , n ) \left ( batch,num_attention_heads,n,n\right )(b a t c h ,n u m _a t t e n t i o n _h e a d s ,n ,n )乘以value_layer ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,num_attention_heads,n,attention_head_size\right )(b a t c h ,n u m _a t t e n t i o n _h e a d s ,n ,a t t e n t i o n _h e a d _s i z e )到context_layer ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,num_attention_heads,n,attention_head_size\right )(b a t c h ,n u m _a t t e n t i o n _h e a d s ,n ,a t t e n t i o n _h e a d _s i z e )。

context_layer = torch.matmul(attention_probs, value_layer)

变换context_layer维度,为了后面将各头得到的结果拼接。这里的contiguous()是将tensor的内存变成连续的,为后面的view()做准备。维度变化为:context_layer ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,num_attention_heads,n,attention_head_size\right )(b a t c h ,n u m _a t t e n t i o n _h e a d s ,n ,a t t e n t i o n _h e a d _s i z e )到context_layer ( b a t c h , n , n u m _ a t t e n t i o n _ h e a d s , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,n,num_attention_heads,attention_head_size\right )(b a t c h ,n ,n u m _a t t e n t i o n _h e a d s ,a t t e n t i o n _h e a d _s i z e )

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

将各头的结果拼接起来。维度变化为:context_layer ( b a t c h , n , n u m _ a t t e n t i o n _ h e a d s , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,n,num_attention_heads,attention_head_size\right )(b a t c h ,n ,n u m _a t t e n t i o n _h e a d s ,a t t e n t i o n _h e a d _s i z e )到context_layer ( b a t c h , n , a l l _ h e a d _ s i z e ) \left ( batch,n,all_head_size\right )(b a t c h ,n ,a l l _h e a d _s i z e )

new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)

完整代码

class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).

"""
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias

class SelfAttention(nn.Module):
    def __init__(self, num_attention_heads, input_size, hidden_size, hidden_dropout_prob):
        super(SelfAttention, self).__init__()
        if hidden_size % num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, num_attention_heads))
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = hidden_size

        self.query = nn.Linear(input_size, self.all_head_size)
        self.key = nn.Linear(input_size, self.all_head_size)
        self.value = nn.Linear(input_size, self.all_head_size)

        self.attn_dropout = nn.Dropout(attention_probs_dropout_prob)

        self.dense = nn.Linear(hidden_size, hidden_size)
        self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
        self.out_dropout = nn.Dropout(hidden_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, input_tensor):
        mixed_query_layer = self.query(input_tensor)
        mixed_key_layer = self.key(input_tensor)
        mixed_value_layer = self.value(input_tensor)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        attention_probs = self.attn_dropout(attention_probs)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        hidden_states = self.dense(context_layer)
        hidden_states = self.out_dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)

        return hidden_states

Original: https://blog.csdn.net/beilizhang/article/details/115282604
Author: cqu_shuai
Title: PyTorch——自注意力(self-attention)机制实现(代码详解)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/627709/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球