遇到此类错误,如:
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Gather node. Name:’Gather_4445′ Status Message: indices element out of data bounds, idx=8 must be within the inclusive range [-3,2]
RUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:’Where’ Status Message…
可以配合Netron工具(安装方法:pip install netron,使用时终端输入netron)查看导出的onnx模型网络图,可以查找相应的Node(如:Where_XXXX),再去代码中找对应代码,将其改为onnx支持的tensor运算方式即可解决相应问题。
根据在ONNX导出时遇到的问题比较麻烦的是和torch.gather、torch.where、torch.split等Tensor运算方法。
- torch.where函数
torch.where(condition,x,y)->tensor
当满足condition,则来自于a,反之来自b
import torch
condition=torch.randn(2,2)
tensor([[ 0.2589, -0.5600],
[ 0.9056, -0.3915]])
a=torch.tensor([[0,0],[0,0]])
b=torch.tensor([[1,1],[1,1]])
torch.where(cond>0.5,a,b)
得到结果
tensor([[1, 1],
[0, 1]])
输出为0的代表来源为a,输出为1的代表来源为b
2. torch.gather(查表的过程)
torch.gather(input,dim,index,out=None)->tensor
就像是给了数据以后,查表得到对应参数,再收集回来进行输出。
gather函数即为gather(对应的参数表,dim,数据表)
import torch
prob=torch.randn(4,4)
#tensor([[-0.9845, 0.5094, -0.5014, -0.5354],
[-1.8514, 0.2640, 0.7895, -0.1660],
[ 0.3955, 0.7571, 0.1451, 0.1970],
[ 0.3674, -0.8006, -0.5625, 1.3455]])
idx=prob.topk(dim=1,k=2)
idx=idx[1]
#tensor([[1, 2],
[2, 1],
[1, 0],
[3, 0]]))
label=torch.arange(4)+100
#tensor([100, 101, 102, 103])
torch.gather(label.expand(4,4),dim=1,index=idx.long())
输出结果:
- torch.split
含义:将一个张量分为几个chunks
torch.split(tensor, split_size_or_sections, dim=0)
- tensor(Tensor) -张量分裂。
- split_size_or_sections(int) 或者 (_list (int )_) -单个块的大小或每个块的大小列表
- dim(int) -沿其分割张量的维度。
如果 split_size_or_sections
是整数类型,那么tensor将被分成大小相等的块(如果可能)。如果沿给定维度 dim
的张量大小不能被 split_size
整除,则最后一个块会更小。
如果 split_size_or_sections
是一个列表,那么 tensor 将根据 split_size_or_sections
被拆分为大小在 dim
中的 len(split_size_or_sections)
块。
示例:
>>> a = torch.arange(8).reshape(4,2)
>>> a
tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])
>>> torch.split(a, 3)
(tensor([[0, 1],
[2, 3],
[4, 5]]),
tensor([[6, 7]]))
>>> torch.split(a, [1,3])
(tensor([[0, 1]]),
tensor([[2, 3],
[4, 5],
[6, 7]]))
4. Tensor.scatter_函数
Writes all values from the tensor src
into self
at the indices specified in the index
tensor. For each value in src
, its output index is specified by its index in src
for dimension != dim
and by the corresponding value in index
for dimension = dim
.
For a 3-D tensor, self
is updated as:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
self
, index
and src
(if it is a Tensor) should all have the same number of dimensions. It is also required that index.size(d) <= src.size(d)< code> for all dimensions <code>d</code>, and that <code>index.size(d) <= self.size(d)< code> for all dimensions <code>d != dim</code>. Note that <code>index</code> and <code>src</code> do not broadcast.<!--= self.size(d)<--></code><!--= src.size(d)<-->
Moreover, as for gather(), the values of index
must be between 0
and self.size(dim) - 1
inclusive.
Parameters
- dim (int) – the axis along which to index
-
index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as
src
. When empty, the operation returnsself
unchanged. -
reduce (str , optional) – reduction operation to apply, can be either
'add'
or'multiply'
.
总结:scatter函数就是把src数组中的数据重新分配到output数组当中,index数组中表示了要把src数组中的数据分配到output数组中的位置,若未指定,则填充0.
举例:
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
[2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
[2.0000, 2.0000, 2.0000, 3.2300]])
Original: https://blog.csdn.net/andrewchen1985/article/details/125332942
Author: AndrewChen1985
Title: ONNXRUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:‘Where‘
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/688714/
转载文章受原作者版权保护。转载请注明原作者出处!