如果是二维的矩阵相乘,那就跟平时咱们做的矩阵乘法一样:
a = torch.tensor([[1,2], [3,4]])
a
Out[31]:
tensor([[1, 2],
[3, 4]])
b = torch.tensor([[2,2], [3,4]])
b
Out[33]:
tensor([[2, 2],
[3, 4]])
torch.matmul(a, b)
Out[34]:
tensor([[ 8, 10],
[18, 22]])
torch.matmul(a, b).shape
Out[35]: torch.Size([2, 2])
如果维度更高呢? 前面的维度必须要相同,然后最里面的两个维度符合矩阵相乘的形状限制:i× j
, j
×k。
a = torch.tensor([[[1,2], [3,4], [5,6]],[[7,8], [9,10], [11,12]]])
a
Out[37]:
tensor([[[ 1, 2],
[ 3, 4],
[ 5, 6]],
[[ 7, 8],
[ 9, 10],
[11, 12]]])
a.shape
Out[38]: torch.Size([2, 3, 2])
b = torch.tensor([[[1,2], [3,4]],[[7,8], [9,10]]])
b
Out[40]:
tensor([[[ 1, 2],
[ 3, 4]],
[[ 7, 8],
[ 9, 10]]])
b.shape
Out[41]: torch.Size([2, 2, 2])
torch.matmul(a, b)
Out[42]:
tensor([[[ 7, 10],
[ 15, 22],
[ 23, 34]],
[[121, 136],
[153, 172],
[185, 208]]])
torch.matmul(a, b).shape
Out[43]: torch.Size([2, 3, 2])
这里举一个例子,在某一篇论文的代码中,作者使用matmul的场景。
简单地说,就是用过matmul()函数实现subject 的 lookup
假设有下面这么一个矩阵,shape为[batch_size, 1, seq_len],该矩阵的含义是,最里面的每一个[ ] s e q l e n []_{seq_len}[]s e q l e n 表示一个句子的序列,如果元素为1,则表示该下标可以作为subject的head index。并且在每一行中,只有一个1。也就是只有一个subject的head index。
现在有另外一个矩阵,shape为[batch_size, seq_len, bert_dim]。该矩阵的含义是整个batch的text([batch_size, seq_len])经过经过bert encoder之后得到的。
根据前面说的,二者相乘,得到的shape是[batch_size, 1, bert_dim]。
比如第一行 [ 0 , 1 , 0 , . . . . 0 , 0 ] × b e r t e n c o d e 之 后 的 矩 阵 = [ 0.3 , 0.1 , . . . , 0 ] [0,1,0,….0,0]×bert encode之后的矩阵=[0.3, 0.1, …, 0][0 ,1 ,0 ,….0 ,0 ]×b e r t e n c o d e 之后的矩阵=[0 .3 ,0 .1 ,…,0 ]
最后得到的是subject在bert encode之后的空间中look up,或者说嵌入以后的向量。
Original: https://blog.csdn.net/qq_35056292/article/details/115689909
Author: y4ung
Title: torch.matmul() 张量相乘
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/542999/
转载文章受原作者版权保护。转载请注明原作者出处!