使用Pytorch_Geometric(PyG)时构建DataLoader,从DataLoader获取样本Batch时报错: RuntimeError: Sizes of tensors must match except in dimension 0.
报错原因是数据对齐错误,1个batch是多个样本的集合,在样本拼接成集合时出现错误,其规律如下:
- 使用pytorch-geometric的dataloader时,batch的各个样本合并规则
- 属性edge_index规则特殊,每个样本edge_index为2 × e i 2\times e_i 2 ×e i ,则合并n个样本形成一个batch之后的batch.edge_index大小为2 × ( ∑ i = 1 n e i ) 2\times(\sum_{i=1}^n e_i)2 ×(∑i =1 n e i )
- 其他所有属性如果为tensor,则按照第一个维度扩展,例如对于属性x x x,第一个样本大小为d 1 × d 2 d_1\times d_2 d 1 ×d 2 ,第二个样本大小为d 3 × d 2 d_3\times d_2 d 3 ×d 2 ,则如果有一个batch包含这两个样本,batch.x的大小会是( d 3 + d 1 ) × d 2 (d_3+d_1)\times d_2 (d 3 +d 1 )×d 2 。 这里一个巨坑,要求除了第一个维度之外,其他维度大小都必须要相同!! 否则会报错
RuntimeError: Sizes of tensors must match except in dimension 0.
- 其他属性如果不是tensor,就会正常按照列表返回,batch.x=[ 样本1的x,样本2的x,样本3的x]
如何解决:
- 如果是使用torch tensor引起的,可以考虑想办法对齐除了第一个维度外,其他维度的宽度。
- 如果没办法对齐,使用非tensor数据类型替换,例如列表。
- 最后的选择,指定batch_size=1以规避。
dataloader=DataLoader(MyData,batch_size=1)
2022/06/23原始
2023/02/20更新
https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html
这个是官网更详细的描述,直接看这个简单
Original: https://blog.csdn.net/weixin_44839047/article/details/125419476
Author: Deno_V
Title: Pytorch_Geometric(PyG)使用DataLoader报错RuntimeError: Sizes of tensors must match except in dimension 0.
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/707617/
转载文章受原作者版权保护。转载请注明原作者出处!