第二十课.DeepGraphLibrary(一)

目录

本篇内容第一部分引用自https://docs.dgl.ai/guide/graph.html#guide-graph,第二部分引用自https://docs.dgl.ai/guide/message.html#guide-message-passing

DGL安装

Deep Graph Library (DGL) 是一个 Python 包,用于在现有 DL 框架(目前支持 PyTorch、MXNet 和 TensorFlow)之上实现图神经网络模型。

安装DGL,先找到合适配置的版本(官方版本查询入口),然后用 pip安装:

pip install dgl-cu111 -f https://data.dgl.ai/wheels/repo.html

图与图的创建

图是用以表示实体及其关系的结构,记为 G = ( V , E ) G=(V,E)G =(V ,E ) 。图由两个集合组成,一是节点的集合 V V V ,一个是边的集合 E E E 。 在边集 E E E 中,一条边 ( u , v ) (u,v)(u ,v ) 连接一对节点 u u u 和 v v v ,表明两节点间存在关系。关系可以是无向的, 如描述节点之间的对称关系;也可以是有向的,如描述非对称关系。例如,若用图对社交网络中人们的友谊关系进行建模,因为友谊是相互的,则边是无向的; 若用图对Twitter用户的关注行为进行建模,则边是有向的。图可以是有向的或无向的,这取决于图中边的方向性。

图可以是加权的或未加权的。在加权图中,每条边都与一个标量权重值相关联。例如,该权重可以表示长度或连接的强度。图可以是同构的或是异构的 。在同构图中,所有节点表示同一类型的实体,所有边表示同一类型的关系。 例如,社交网络的图由表示同一实体类型的人及其相互之间的社交关系组成。

DGL使用一个唯一的整数来表示一个节点,称为点ID;并用对应的两个端点ID表示一条边。同时,DGL也会根据边被添加的顺序, 给每条边分配一个唯一的整数编号,称为边ID。节点和边的ID都是从0开始构建的。在DGL的图里,所有的边都是有方向的, 即边 ( u , v ) (u,v)(u ,v ) 表示它是从节点 u u u 指向节点 v v v 的。

对于多个节点,DGL使用一个一维的整型张量(如,PyTorch的Tensor类,TensorFlow的Tensor类或MXNet的ndarray类)来保存图的点ID, DGL称之为”节点张量”。为了指代多条边,DGL使用一个包含2个节点张量的元组 ( U , V ) (U,V)(U ,V ) ,其中,用 ( U [ i ] , V [ i ] ) (U[i],V[i])(U [i ],V [i ]) 指代一条 U [ i ] U[i]U [i ] 到 V [ i ] V[i]V [i ] 的边。

创建一个 DGLGraph 对象的一种方法是使用 dgl.graph() 函数。它接受一个边的集合作为输入。下面我们构建一个图:

第二十课.DeepGraphLibrary(一)
构建过程如下:
import dgl
import torch as th

u, v = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])
g = dgl.graph((u, v))
print(g)
"""
Graph(num_nodes=4, num_edges=4,
      ndata_schemes={}
      edata_schemes={})
"""

print(g.nodes())

print(g.edges())

print(g.edges(form='all'))

g = dgl.graph((u, v), num_nodes=8)

对于无向的图,用户需要为每条边都创建两个方向的边。可以使用 dgl.to_bidirected() 函数来实现这个目的(这个函数可以把原图转换成一个包含反向边的图):

bg = dgl.to_bidirected(g)
print(bg.edges())

由于Tensor类内部使用C来存储,且定义了数据类型以及存储的设备信息,DGL推荐使用Tensor作为DGL API的输入。 不过大部分的DGL API也支持Python的可迭代类型(比如列表)或numpy.ndarray类型作为API的输入,方便用户快速进行开发验证

DGL支持使用 32 位或 64 位的整数作为节点ID和边ID。节点和边ID的数据类型必须一致。DGL提供了进行数据类型转换的方法,如下所示:

edges = th.tensor([2, 5, 3]), th.tensor([3, 5, 0])
g64 = dgl.graph(edges)
print(g64.idtype)

g32 = dgl.graph(edges, idtype=th.int32)
print(g32.idtype)

g64_2 = g32.long()
print(g64_2.idtype)

g32_2 = g64.int()
print(g32_2.idtype)

DGL Graph 对象的节点和边可具有多个用户定义的、可命名的特征,以储存图的节点和边的属性。 通过 ndataedata 接口可访问这些特征。 例如,以下代码创建了2个节点特征(分别命名为 'x''y' )和1个边特征(命名为 'x' ):

