Torch和Numpy的高级索引,即,方括号中还有一个Tensor或Numpy

最近看代码,看到了一个Tensor的方括号中还有一个Tensor,给爷看懵了。

new_embeddings = new_token_embeddings[input_flags]

其中:各个变量都是Tensor,维度分别是
new_token_embeddings:5 × 1024
input_flags:8 × 512
new_embeddings :8 × 512 × 1024

你需要先明白,方括号里的是索引,那么input_flags怎么索引new_token_embeddings的向量呢,也就是怎么查的呢。

可以把new_token_embeddings当成一个字典,包含5个字,每个字的含义是1024维。然后, input_flags的 每个位置上的值就是去这个字典中查这个字对应的含义。input_flags有多少个字呢?8 × 512,所以最终查询的结果是,大声告诉我,8 × 512 × 1024,对!

深度解释

关键字:整数索引

  • input_flags如果是维度为1,new_token_embeddings维度为1,则取new_token_embeddings对应的 ;new_token_embeddings维度为2,则取new_token_embeddings对应的
  • input_flags如果维度为2,每一行每个值对应从new_token_embeddings中取一行。

深度样例

>>> a
array([[8, 8, 7, 4],
       [7, 8, 0, 9],
       [7, 4, 0, 5]])
>>> a[1]
array([7, 8, 0, 9])
>>> a[[1]]
array([[7, 8, 0, 9]])
>>> a[[1, 2]]
array([[7, 8, 0, 9],
       [7, 4, 0, 5]])
>>> a[np.array([[1, 2], [2, 0]])]
array([[[7, 8, 0, 9],
        [7, 4, 0, 5]],
       [[7, 4, 0, 5],
        [8, 8, 7, 4]]])

>>> a[[[1, 2], [2, 0]]]
array([0, 7])

>>> a[np.array(1)]
array([7, 8, 0, 9])
>>> a[np.array([1])]
array([[7, 8, 0, 9]])
>>> a[np.array([1, 2])]
array([[7, 8, 0, 9],
       [7, 4, 0, 5]])

参考感谢 了解更多👌

Original: https://blog.csdn.net/xiangduixuexi/article/details/124020005
Author: 正门大石狮
Title: Torch和Numpy的高级索引,即,方括号中还有一个Tensor或Numpy

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/759158/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球