self-attention和rnn计算复杂度的对比

Attention is all you need论文中的实验分析部分罗列了self-attention和rnn的复杂度对比,特此记录一下自己对二者复杂度的分析。

self-attention和rnn计算复杂度的对比

注意:n表示序列长度,d表示向量维度。
1、self-attention的复杂度为O ( n 2 ⋅ d ) O(n^{2} \cdot d)O (n 2 ⋅d ),其来源自self-attention计算公式:
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d k ) V Attention(Q,K,V)=Softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V A t t e n t i o n (Q ,K ,V )=S o f t m a x (d k ​​Q K T ​)V
其中,Q 、 K 、 V ∈ R n × d Q、K、V\in \mathbb{R}^{n \times d}Q 、K 、V ∈R n ×d,
Q K T QK^{T}Q K T是两个矩阵的乘法[ n , d ] × [ d , n ] = [ n , n ] [n,d] \times [d,n]=[n,n][n ,d ]×[d ,n ]=[n ,n ],计算复杂度为n 2 ⋅ d n^{2} \cdot d n 2 ⋅d;
其结果再乘V V V,即[ n , n ] × [ n , d ] = [ n , d ] [n,n] \times [n,d]=[n,d][n ,n ]×[n ,d ]=[n ,d ],计算复杂度也为n 2 ⋅ d n^{2} \cdot d n 2 ⋅d;

2、RNN的复杂度为O ( n ⋅ d 2 ) O(n \cdot d^{2})O (n ⋅d 2 ),其来源自计算公式:
h t = f ( W x h x t + b x h + W h h h t − 1 + b h h ) h_{t}=f(W_{xh}x_{t}+b_{xh}+W_{hh}h_{t-1}+b_{hh})h t ​=f (W x h ​x t ​+b x h ​+W h h ​h t −1 ​+b h h ​) y t = g ( W h y h t + b h t ) y_{t}=g(W_{hy}h_{t}+b_{ht})y t ​=g (W h y ​h t ​+b h t ​)
W x h ∈ R e m b × d W_{xh}\in \mathbb{R}^{emb \times d}W x h ​∈R e m b ×d,W h h ∈ R d × d W_{hh}\in \mathbb{R}^{d \times d}W h h ​∈R d ×d,
从W h h h t − 1 W_{hh}h_{t-1}W h h ​h t −1 ​来看,虽然W h h W_{hh}W h h ​在前边,但是做矩阵乘法的时候是 h t − 1 × W h h T h_{t-1} \times W_{hh}^{T}h t −1 ​×W h h T ​,即[ 1 , d ] × [ d , d ] = [ 1 , d ] [1,d] \times [d,d]=[1,d][1 ,d ]×[d ,d ]=[1 ,d ],计算复杂度为d ⋅ d d \cdot d d ⋅d;
以上是一个输入的计算复杂度,n个输入的计算复杂度为n ⋅ d 2 n \cdot d^{2}n ⋅d 2。

Original: https://blog.csdn.net/tailonh/article/details/123889034
Author: 想念@思恋
Title: self-attention和rnn计算复杂度的对比

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/527715/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球