import dgl
import torch as th
g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0]))
print(g)
"""
Graph(num_nodes=6, num_edges=4,
      ndata_schemes={}
      edata_schemes={})
"""

g.ndata['x'] = th.ones(g.num_nodes(), 3)
g.edata['x'] = th.ones(g.num_edges(), dtype=th.int32)
print(g)
"""
Graph(num_nodes=6, num_edges=4,
      ndata_schemes={'x' : Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'x' : Scheme(shape=(,), dtype=torch.int32)})
"""

g.ndata['y'] = th.randn(g.num_nodes(), 5)
print(g.ndata['x'][1])

print(g.edata['x'][th.tensor([0, 3])])

关于 ndataedata 接口的重要说明:

  • 每个节点特征具有唯一名称,每个边特征也具有唯一名称。节点和边的特征可以具有相同的名称(如上述示例中的 'x' );
  • 通过张量分配创建特征时,DGL会将特征赋给图中的每个节点和每条边。该张量的第一维必须与图中节点或边的数量一致。 不能将特征赋给图中节点或边的子集;
  • 相同名称下的特征必须具有相同的维度和数据类型;

对于加权图,用户可以将权重储存为一个边特征,如下:


edges = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])
weights = th.tensor([0.1, 0.6, 0.9, 0.7])
g = dgl.graph(edges)
g.edata['w'] = weights
print(g)
"""
Graph(num_nodes=4, num_edges=4,
      ndata_schemes={}
      edata_schemes={'w' : Scheme(shape=(,), dtype=torch.float32)})
"""

从外部源创建图

可以从外部来源构造一个 DGL Graph 对象,包括:

  • 从用于图和稀疏矩阵的外部Python库(NetworkX 和 SciPy)创建而来;
  • 从磁盘加载图数据;

从SciPy稀疏矩阵创建图:

import dgl
import torch as th
import scipy.sparse as sp

spmat = sp.rand(100, 100, density=0.05)

print(spmat)
"""
  (4, 72)   0.6109147027395163
  : :
  (30, 92)  0.5398895483258674
"""
print(spmat.todense())
"""
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...

 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
"""

print(dgl.from_scipy(spmat))
"""
Graph(num_nodes=100, num_edges=500,
      ndata_schemes={}
      edata_schemes={})
"""

从NetworkX创建图:

import dgl
import torch as th
import networkx as nx

nx_g = nx.path_graph(5)
print(dgl.from_networkx(nx_g))
"""
Graph(num_nodes=5, num_edges=8,
      ndata_schemes={}
      edata_schemes={})
"""

注意,当使用 nx.path_graph(5) 进行创建时, DGLGraph 对象有8条边,而非4条。 这是由于 nx.path_graph(5) 构建了一个无向的NetworkX图 networkx.Graph ,而 DGLGraph 的边总是有向的。 所以当将无向的NetworkX图转换为 DGLGraph 对象时,DGL会在内部将1条无向边转换为2条有向边。 使用有向的NetworkX图 networkx.DiGraph 可避免该现象:

nxg = nx.DiGraph([(2, 1), (1, 2), (2, 3), (0, 0)])
print(dgl.from_networkx(nxg))
"""
Graph(num_nodes=4, num_edges=4,
      ndata_schemes={}
      edata_schemes={})
"""

有多种文件格式可储存图,一般常见的是CSV格式,Pandas可以将该类型数据加载到python对象(如 numpy.ndarray)中, 进而使用这些对象来构建DGLGraph对象

异构图

相比同构图,异构图里可以有不同类型的节点和边。这些不同类型的节点和边具有独立的ID空间和特征。 例如在下图中,”用户”和”游戏”节点的ID都是从0开始的,而且两种节点具有不同的特征:

第二十课.DeepGraphLibrary(一)
在DGL中,一个异构图由一系列子图构成,一个子图对应一种关系。每个关系由一个字符串三元组定义 : (源节点类型, 边类型, 目标节点类型) 。由于这里的关系定义消除了边类型的歧义,DGL称它们为规范边类型。

下面创建一个异构图:

import dgl
import torch as th

graph_data = {
   ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),
   ('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),
   ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))
}
g = dgl.heterograph(graph_data)

print(g.ntypes)

print(g.etypes)

print(g.canonical_etypes)
"""
[('drug', 'interacts', 'drug'),
 ('drug', 'interacts', 'gene'),
 ('drug', 'treats', 'disease')]
"""

