TF-GNN踩坑记录(三)

在Tensorflow-GNN中使用batch size除了需要注意上面的链接问题之外,最近我在调试的发现,使用了merge_batch_to_components() 之后,使用TF-GNN的Readout模块,它会默认merge之后的graph为一张图读出所有节点的数据组成一个矩阵,而不区分batch中的每一张子图,故会导致数据的结构被修改,导致模型的表现与预期的差距较大。

解决方案

使用Pool替代Readout,该代码的具体的作用是从merge之后的图中,读出每一张子图(component)上 的节点数据,并对每个子图的节点数据进行pooling,如这里使用加法做为pooling的方式,并把这些子图pooling之后的数据拼接成一个矩阵存储在context中。这个矩阵的数据和原始输入的graph是一一对应的,如输入的batch size是32,这个矩阵的行数也为32行,每一行对应一张graph。

Original: https://www.cnblogs.com/lovefisho/p/16627062.html
Author: LoveFishO
Title: TF-GNN踩坑记录(三)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/566632/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球