图神经网络实践之图节点分类(一)

本文主要以Deep Graph Library(DGL)为基础,利用图神经网络来进行图节点分类任务。本篇针对的图为同构图。

DGL是一个python包,用以在现有的深度学习框架上(包括Pytorch、MXNet和TensorFlow)来实现图神经网络系列模型。它提供了对消息传递的通用控制,通过自动批处理和高度调整的稀疏矩阵内核进行速度优化,以及多 GPU/CPU 训练以扩展到数亿个节点和边缘的图形。
DGL拥有丰富的文档及相关接口,而且文档有中文版本,十分容易学习和上手。
DGL的github链接:https://github.com/dmlc/dgl

2.1 数据集加载

本文使用的数据集为DGL中已经有的Cora数据集,该数据集为论文引用数据集,包含论文节点和论文之间的引用关系,通过论文本身的特征和引用关系来对论文进行分类,其共包括以下七类:

  • 基于案例
  • 遗传算法
  • 神经网络
  • 概率方法
  • 强化学习
  • 规则学习
  • 理论
import dgl.data
from dgl.nn import GraphConv
import torch.nn as nn
from dgl.nn.pytorch.conv import SAGEConv
import torch
import torch.nn.functional as F

dataset = dgl.data.CoraGraphDataset()
print('Number of categories:', dataset.num_classes)
g = dataset[0]
print("结点信息",g.ndata)
print("边信息",g.edata)

通过上述代码,可以加载Cora数据集,并能够看到数据集的基本情况,数据集共包含2708个节点,10556条边。

2.2 图神经网络模块定义

简单的GCN构建:
以下代码构建了一个两层图卷积网络(GCN),每一层通过聚合邻居信息来计算新的节点表示。
如果想要构建多层 GCN,您可以简单地堆叠dgl.nn.GraphConv 模块,这些模块继承自torch.nn.Module.

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

GraphSAGE构建:
GraphSAGE 是图神经网络中比较经典的模型,GraphSAGE 包含采样和聚合 (Sample and aggregate),首先使用节点之间连接信息,对邻居进行采样,然后通过多层聚合函数不断地将相邻节点的信息融合在一起。本文参照DGL中的例子来实现的GraphSAGE,代码如下:

class GraphSAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 aggregator_type):
        super(GraphSAGE, self).__init__()
        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

        self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type))

        for i in range(n_layers - 1):
            self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type))

        self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type))

    def forward(self, graph, inputs):
        h = self.dropout(inputs)
        for l, layer in enumerate(self.layers):
            h = layer(graph, h)
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h

2.3 评价函数

def evaluate(model, graph, features, labels, nid):
    model.eval()
    with torch.no_grad():
        logits = model(graph, features)
        logits = logits[nid]
        labels = labels[nid]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

2.4 图神经网络的训练

全图(使用所有的节点和边的特征)上的训练只需要使用上面定义的模型进行前向传播计算,并通过在训练节点上比较预测和真实标签来计算损失,从而完成后向传播。
节点特征和标签存储在其图上,训练、验证和测试的分割也以布尔掩码的形式存储在图上。

features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
train_nid = train_mask.nonzero().squeeze()
val_nid = val_mask.nonzero().squeeze()
test_nid = test_mask.nonzero().squeeze()

def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    best_val_acc = 0
    best_test_acc = 0
    for e in range(100):

        logits = model(g, features)

        pred = logits.argmax(1)

        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = evaluate(model, g, features, labels, val_nid)

        print("Epoch {:05d}  | Loss {:.4f} | Accuracy {:.4f} | ".format(e, loss.item(), acc))

双层GNN训练:

model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)
print()
acc = evaluate(model, g, features, labels, test_nid)
print("Test Accuracy {:.4f}".format(acc))

GraphSAGE训练:

modeSAGE = GraphSAGE(g.ndata['feat'].shape[1],
                      16,
                      dataset.num_classes,
                      2,
                      F.relu,
                      0.5,
                      "gcn")
train(g, modeSAGE)
acc = evaluate(modeSAGE, g, features, labels, test_nid)
print("Test Accuracy {:.4f}".format(acc))

运行上述代码即可得到分类的效果,一般来说GraphSAGE的效果会略好于双层的GCN,但差距并不太大。

本文主要在DGL包自带的同构图数据集上进行了一个简单的图节点分类的尝试,之后会尝试在其他数据集(异构图/知识图谱)上进行图节点分类的任务。

Original: https://blog.csdn.net/sjx674749057/article/details/123558509
Author: nrcc
Title: 图神经网络实践之图节点分类(一)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/662506/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球