最近看代码,看到了一个Tensor的方括号中还有一个Tensor,给爷看懵了。
new_embeddings = new_token_embeddings[input_flags]
其中:各个变量都是Tensor,维度分别是
new_token_embeddings:5 × 1024
input_flags:8 × 512
new_embeddings :8 × 512 × 1024
你需要先明白,方括号里的是索引,那么input_flags怎么索引new_token_embeddings的向量呢,也就是怎么查的呢。
可以把new_token_embeddings当成一个字典,包含5个字,每个字的含义是1024维。然后, input_flags的 每个位置上的值就是去这个字典中查这个字对应的含义。input_flags有多少个字呢?8 × 512,所以最终查询的结果是,大声告诉我,8 × 512 × 1024,对!
深度解释
关键字:整数索引
- input_flags如果是维度为1,new_token_embeddings维度为1,则取new_token_embeddings对应的 值;new_token_embeddings维度为2,则取new_token_embeddings对应的 行。
- input_flags如果维度为2,每一行每个值对应从new_token_embeddings中取一行。
深度样例
>>> a
array([[8, 8, 7, 4],
[7, 8, 0, 9],
[7, 4, 0, 5]])
>>> a[1]
array([7, 8, 0, 9])
>>> a[[1]]
array([[7, 8, 0, 9]])
>>> a[[1, 2]]
array([[7, 8, 0, 9],
[7, 4, 0, 5]])
>>> a[np.array([[1, 2], [2, 0]])]
array([[[7, 8, 0, 9],
[7, 4, 0, 5]],
[[7, 4, 0, 5],
[8, 8, 7, 4]]])
>>> a[[[1, 2], [2, 0]]]
array([0, 7])
>>> a[np.array(1)]
array([7, 8, 0, 9])
>>> a[np.array([1])]
array([[7, 8, 0, 9]])
>>> a[np.array([1, 2])]
array([[7, 8, 0, 9],
[7, 4, 0, 5]])
参考感谢 了解更多👌
Original: https://blog.csdn.net/xiangduixuexi/article/details/124020005
Author: 正门大石狮
Title: Torch和Numpy的高级索引,即,方括号中还有一个Tensor或Numpy
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/759158/
转载文章受原作者版权保护。转载请注明原作者出处!