注意,同构图和二分图只是一种特殊的异构图,它们只包括一种关系:


dgl.heterograph({('node_type', 'edge_type', 'node_type'): (u, v)})

dgl.heterograph({('source_type', 'edge_type', 'destination_type'): (u, v)})

与异构图相关联的 metagraph 就是图的模式。它指定节点集和节点之间的边的类型约束。 metagraph 中的一个节点 u u u 对应于相关异构图中的一个节点类型。 metagraph 中的边 ( u , v ) (u,v)(u ,v ) 表示在相关异构图中存在从 u u u 型节点到 v v v 型节点的边。

print(g)
"""
Graph(num_nodes={'disease': 3, 'drug': 3, 'gene': 4},
      num_edges={('drug', 'interacts', 'drug'): 2,
                 ('drug', 'interacts', 'gene'): 2,
                 ('drug', 'treats', 'disease'): 1},
      metagraph=[('drug', 'drug', 'interacts'),
                 ('drug', 'gene', 'interacts'),
                 ('drug', 'disease', 'treats')])
"""

print(g.metagraph().edges())
"""
OutMultiEdgeDataView([('drug', 'drug'), ('drug', 'gene'), ('drug', 'disease')])
"""

当引入多种节点和边类型后,用户在调用DGLGraph API以获取特定类型的信息时,需要指定具体的节点和边类型。此外,不同类型的节点和边具有单独的ID:


print(g.num_nodes())

print(g.num_nodes('drug'))

print(g.nodes())

print(g.nodes('drug'))

为了设置/获取特定节点和边类型的特征,DGL提供了两种新类型的语法:

  • g.nodes[‘node_type’].data[‘feat_name’]
  • g.edges[‘edge_type’].data[‘feat_name’]

g.nodes['drug'].data['hv'] = th.ones(3, 1)
print(g.nodes['drug'].data['hv'])
"""
tensor([[1.],
        [1.],
        [1.]])
"""

g.edges['treats'].data['he'] = th.zeros(1, 1)
print(g.edges['treats'].data['he'])

如果图里只有一种节点或边类型,则不需要指定节点或边的类型:

g = dgl.heterograph({
   ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),
   ('drug', 'is similar', 'drug'): (th.tensor([0, 1]), th.tensor([2, 3]))
})
print(g.nodes())

g.ndata['hv'] = th.ones(4, 1)

当边类型唯一地确定了源节点和目标节点的类型时,用户可以只使用一个字符串而不是字符串三元组来指定边类型;例如: 对于具有两个关系 ('user', 'plays', 'game')('user', 'likes', 'game') 的异构图, 只使用 'plays''like' 来指代这两个关系是可以的;

一种存储异构图的常见方法是在不同的CSV文件中存储不同类型的节点和边。下面是一个例子:


data/
|-- drug.csv
|-- gene.csv
|-- disease.csv
|-- drug-interact-drug.csv
|-- drug-interact-gene.csv
|-- drug-treat-disease.csv

与同构图的情况类似,用户可以使用像Pandas这样的包先将CSV文件解析为numpy数组或框架张量,再构建一个关系字典,并用它构造一个异构图;

用户可以通过指定要保留的关系来创建异构图的 子图,相关的特征也会被拷贝:

g = dgl.heterograph({
   ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),
   ('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),
   ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))
})
g.nodes['drug'].data['hv'] = th.ones(3, 1)

eg = dgl.edge_type_subgraph(g, [('drug', 'interacts', 'drug'),
                                ('drug', 'treats', 'disease')])

print(eg)
"""
Graph(num_nodes={'disease': 3, 'drug': 3},
      num_edges={('drug', 'interacts', 'drug'): 2, ('drug', 'treats', 'disease'): 1},
      metagraph=[('drug', 'drug', 'interacts'), ('drug', 'disease', 'treats')])
"""

print(eg.nodes['drug'].data['hv'])
"""
tensor([[1.],
        [1.],
        [1.]])
"""

对于异构图的应用场景:

  • 不同类型的节点和边的特征具有不同的数据类型或大小;
  • 用户希望对不同类型的节点和边应用不同的操作;

如果上述情况不适用,并且用户不希望在建模中区分节点和边的类型,则DGL允许使用 dgl.DGLGraph.to_homogeneous() API将异构图转换为同构图。 具体行为如下:

  • 用从0开始的连续整数重新标记所有类型的节点和边;
  • 对所有的节点和边合并用户指定的特征;
