引言
这篇是datawhale组队学习之图神经网络第二篇,本笔记主要梳理课程的关键点,以及简单的代码实现。
- 首先我们将学习 图神经网络生成节点表征的范式–消息传递(Message Passing)范式。
- 接着我们将初步分析PyG中的
MessagePassing
基类,通过继承此基类我们可以方便地构造一个图神经网络。 - 然后我们以继承
MessagePassing
基类的GCNConv
类为例,学习如何通过继承MessagePassing
基类来构造图神经网络。 - 再接着我们将对
MessagePassing
基类进行剖析。 - 最后我们将学习在继承
MessagePassing
基类的子类中覆写message(),aggreate(),message_and_aggreate()
和update()
,这些方法的规范。
消息传递(Message Passing)范式
图神经网络 (GNN)主要是靠图卷积操作来完成的。而图卷积操作是一种将 目标节点周围邻接节点的信息进行聚合的一种方法,:
为 层节点 的特征向量, 为 到 的边的特征向量。 为聚合方法(可微分的、具有排列不变性(函数输出结果与输入参数的排列无关)的函数;例如:sum()
函数、 mean()
函数和 max()
函数 )
<img class="mathcode" src="https://latex.csdn.net/eq?%5Cgamma">
和
于是,根据这个公式,我们要做的就变成了三件事:
- A. 邻接节点 信息的变换:
- B. 邻接节点信息聚合:
- C. 自己的信息与聚合后的邻接节点信息的变换:
<img class="mathcode" src="https://latex.csdn.net/eq?%5Cgamma">
在Pytorch Geometric(PyG) 中,这个流程被对应到 self.propagate这个操作中, self.propagate将分别执行上述三件事:
- A. 执行 self.message,对应公式中 ,即邻接节点信息的变换
- B. 执行 self.aggregate,对应公式中 ,即邻接节点信息聚合
- C. 执行 self.update,对应公式中 ,即自己的信息与聚合后的邻接节点 信息的变换
MessagePassing实例
我们以继承 MessagePassing
基类的 GCNConv
类为例,学习如何通过继承 MessagePassing
基类来实现一个简单的图神经网络。
GCNConv
; 的数学定义为
归一化系数计算对应
,求和对应聚合方法,该方法中没有变换<img class="mathcode" src="https://latex.csdn.net/eq?%5Cgamma">
。除此之外,需要在self.forward进行初始特征的一次变换,整体流程如下:
- A. self.message:计算归一化系数
- B. self.aggregate:,选择add
- C. self.update:
<img class="mathcode" src="https://latex.csdn.net/eq?%5Cgamma">
(无)
其中,邻接节点的表征首先通过与权重矩阵相乘进行变换,然后按端点的度进行归一化处理,最后进行求和。这个公式可以分为以下几个步骤:
- 向邻接矩阵添加自环边。
- 对节点表征做线性转换。
- 计算归一化系数。
- 归一化邻接节点的节点表征。
- 将相邻节点表征相加(”求和 “聚合)。
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add') # /space 选择聚合方法
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# Step 1: 添加自环
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: /theta 初始特征的一次变换
x = self.lin(x)
# Step 5: 相邻节点表征相加("求和 "聚合)
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
def message(self, x_j, edge_index, size):
# Step 3:/phi 计算归一化系数
row, col = edge_index
deg = degree(row, size[0], dtype=x_j.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4: 归一化邻接节点的节点表征
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
# /gamma 无
return aggr_out
完整代码:
Planetoid
数据集类的官方文档为torch_geometric.datasets.Planetoid。
-*- coding: utf-8 -*-
"""
Created on Sat Jun 19 11:37:18 2021
@author: Choi
"""
import os
import torch
from torch_geometric.datasets import Planetoid # PyG处理好的一些数据,如"Cora", "CiteSeer" and "PubMed" ,用Planetoid这个类调用即可
import torch_geometric.nn as pyg_nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
load Cora dataset
def get_data(folder="node_classify/cora", data_name="cora"):
"""
:param folder:保存数据集的根目录。
:param data_name:数据集的名称
:return:返回的是一个对象,就是PyG文档里的Data对象,它有一些属性,如 data.x、data.edge_index等
"""
dataset = Planetoid(root=folder, name=data_name)
return dataset
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add') # /space 选择聚合方法
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# Step 1: 添加自环
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: /theta 初始特征的一次变换
x = self.lin(x)
# Step 5: 相邻节点表征相加("求和 "聚合)
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
def message(self, x_j, edge_index, size):
# Step 3:/phi 计算归一化系数
row, col = edge_index
deg = degree(row, size[0], dtype=x_j.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4: 归一化邻接节点的节点表征
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
# /gamma 无
return aggr_out
def main():
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 配置GPU
dataset = get_data()
data = dataset[0]
net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
print(h_nodes.shape)
if __name__ == "__main__":
main()
其他资料:
Original: https://blog.csdn.net/alterxu/article/details/118034359
Author: 淡定V哥
Title: 图神经网络二:消息传递图神经网络
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/544884/
转载文章受原作者版权保护。转载请注明原作者出处!