「解析」如何优雅的学习 torch.einsum()

「解析」如何优雅的学习 torch.einsum()

「解析」
Einsum 是简约版的 ‘求和公式’ ,故在看 einsum 公式的时候可以 反推 原计算过程
eg: C = t o r c h . e i n s u m ( ” i j k , k l − > i j l ” , A , B ) C i j l = ∑ k A i j k B k l = A i j k B k l C = torch.einsum(“ijk,kl->ijl”,A,B) \qquad C_{ijl}=\sum_k \red{A_{ijk}} \blue{B_{kl}}=\red{A_{ijk}} \blue{B_{kl}}C =t o r c h .e i n s u m (“i j k ,k l −>i j l “,A ,B )C i j l ​=∑k ​A i j k ​B k l ​=A i j k ​B k l ​
自由标:i j l ijl \qquad i j l 哑标:k k \qquad k 且自由标顺序out顺序提取 i j l ijl i j l
因此,einsum 可以计算不同维度的和,如 vector & matrix,matrix & tensor 等,只要有哑标即可,
输出维度按 out 顺序 for 循环

Einsum 求和过程理论上等价于如下四步:

  1. 维度对齐:将所有标记按字母序排序,按照标记顺序将输入张量逐一转置、补齐维度,使得处理后的所有张量其维度标记保持一致
  2. 广播点乘:以out标顺序为索引进行广播点乘「out标 为空时,对输入进行全量求和,返回标量」
    ⚠️注意: 若没有 「 -> 和 out标」,则按照字母顺序自动调整不同维度
    ⚠️注意: 若有 -> 没有 out标,则对输入进行全量求和,返回标量
  3. 维度规约:将哑标对应的维度分量求和消除「按照equation 哑标顺序」
  4. 转置输出:若存在out标,则按 out标 进行输出「若out标 为空时,对输入进行全量求和,返回标量」

文章目录

*
einsum方法解析
1、einsum 公式约定
2、torch.einsum() 方法原理
3、向量操作
4、矩阵操作
5、实例

+ 5.1 MATRIX TRANSPOSE
+ 5.2 SUM
+ 5.3 COLUMN SUM
+ 5.4 ROW SUM
+ 5.5 MATRIX-VECTOR MULTIPLICATION
+ 5.6 MATRIX-MATRIX MULTIPLICATION
+ 5.7 DOT PRODUCT
+ 5.8 HADAMARD PRODUCT
+ 5.9 OUTER PRODUCT
+ 5.10 BATCH MATRIX MULTIPLICATION
+ 5.11 TENSOR CONTRACTION
+ 5.12 BILINEAR TRANSFORMATION
参考文献

; einsum方法解析

Einsum 是爱因斯坦在研究广义相对论时,需要处理大量求和运算,为了简化这种繁复的运算,提出了求和约定,推动了张量分析的发展,具有重要意义!einsum 在Pytorch、TensorFlow、numpy中一个十分优雅的方法。Einsum 可以计算向量、矩阵、张量运算,包括计算 transposes、sum、column/row sum、Matrix-Vector Multiplication、Matrix-Matrix Multiplication。如果利用得当,sinsum绝对是你科研路上的一把利器,可完全代替其他的矩阵计算方法。

1、einsum 公式约定

爱因斯坦求和是一种对求和公式简洁高效的记法
其原则是当变量下标重复出现时,即可省略繁琐的求和符号。

比如 矩阵点积 公式:

M i j = ∑ k A i k B k j = A i k B k j M = e i n s u m ( ′ i k , k j − > i j ′ , A , B ) M_{ij=}\sum_{k} A_{ik}B_{kj}=A_{ik}B_{kj} \qquad \color{red} \mathbf{M = einsum(‘ik,kj ->ij’, A, B)}M i j =​k ∑​A i k ​B k j ​=A i k ​B k j ​M =e i n s u m (′i k ,k j −>i j ′,A ,B )

哑标: 必须是重复一次的,且在每一项中的重复次数不能多于1次;含义就是虚设的指标,只是临时性的,经过求和之后就消失了;
自由标: 在表达式的每一项中,出现一次 且 仅出现一次,用同一字母,表示方程或变量的数目,并不作求和运算;