g = dgl.heterograph({
   ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),
   ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))})
g.nodes['drug'].data['hv'] = th.zeros(3, 1)
g.nodes['disease'].data['hv'] = th.ones(3, 1)
g.edges['interacts'].data['he'] = th.zeros(2, 1)
g.edges['treats'].data['he'] = th.zeros(1, 2)

hg = dgl.to_homogeneous(g)
print('hv' in hg.ndata)

hg = dgl.to_homogeneous(g, edata=['he'])

hg = dgl.to_homogeneous(g, ndata=['hv'])
print(hg.ndata['hv'])
"""
tensor([[1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.]])
"""

在GPU上运行DGL

用户可以将DGL中的对象迁移到GPU上,达到加速计算的效果:

import dgl
import torch as th

u, v = th.tensor([0, 1, 2]), th.tensor([2, 3, 4])
g = dgl.graph((u, v))
g.ndata['x'] = th.randn(5, 3)
print(g.device)

cuda_g = g.to('cuda:0')
print(cuda_g.device)
print(cuda_g.ndata['x'].device)

u, v = u.to('cuda:0'), v.to('cuda:0')
g = dgl.graph((u, v))
print(g.device)

消息传递范式

消息传递是GNN的通用框架和编程范式,它从聚合与更新的角度总结GNN;

假设节点v v v上的特征为x v ∈ R d 1 x_{v}\in R^{d_{1}}x v ​∈R d 1 ​,边( u , v ) (u,v)(u ,v )上的特征为w e ∈ R d 2 w_{e}\in R^{d_{2}}w e ​∈R d 2 ​,消息传递范式定义了逐节点(node-wise)和边(edge-wise)的计算:

  • Edge-wise:m e ( t + 1 ) = ϕ ( x v ( t ) , x u ( t ) , w e ( t ) ) , ( u , v , e ) ∈ ε m_{e}^{(t+1)}=\phi(x_{v}^{(t)},x_{u}^{(t)},w_{e}^{(t)}),(u,v,e)\in\varepsilon m e (t +1 )​=ϕ(x v (t )​,x u (t )​,w e (t )​),(u ,v ,e )∈ε
  • Node-wise:x v ( t + 1 ) = ψ ( x v ( t ) , ρ ( { m e ( t + 1 ) : ( u , v , e ) ∈ ε } ) ) x_{v}^{(t+1)}=\psi(x_{v}^{(t)},\rho(\left{m_{e}^{(t+1)}:(u,v,e)\in\varepsilon\right}))x v (t +1 )​=ψ(x v (t )​,ρ({m e (t +1 )​:(u ,v ,e )∈ε}))

其中,ϕ \phi ϕ是定义在每条边上的 消息函数,它通过将边上特征与其两端节点的特征相结合生成消息m e m_{e}m e ​, 聚合函数ρ \rho ρ聚合节点接收到的消息, 更新函数ψ \psi ψ结合聚合后的消息和节点本身的特征更新节点的特征;

内置函数和消息传递API

在DGL中, 消息函数 接受一个参数 edges,这是一个 EdgeBatch 的实例, 在消息传递时,它被DGL 在内部生成以表示一批边。 edgessrcdstdata 共3个成员属性, 分别用于访问源节点、目标节点和边的特征;

聚合函数 接受一个参数 nodes,这是一个 NodeBatch 的实例, 在消息传递时,它被DGL 在内部生成以表示一批节点。 nodes 的成员属性 mailbox 可以用来保存节点收到的消息。 一些最常见的聚合操作包括 summaxmin 等;

更新函数 接受一个如上所述的参数 nodes。此函数对 聚合函数 的聚合结果进行操作, 通常在消息传递的最后一步将其与节点的特征相结合,并将输出作为节点的新特征。

如果用户的消息传递函数无法用内置函数实现,则可以实现自己的消息或聚合函数,也称为 用户定义函数

消息函数

内置消息函数可以是一元函数或二元函数。对于一元函数,DGL支持 copy 函数。对于二元函数, DGL现在支持 addsubmuldivdot 函数。消息的内置函数的 命名约定u 表示 源节点, v 表示 目标节点, e 表示 边。这些函数的参数是字符串,指示相应节点和边的输入和输出特征字段名。例如,要对源节点的 hu 特征和目标节点的 hv 特征求和, 然后将结果保存在边的 he 特征上( he可以理解为一个临时变量,即消息),用户可以使用内置函数:

