ref:
- https://pytorch.org/docs/1.4.0/tensors.html?highlight=index_add_#torch.Tensor.index_add_
- https://blog.csdn.net/weixin_44289071/article/details/103882658
torch.Tensor.index_add_能实现指定行或列的内容相加的功能,类似于tensorflow中 tf.unsorted_segment_sum
函数,可以用在比如实例分割中进行特征聚合的步骤。比如一个 N*C
的feature根据实例label可以将属于同一实例的点的特征聚合起来,得到 Ins_num*C
的聚合特征。
1. 函数的参数
- dim:这个参数表明你要沿着哪个维度索引;
- index:包含索引的tensor;
- tensor:被索引出来去相加的tensor;
- 注意事项:
x
相加前后的shape保持不变,被索引的tensor在被索引的维度(第dim维)之外的维度上与tensor的对应维度必须保持一致,且index
中的值最大不能超过x
在被索引的维度上的最大维数,index
的长度必须和tensor[dim]
相同。假如x的shape
为(N, C)
,索引的维度为第0维(dim=0
),那么被索引的tensor的dim=1
的维度也必须为C
,index的值必须介于0
和C-1
之间,index
的长度必须和被索引的tensor的dim=0
的数字相同。
; 2. 使用示例
import torch
x = torch.ones(5, 3)
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float)
index = torch.tensor([0, 2, 4, 2])
new_x = x.index_add_(0, index, t)
print('new_x: {}'.format(new_x))
new_x: tensor([[ 2., 3., 4.],
[ 1., 1., 1.],
[15., 17., 19.],
[ 1., 1., 1.],
[ 8., 9., 10.]])
解释一下: x.index_add_()
表示在x的每一行加上 index
从t中索引出来的值,这个例子中x初始为一个5行3列全为1的tensor。
- 如何确定x的第
i
行要加上的值,首先通过index[j]=i
找到所有满足条件的j
,再把所有的t[j]
加上x[i]
就得到新的x[i]
。一行行来看。 - new_x的第0行:首先去找
index
中值为0的的索引,找index[j]=0
,所以j=0
,新的new_x[0]=t[0]+x[0]
,即new_x[0]=[1, 2, 3]+[1, 1, 1]=[2, 3, 4]
。 - new_x的第1行:首先去找
index
中值为1的的索引,找不到对应的j,所以没有东西可以加上,new_x[1]
保持不变。 - new_x的第2行:首先去找
index
中值为2的的索引,找index[j]=2
,所以j=1, 3
,新的new_x[2]=t[1]+t[3]+x[0]
,即new_x[2]=[4, 5, 6]+[10, 11, 12]+[1, 1, 1]==[15, 17, 19]
。 - new_x的第3行:首先去找
index
中值为3的的索引,找不到对应的j,所以没有东西可以加上,new_x[3]
保持不变。 - new_x的第4行:首先去找
index
中值为4的的索引,找index[j]=4
,所以j=2
,新的new_x[4]=t[2]+x[4]
,即new_x[4]=[7, 8, 9]+[1, 1, 1]=[8, 9, 10]
。
3. 使用(我自己看的)
- 可以根据自己的工程需要去分配每一个输入的值。
- 假设我需要聚合属于相同实例的点的特征,我可以把初始的x设为shape为
(Ins_num, C)
的全0数组;t设置为需要被索引的feature,其shape为(N, C)
;索引就可以为实例label,shape为(N, )
。 - label中为每个点所属的实例类别,其值为
0
到Ins_num-1
。(如果这个当中某些背景点的label为-100或者其他值就要注意了,索引会报错,记得处理一下) - 那么最终就可以得到一个shape为
(Ins_num, C)
的新tensor,每i
行代表对应实例label为i
的所有点相加后的特征。
Original: https://blog.csdn.net/zyoung17/article/details/116589641
Author: zyoung17
Title: torch.Tensor.index_add_函数,pytorch中的tf.unsorted_segment_sum
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/521254/
转载文章受原作者版权保护。转载请注明原作者出处!