Einsum 标记的约定

  1. 维度分量下标:张量的维度分量下标使用英文字母表示,不区分大小写,如’ijk’表示张量维度分量为i,j,k
  2. 下标对应输入操作数:维度下标以 ,分段,按顺序1-1对应输入操作数
  3. 广播维度:省略号 ...表示维度的广播分量,例如,’i…j’ 表示首末分量除外的维度需进行广播对齐
  4. 自由标和哑标:输入标记中仅出现一次的下标为自由标,重复出现的下标为哑标,哑标对应的维度分量将被规约消去
  5. 输出:输出张量的维度分量既可由输入标记自动推导,也可以用输出标记定制化
  6. 自动推导输出
    广播维度分量位于维度向量高维位置,自由标维度分量按字母顺序排序,位于维度向量低纬位置,哑标维度分量不输出
  7. 定制化输出
    若输出包含广播维度,则输出标记需包含 ...
    哑标出现在输出标记中则自动提升为自由标
    输出标记中未出现的自由标被降为哑标

for i in range(3):
    for j in range(4):
        for l in range():

            total = 0
            for k in range(5):
                total += A[i,j,k] * B[k,l]
            M[i,j] = total

2、torch.einsum() 方法原理

Sums the product of the elements of the input operands along dimensions specified using a notation based on the Einstein summation convention.

einsum方法正是利用了爱因斯坦求和简介高效的表示方法,从而可以驾驭任何复杂的矩阵计算操作。基本的框架如下:

C = einsum('ij,jk->ik', A, B)

上述操作表示矩阵A与矩阵B的点积。

输入的参数分为两部分

  • equation (str): 求和标记 计算操作的字符串
  • operands (Tensor, [Tensor, …]): 输入张量 操作对象(数量及维度需与前面对应)

3、向量操作

Let A and B be two 1D arrays of compatible shapes (meaning the lengths of the axes we pair together either equal, or one of them has length 1):

参数数学含义描述(‘i’, A)
A A A

返回A的视图(‘i->’, A)
s u m ( A ) sum(A)s u m (A )

A的元素总和(‘i,i->i’, A, B)
A ∗ B A * B A ∗B

A与B 逐元素依次相乘(‘i,i’, A, B)
i n n e r ( A , B ) inner(A, B)i n n e r (A ,B )

A与B的 点积(内积)(‘i,j->ij’, A, B)
o u t e r ( A , B ) outer(A, B)o u t e r (A ,B )

A与B的 外积(叉积)

4、矩阵操作

Now let A and B be two 2D arrays with compatible shapes:

参数数学含义描述(‘ij’, A)
A A A

返回A的视图(‘ji’, A)
A T A^T A T

A的转置(‘ii->i’, A)
d i a g ( A ) diag(A)d i a g (A )

A的主对角线(‘ii’, A)
t r a c e ( A ) trace(A)t r a c e (A )

A的迹(‘ij->’, A)
s u m ( A ) sum(A)s u m (A )

A的值累加和(‘ij->i’, A)
s u m ( A , a x i s = 1 ) sum(A, axis=1)s u m (A ,a x i s =1 )

对A的行(水平轴)求和(‘ij->j’, A)
s u m ( A , a x i s = 0 ) sum(A, axis=0)s u m (A ,a x i s =0 )

对A的列(竖直轴)求和(‘ij,ij->ij’, A, B)
A ∗ B A * B A ∗B

A与B逐元素依次相乘(‘ij,ji->ij’, A, B)
A ∗ B T A * B^T A ∗B T

A与B的转置逐元素依次相乘(‘ij,jk’, A, B)
d o t ( A , B ) dot(A, B)d o t (A ,B )

A与B 的点积(‘ij,kj->ik’, A, B)
i n n e r ( A , B ) inner(A, B)i n n e r (A ,B )

A与B 的内积(‘ij,kj->ijk’, A, B)
A [ : , N o n e ] ∗ B A[:, None] * B A [:,N o n e ]∗B

A的每一行乘以B(‘ij,kl->ijkl’, A, B)
A [ : , : , N o n e , N o n e ] ∗ B A[:, :, None, None] * B A [:,:,N o n e ,N o n e ]∗B

A的每个值乘以B

