pytorch 注意力机制

注意力机制:

父母在学校门口接送孩子的时候,可以在人群中一眼的发现自己的孩子,这就是一种注意力机制。
为什么父母可以在那么多的孩子中,找到自己的孩子?
比如现在有100个孩子,要被找的孩子 发型是平头,个子中等,不戴眼镜,穿着红色上衣,牛仔裤
通过对这些特征,就可以对这100个孩子进行筛选,最后剩下的孩子数量就很少了,就是这些特征的存在,使得父母的注意力会主要放在有这些特征的孩子身上,这就是注意力机制。

注意力机制
Query 被找孩子的特征
Key 100个孩子,通过特征进行筛选,得到这100个孩子的可能性
Value 100个孩子中,找到自己孩子的可能性

attention = softmax(Q、K之间进行计算) * V
Q、K之间的计算方式不同,这就导致了不同的注意力机制。

pytorch 注意力机制
最后一种就是Transformer中的一种注意力的计算机制。

; 实际应用中的理解

一般在自然语言处理应用里会把Attention模型看作是输出Target句子中某个单词和输入Source句子每个单词的对齐模型。
目标句子的每个单词 与输入句子中的每个单词 计算权重,计算注意力权重
类似于机器翻译中的短语对齐步骤

pytorch 注意力机制
可以看到里面的 Q K V
QK之间的计算就是计算QK之间的相关性,或者说特征的相似性
这样就可以得到每个key对应的value的权重系数,然后与V相乘
pytorch 注意力机制
Lx=||Source||代表Source的长度

计算过程

1.计算QK之前的相似度

pytorch 注意力机制
2.softmax 归一化
3.对value进行加权求和**

; 代码实现

第一步:根据注意力计算规则,对Q,K,V进行相应的计算.

第二步:根据第一步采用的计算方法,如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接,如果是转置点积,一般是自注意力,Q与V相同,则不需要进行与Q的拼接.

第三步:最后为了使整个attention机制按照指定尺寸输出,使用线性层作用在第二步的结果上做一个线性变换,得到最终对Q的注意力表示

第一步就是使用第一种计算的方式,获取注意力机制的权重,就是上边所说的孩子的特征占100个孩子权重
第三部就是为了获得指定尺寸的输出

import torch
from torch import nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self,query_size, key_size, value_size1, value_size2, output_size):
        super(Attention, self).__init__()
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        self.output_size = output_size

        self.attn = nn.Linear(self.query_size + self.key_size, self.value_size1)

        self.attn_combine = nn.Linear(self.query_size + self.value_size2, self.output_size)

    def forward(self, Q, K, V):

        attn_weights = F.softmax(self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)

        attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)

        output = torch.cat((Q[0], attn_applied[0]), 1)

        output = self.attn_combine(output).unsqueeze(0)
        return output, attn_weights

query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attention(query_size, key_size, value_size1, value_size2, output_size)
Q = torch.randn(1, 1, query_size)
print("---")
K = torch.randn(1, 1, key_size)
print(torch.cat((Q[0], K[0]), 1).shape)
V = torch.randn(1, value_size1, value_size2)
out = attn(Q, K, V)

Original: https://blog.csdn.net/qq_39753950/article/details/125791872
Author: yhbetter
Title: pytorch 注意力机制

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/709802/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球