torch.Tensor.index_add_函数,pytorch中的tf.unsorted_segment_sum

ref:

torch.Tensor.index_add_能实现指定行或列的内容相加的功能,类似于tensorflow中 tf.unsorted_segment_sum函数,可以用在比如实例分割中进行特征聚合的步骤。比如一个 N*C的feature根据实例label可以将属于同一实例的点的特征聚合起来,得到 Ins_num*C的聚合特征。

1. 函数的参数

torch.Tensor.index_add_函数,pytorch中的tf.unsorted_segment_sum
  • dim:这个参数表明你要沿着哪个维度索引;
  • index:包含索引的tensor;
  • tensor:被索引出来去相加的tensor;
  • 注意事项: x相加前后的shape保持不变,被索引的tensor在被索引的维度(第dim维)之外的维度上与tensor的对应维度必须保持一致,且 index中的值最大不能超过 x在被索引的维度上的最大维数, index的长度必须和 tensor[dim]相同。假如x的 shape(N, C),索引的维度为第0维( dim=0),那么被索引的tensor的 dim=1的维度也必须为 C,index的值必须介于 0C-1之间, index的长度必须和被索引的tensor的 dim=0的数字相同。

; 2. 使用示例

import torch
x = torch.ones(5, 3)
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float)
index = torch.tensor([0, 2, 4, 2])
new_x = x.index_add_(0, index, t)
print('new_x: {}'.format(new_x))
new_x: tensor([[ 2.,  3.,  4.],
        [ 1.,  1.,  1.],
        [15., 17., 19.],
        [ 1.,  1.,  1.],
        [ 8.,  9., 10.]])

解释一下: x.index_add_()表示在x的每一行加上 index从t中索引出来的值,这个例子中x初始为一个5行3列全为1的tensor。

  • 如何确定x的第 i行要加上的值,首先通过 index[j]=i找到所有满足条件的 j,再把所有的 t[j]加上 x[i]就得到新的 x[i]。一行行来看。
  • new_x的第0行:首先去找 index中值为0的的索引,找 index[j]=0,所以 j=0,新的 new_x[0]=t[0]+x[0],即 new_x[0]=[1, 2, 3]+[1, 1, 1]=[2, 3, 4]
  • new_x的第1行:首先去找 index中值为1的的索引,找不到对应的j,所以没有东西可以加上, new_x[1]保持不变。
  • new_x的第2行:首先去找 index中值为2的的索引,找 index[j]=2,所以 j=1, 3,新的 new_x[2]=t[1]+t[3]+x[0],即 new_x[2]=[4, 5, 6]+[10, 11, 12]+[1, 1, 1]==[15, 17, 19]
  • new_x的第3行:首先去找 index中值为3的的索引,找不到对应的j,所以没有东西可以加上, new_x[3]保持不变。
  • new_x的第4行:首先去找 index中值为4的的索引,找 index[j]=4,所以 j=2,新的 new_x[4]=t[2]+x[4],即 new_x[4]=[7, 8, 9]+[1, 1, 1]=[8, 9, 10]

3. 使用(我自己看的)

  • 可以根据自己的工程需要去分配每一个输入的值。
  • 假设我需要聚合属于相同实例的点的特征,我可以把初始的x设为shape为 (Ins_num, C)的全0数组;t设置为需要被索引的feature,其shape为 (N, C);索引就可以为实例label,shape为 (N, )
  • label中为每个点所属的实例类别,其值为 0Ins_num-1。(如果这个当中某些背景点的label为-100或者其他值就要注意了,索引会报错,记得处理一下)
  • 那么最终就可以得到一个shape为 (Ins_num, C)的新tensor,每 i行代表对应实例label为 i的所有点相加后的特征。

Original: https://blog.csdn.net/zyoung17/article/details/116589641
Author: zyoung17
Title: torch.Tensor.index_add_函数,pytorch中的tf.unsorted_segment_sum

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/521254/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球