When working with larger numbers of dimensions, keep in mind that einsum allows the ellipses syntax ‘…’. This provides a convenient way to label the axes we’re not particularly interested in, e.g. np.einsum(‘…ij,ji->…’, a, b) would multiply just the last two axes of a with the 2D array b. There are more examples in the documentation.

5、实例

einsum方法在numpy和pytorch中均有内置,这里以pytorch为例,首先定义一些需要用到的变量:

import torch
from torch import einsum
a = torch.rand((3,4))
b = torch.rand((4,5))
c = torch.rand((6,7,8))
d = torch.rand((3,4))
x, y = torch.randn(5), torch.randn(5)

einsum('i,j->', a)
einsum('i,j,k', b)

einsum('ii->', a)

einsum('ii->i', a)

einsum('i,j->ij', x, y)

einsum('ij,kj->ik',b, x)

einsum('ij,jk->ik', a, b)

einsum('ij,ij->ij', a, d)

einsum('ijk->ikj', c)
einsum('...jk->...kj', c)

A = torch.randn(3,5,4)
l = torch.randn(2,5)
r = torch.randn(2,4)
torch.einsum('bn,anm,bm->ba', l, A, r)

5.1 MATRIX TRANSPOSE

B j i = A i j B_{ji}=A_{ij}B j i ​=A i j ​

import torch
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->ji', [a])
tensor([[ 0.,  3.],
        [ 1.,  4.],
        [ 2.,  5.]])

5.2 SUM

b = ∑ i ∑ j A i j = A i j b=\sum_i\sum_j A_{ij}=A_{ij}b =i ∑​j ∑​A i j ​=A i j ​

a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)

5.3 COLUMN SUM

b j = ∑ i A i j = A i j b_j=\sum_iA_{ij}=A_{ij}b j ​=i ∑​A i j ​=A i j ​

a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3.,  5.,  7.])

5.4 ROW SUM

b i = ∑ j A i j = A i j b_i=\sum_j A_{ij}=A_ij b i ​=j ∑​A i j ​=A i ​j

a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->i', [a])
tensor([  3.,  12.])

5.5 MATRIX-VECTOR MULTIPLICATION

c i = ∑ k A i k b k = A i k b k c_i=\sum_k \red{A_{ik}} \blue{b_k}=\red{A_{ik}} \blue{b_k}c i ​=k ∑​A i k ​b k ​=A i k ​b k ​

a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])
tensor([  5.,  14.])

5.6 MATRIX-MATRIX MULTIPLICATION

C i j = ∑ k A i k B k j = A i k B k j C_{ij}=\sum_k \red{A_{ik}} \blue{B_{kj}}=\red{A_{ik}} \blue{B_{kj}}C i j ​=k ∑​A i k ​B k j ​=A i k ​B k j ​

a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
torch.einsum('ik,kj->ij', [a, b])
tensor([[  25.,   28.,   31.,   34.,   37.],
        [  70.,   82.,   94.,  106.,  118.]])

5.7 DOT PRODUCT

Vector:
c = ∑ i a i b i = a i b i c=\sum_i \red{a_i} \blue{b_i}=\red{a_i} \blue{b_i}c =i ∑​a i ​b i ​=a i ​b i ​

a = torch.arange(3)
b = torch.arange(3,6)
torch.einsum('i,i->', [a, b])
tensor(14.)

Matrix:
c = ∑ i ∑ j A i j B i j = A i j B i j c=\sum_i\sum_j \red{A_{ij}} \blue{B_{ij}}=\red{A_{ij}} \blue{B_{ij}}c =i ∑​j ∑​A i j ​B i j ​=A i j ​B i j ​

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)

5.8 HADAMARD PRODUCT

C i j = A i j B i j C_{ij}=\red{A_{ij}} \blue{B_{ij}}C i j ​=A i j ​B i j ​

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[  0.,   7.,  16.],
        [ 27.,  40.,  55.]])

5.9 OUTER PRODUCT

C i j = a i b j C_{ij}=\red{a_i} \blue{b_j}C i j ​=a i ​b j ​

a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b])
tensor([[  0.,   0.,   0.,   0.],
        [  3.,   4.,   5.,   6.],
        [  6.,   8.,  10.,  12.]])

