ONNXRUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:‘Where‘

遇到此类错误,如:

onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Gather node. Name:’Gather_4445′ Status Message: indices element out of data bounds, idx=8 must be within the inclusive range [-3,2]

RUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:’Where’ Status Message…

可以配合Netron工具(安装方法:pip install netron,使用时终端输入netron)查看导出的onnx模型网络图,可以查找相应的Node(如:Where_XXXX),再去代码中找对应代码,将其改为onnx支持的tensor运算方式即可解决相应问题。

根据在ONNX导出时遇到的问题比较麻烦的是和torch.gather、torch.where、torch.split等Tensor运算方法。

  1. torch.where函数

torch.where(condition,x,y)->tensor

当满足condition,则来自于a,反之来自b

import torch
condition=torch.randn(2,2)
tensor([[ 0.2589, -0.5600],
       [ 0.9056, -0.3915]])
a=torch.tensor([[0,0],[0,0]])
b=torch.tensor([[1,1],[1,1]])
torch.where(cond>0.5,a,b)

得到结果

tensor([[1, 1],
        [0, 1]])

输出为0的代表来源为a,输出为1的代表来源为b

2. torch.gather(查表的过程)

torch.gather(input,dim,index,out=None)->tensor

就像是给了数据以后,查表得到对应参数,再收集回来进行输出。

gather函数即为gather(对应的参数表,dim,数据表)

import torch
prob=torch.randn(4,4)
#tensor([[-0.9845,  0.5094, -0.5014, -0.5354],
       [-1.8514,  0.2640,  0.7895, -0.1660],
       [ 0.3955,  0.7571,  0.1451,  0.1970],
       [ 0.3674, -0.8006, -0.5625,  1.3455]])
idx=prob.topk(dim=1,k=2)
idx=idx[1]
#tensor([[1, 2],
       [2, 1],
       [1, 0],
       [3, 0]]))
label=torch.arange(4)+100
#tensor([100, 101, 102, 103])
torch.gather(label.expand(4,4),dim=1,index=idx.long())

输出结果:

  1. torch.split

含义:将一个张量分为几个chunks

torch.split(tensor, split_size_or_sections, dim=0)

  • tensor(Tensor) -张量分裂。
  • split_size_or_sections(int) 或者 (_list (int )_) -单个块的大小或每个块的大小列表
  • dim(int) -沿其分割张量的维度。

如果 split_size_or_sections 是整数类型,那么tensor将被分成大小相等的块(如果可能)。如果沿给定维度 dim 的张量大小不能被 split_size 整除,则最后一个块会更小。

如果 split_size_or_sections 是一个列表,那么 tensor 将根据 split_size_or_sections 被拆分为大小在 dim 中的 len(split_size_or_sections) 块。

示例:

>>> a = torch.arange(8).reshape(4,2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7]])
>>> torch.split(a, 3)
(tensor([[0, 1],
         [2, 3],
         [4, 5]]),
 tensor([[6, 7]]))
>>> torch.split(a, [1,3])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7]]))

4. Tensor.scatter_函数

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

For a 3-D tensor, self is updated as:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

self, index and src (if it is a Tensor) should all have the same number of dimensions. It is also required that index.size(d)&#xA0;<= src.size(d)< code> for all dimensions <code>d</code>, and that <code>index.size(d)&#xA0;<= self.size(d)< code> for all dimensions <code>d&#xA0;!=&#xA0;dim</code>. Note that <code>index</code> and <code>src</code> do not broadcast.<!--= self.size(d)<--></code><!--= src.size(d)<-->

Moreover, as for gather(), the values of index must be between 0 and self.size(dim)&#xA0;-&#xA0;1 inclusive.

Parameters

  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.

  • src (Tensor or float) – the source element(s) to scatter.

  • reduce (str , optional) – reduction operation to apply, can be either 'add' or 'multiply'.

总结:scatter函数就是把src数组中的数据重新分配到output数组当中,index数组中表示了要把src数组中的数据分配到output数组中的位置,若未指定,则填充0.

举例:

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])

>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 3.2300]])

Original: https://blog.csdn.net/andrewchen1985/article/details/125332942
Author: AndrewChen1985
Title: ONNXRUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:‘Where‘

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/688714/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球