dgl.function.u_add_v('hu', 'hv', 'he')
"""
程序会自动去edges的源节点中找到字段'hu'的数据, 目标节点中找到字段'hv'的数据,
求和后保存到临时字段'he'的数据中, 这就是消息, 它会被存到nodes.mailbox下
"""

而以下用户定义消息函数与此内置函数等价:


def message_func(edges):

     return {'he': edges.src['hu'] + edges.dst['hv']}

聚合函数

DGL支持内置的聚合函数 summaxminmean 操作。 聚合函数通常有两个参数,它们的类型都是字符串。一个用于指定 mailbox 中的字段名( nodes 的成员 mailbox 可以用来保存节点收到的消息, nodesNodeBatch 的实例),另一个用于指示节点特征的字段名, 例如:


dgl.function.sum('m', 'h')

等价于如下所示的对接收到消息求和的用户定义函数:

import torch
def reduce_func(nodes):

     return {'h': torch.sum(nodes.mailbox['m'], dim=1)}

更新函数

update_all() 的参数是一个消息函数、一个聚合函数和一个更新函数。 更新函数是一个可选择的参数,用户也可以不使用它,而是在 update_all 执行完后直接对节点特征进行操作。比如:

import dgl.function as fn
def updata_all_example(graph):

    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                     fn.sum('m', 'ft'))

    final_ft = graph.ndata['ft'] * 2
    return final_ft

此调用通过将源节点字段 ft 的特征与边字段 a 的特征相乘生成消息 m, 然后对所有节点( 本质应该是目标节点,但对于同构图,源节点和目标节点都是图中的所有节点)对应的消息( mailbox['m'])求和来更新节点字段 ft 的特征,再将特征 ft 乘以2得到最终结果 final_ft

调用后,中间消息 m 将被清除。上述函数的数学公式为:f i n a l f t i = 2 ∑ j ∈ N ( i ) ( a j i f t j ) final\, ft_{i}=2\sum_{j\in N(i)}(a_{ji}ft_{j})f i n a l f t i ​=2 j ∈N (i )∑​(a j i ​f t j ​)实例如下,我们采用Graph:

第二十课.DeepGraphLibrary(一)
依次按照节点ID赋予特征 'x':[1,2,3,4],按照边ID赋予特征 'x':[1,1,1,1],采用上面的消息传递机制,模拟过程为:
import dgl
import torch
import dgl.function as fn

u,v=torch.tensor([0,0,0,1]),torch.tensor([1,2,3,3])
g=dgl.graph((u,v))

g.ndata['x']=torch.tensor([1,2,3,4],dtype=torch.float)
g.edata['x']=torch.tensor([1,1,1,1],dtype=torch.float)

print(g)
print(g.nodes())
print(g.edges())
"""
Using backend: pytorch
Graph(num_nodes=4, num_edges=4,
      ndata_schemes={'x': Scheme(shape=(), dtype=torch.float32)}
      edata_schemes={'x': Scheme(shape=(), dtype=torch.float32)})
tensor([0, 1, 2, 3])
(tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]))
"""

def updata_all_example(graph):

    graph.update_all(fn.u_mul_e('x', 'x', 'm'),
                     fn.sum('m', 'x'))

    final_ft = graph.ndata['x'] * 2
    return final_ft

print(updata_all_example(g))
"""
tensor([0., 2., 2., 6.])
"""

比如节点ID=3的邻居有ID=0和ID=1,这两个节点都是相对ID=3的源节点,所以更新后,它的特征为6;

根据邻居的定义,节点ID=0没有邻居,所以特征更新后为0;

邻居的定义:节点v i v_{i}v i ​的邻居为:{v j ∈ V , ( v j , v i ) ∈ E v_{j}\in V,(v_{j},v_{i})\in E v j ​∈V ,(v j ​,v i ​)∈E}

额外思考:注意到在聚合函数中,参数有:消息 'm'与节点的特征字段 'ft'

关于其中的,消息的内容与特征字段的对应关系,它们之间的对应关系可能是依靠节点的编号进行连接的,从而把消息正确地分配到 nodes.mailbox

单独调用Edge-wise更新边特征

在DGL中,也可以在不涉及消息传递的情况下,通过 apply_edges() 单独调用逐边计算(Edge-wise), apply_edges() 的参数是一个消息函数。并且在默认情况下, apply_edges()将更新所有的边。例如:

import dgl.function as fn

graph.apply_edges(fn.u_add_v('el', 'er', 'e'))

依然采用前面提到的Graph,利用 apply_edges,我们只更新了边的数据,实例如下:

import dgl
import torch
import dgl.function as fn

u,v=torch.tensor([0,0,0,1]),torch.tensor([1,2,3,3])
g=dgl.graph((u,v))

g.ndata['x']=torch.tensor([1,2,3,4],dtype=torch.float)
g.edata['y']=torch.tensor([1,1,1,1],dtype=torch.float)

g.apply_edges(fn.u_add_v('x','x','y'))
print(g.ndata['x'])
print(g.edata['y'])

在子图上进行消息传递

如果用户只想更新图中的部分节点,可以先通过想要囊括的节点编号创建一个子图, 然后在子图上调用 update_all() 方法。例如:

nid = [0, 2, 3, 6, 7, 9]
sg = g.subgraph(nid)
sg.update_all(message_func, reduce_func, apply_node_func)

这是小批量训练的常用方法,常用于在大规模Graph上训练

在消息传递中使用边的权重

一类常见的图神经网络建模的做法是在消息聚合前使用边的权重, 比如图注意力网络GAT,通常,DGL的处理方法为:

  • 将权重保存为边的特征;
  • 在消息函数中用边的特征与源节点的相乘;

比如:

import dgl.function as fn

graph.edata['a'] = eweight
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                 fn.sum('m', 'ft'))

异构图上的消息传递

在DGL中,一个异构图由一系列子图构成,一个子图对应一种关系。每个关系由一个字符串三元组定义 : (源节点类型, 边类型, 目标节点类型) 。由于这里的关系定义消除了边类型的歧义,DGL称它们为规范边类型;

我们创建下面的异构图:

import dgl
import torch as th
import dgl.function as fn

graph_data = {
   ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),
   ('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),
   ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))
}
g = dgl.heterograph(graph_data)
print(g)
"""
Graph(num_nodes={'disease': 3, 'drug': 3, 'gene': 4},
      num_edges={('drug', 'interacts', 'drug'): 2, ('drug', 'interacts', 'gene'): 2, ('drug', 'treats', 'disease'): 1},
      metagraph=[('drug', 'drug', 'interacts'), ('drug', 'gene', 'interacts'), ('drug', 'disease', 'treats')])
"""

对于异构图的metagraph,我们可以查看其 meta nodesmeta edges 的顺序,以及 规范边类型

print(g.ntypes)

print(g.etypes)

print(g.canonical_etypes)

为异构图的节点和边添加特征:

g.nodes['drug'].data['drug_x']=th.tensor([1,2,3],dtype=th.float).view(-1,1)
g.nodes['disease'].data['disease_x']=th.tensor([1,2,3],dtype=th.float).view(-1,1)
g.nodes['gene'].data['gene_x']=th.tensor([1,2,3,4],dtype=th.float).view(-1,1)

g.edges['treats'].data['treats_x']=th.ones(1,1)
g.edges[('drug', 'interacts', 'drug')].data['ddi_x']=th.ones(2,1)
g.edges[('drug', 'interacts', 'gene')].data['dgi_x']=th.ones(2,1)

异构图上的消息传递可以分为以下两部分:

  • 对每个关系计算和聚合消息;
  • 对节点聚合来自不同类型上的消息;

在DGL中,对异构图进行消息传递的接口是 multi_update_all()

multi_update_all() 接受一个字典。这个字典的每一个键值对里,键是一种关系, 值是这种关系对应 update_all() 的参数。 multi_update_all() 还接受一个字符串来表示跨类型整合函数,来指定整合不同关系聚合结果的方式。 这个整合方式可以是 summinmaxmeanstack 中的一个。下面是对上文异构图的操作实例:

funcs = {}

for c_etype in g.canonical_etypes:

    srctype, etype, dsttype = c_etype

    funcs[c_etype] = (fn.copy_u('%s_x' % srctype, 'm'),
                      fn.mean('m', 'h'))

g.multi_update_all(funcs, 'sum')

result={ntype: g.nodes[ntype].data['h'] for ntype in g.ntypes}
print(result)
"""
{'disease': tensor([[0.],
                    [0.],
                    [2.]]),
 'drug': tensor([[0.],
                [1.],
                [2.]]),
 'gene': tensor([[0.],
                [0.],
                [1.],
                [2.]])}
"""

聚合结果都保存在同一个字段 'h'下,说明异构图的聚合是逐类操作的

Original: https://blog.csdn.net/qq_40943760/article/details/120658591
Author: tzc_fly
Title: 第二十课.DeepGraphLibrary(一)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/555653/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球