5.10 BATCH MATRIX MULTIPLICATION

C i j l = ∑ k A i j k B i k l = A i j k B i k l C_{ijl}=\sum_k \red{A_{ijk}} \blue{B_{ikl}}=\red{A_{ijk}} \blue{B_{ikl}}C i j l ​=k ∑​A i j k ​B i k l ​=A i j k ​B i k l ​

a = torch.randn(3,2,5)
b = torch.randn(3,5,3)
torch.einsum('ijk,ikl->ijl', [a, b])
tensor([[[ 1.0886,  0.0214,  1.0690],
         [ 2.0626,  3.2655, -0.1465]],

        [[-6.9294,  0.7499,  1.2976],
         [ 4.2226, -4.5774, -4.8947]],

        [[-2.4289, -0.7804,  5.1385],
         [ 0.8003,  2.9425,  1.7338]]])

5.11 TENSOR CONTRACTION

Batch matrix multiplication is a special case of a tensor contraction. Let’s say we have two tensors, an order-n tensor A ∈ R I 1 × ⋯ × I n \red{A}\in ℝ^{I_1×⋯×I_n}A ∈R I 1 ​×⋯×I n ​ and an order-m tensor B ∈ R J 1 × ⋯ × I m \blue{B}∈ℝ^{J_1×⋯×I_m}B ∈R J 1 ​×⋯×I m ​. As an example, take n = 4 , m = 5 n=4, m=5 n =4 ,m =5 and assume that I 2 = J 3 a n d I 3 = J 5 I_2=J_3 and I_3=J_5 I 2 ​=J 3 ​a n d I 3 ​=J 5 ​. We can multiply the two tensors in these two dimensions (2 and 3 for A \red A A and 3 and 5 for B \blue B B) resulting in a new tensor C ∈ R I 1 × I 4 × J 1 × J 2 × J 4 C∈ℝ^{I_1×I_4×J_1×J_2×J_4}C ∈R I 1 ​×I 4 ​×J 1 ​×J 2 ​×J 4 ​ as follows

C p s t u v = ∑ q ∑ r A p q r s B t u q v r = A p q r s B t u q v r C_{pstuv}=\sum_q\sum_r \red{A_{pqrs}} \blue{B_{tuqvr}}=\red{A_{pqrs}} \blue{B_{tuqvr}}C p s t u v ​=q ∑​r ∑​A p q r s ​B t u q v r ​=A p q r s ​B t u q v r ​

a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape
torch.Size([2, 7, 11, 13, 17])

5.12 BILINEAR TRANSFORMATION

As mentioned earlier, einsum can operate on more than two tensors. One example where this is used is bilinear transformation.

D i j = ∑ k ∑ l A i k B j k l C i l = A i k B j k l C i l D_{ij}=\sum_k \sum_l \red{A_{ik}} \purple{B_{jkl}} \blue{C_{il}}=\red{A_{ik}} \purple{B_{jkl}} \blue{C_{il}}D i j ​=k ∑​l ∑​A i k ​B j k l ​C i l ​=A i k ​B j k l ​C i l ​

a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
torch.einsum('ik,jkl,il->ij', [a, b, c])

np_out = np.empty((2, 5), dtype=np.float32)

for i in range(0, 2):
    for j in range(0, 5):

        sum_result = 0
        for k in range(0, 3):
            for l in range(0, 7):
                sum_result += a[i, k] * b[j, k, l] * c[i, l]
        np_out[i, j] = sum_result

tensor([[ 3.8471,  4.7059, -3.0674, -3.2075, -5.2435],
        [-3.5961, -5.2622, -4.1195,  5.5899,  0.4632]])

参考文献

  1. https://rockt.github.io/2018/04/30/einsum
  2. https://ajcr.net/Basic-guide-to-einsum/
  3. https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/einsum_cn.html
  4. https://blog.csdn.net/qq_32768743/article/details/109131936
  5. https://dengbocong.blog.csdn.net/article/details/109566151
  6. https://zhuanlan.zhihu.com/p/46006162

Original: https://blog.csdn.net/ViatorSun/article/details/122710515
Author: ViatorSun
Title: 「解析」如何优雅的学习 torch.einsum()

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/758670/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球