5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现

1、seq2seq attention的论文出处

主要是阅读经典论文《Neural Machine Translation by Jointly Learning to Align and Translate》https://arxiv.org/pdf/1409.0473.pdf

这篇论文中机器翻译采用seq2seq的encoder-decoder模型构建,将输入句子encoding成一个固定长度的向量,然后输入到decoder解码生成译文。

attention机制引入要解决的问题

encoder得到的固定向量直接送入decoder,不利于

(1)较长句子的信息传递 (2)提取decoder中目标单词需要关注原文中的语义信息

2、seq2seq decoder的构成

直接截取一下论文中的讲法吧,如下图:

5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现

式子(4)的意思就是解码器生成出来的每个token,由

1、

5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现就是前一个token的特征向量;

2、

5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现 就是当前的隐状态,这里的5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现 也是通过RNN得到的,可以把5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现看成这一步RNN的输入特征向量;

3、

5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现 就是上下文的特征,没有用attention之前就是encoder得到的固定向量。现在用了attention机制这篇文章中经典的attention机制就是(5)式得到的。

关于

5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现 的计算中,简单来说就是encoder中各层隐状态5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现 与decoder的隐状态5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现做相似度(可以用一个MLP来实现等,下文再具体说明各种算法),然后求softmax得到权重系数5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现,然后做加权和。当然这只是一种做法,实际怎么做可以有别的做法。

3、seq2seq attention的经典形式

Effective Approaches to Attention-based Neural Machine Translation 论文中,将注意力机制大致分为了全局(global)注意力和局部(local)注意力。

全局注意力指的是注意力分布在所有encoder得到隐状态中;局部注意力指注意力只存在于一些隐状态中。例如,图像的average pooling可视为全局注意力;max pooling可视为局部注意力

论文链接:https://arxiv.org/abs/1508.04025

3.1 global attention

直接上论文中的内容好了,这边global attention的意思和上一篇论文的一样,只是论文中用的符号不一样,应该一看就能知道,实际一个意思。

5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现

这里decoder的target hidden和encoder中的source hidden的score怎么算呢?文中介绍三种算法:

5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现

这里把score称为基于内容的函数,原因是这里的计算只考虑了隐状态间内容的相关性,并不包含时序信息。

(1)dot:transformer中的Q、K就是用dot计算

(2)general:俗称乘法注意力机制

(3)concat:俗称加法注意力机制

当然文中也提到了location based的global attention不过,似乎(不太确定可能本人孤陋寡闻基本没用过)是历史遗留的产物,大家可以自行阅读。

3.2 local attention

此处,出于这篇文章的完整性,提及一下文章还提到了local attention。提出原因是global attention的计算量大。以机器翻译任务为例,如果原文和译文语种差异较大采用global attention,但如果句法上对应较好,可以采用local attention尝试。

4、pytorch实现Seq2Seq中dot型attention的注意力

这里实现一个3.1中dot型的注意力,输入为encoder的各层隐状态encoder_states以及当前的decoder隐状态decoder_state_t,输出为注意力加权后的上下文状态c

class Seq2SeqAttentionMechanism(nn.Module):

    def __init__(self):
        super(Seq2SeqAttentionMechanism, self).__init__()

    def forward(self, decoder_state_t, encoder_states):

        bs, source_length, hidden_size = encoder_states.shape

        decoder_state_t = decoder_state_t.unsqueeze(1)
        decoder_state_t = torch.tile(decoder_state_t, dims = (1, source_length, 1))

        score = torch.sum(decoder_state_t * encoder_states, dim = -1) #[bs, source_length]

        attn_prob = F.softmax(score, dim = -1) #[bs, source_length]

        context = torch.sum(attn_prob.unsqueeze(-1) * encoder_states, 1) #[bs, hidden_size]

        return attn_prob, context

Original: https://blog.csdn.net/chen10314/article/details/123981383
Author: 沉睡的小卡比兽
Title: 5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/532075/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球