使用DGL完成节点分类任务

更多图神经网络和深度学习内容请关注:

使用DGL完成节点分类任务

; 节点分类任务概述

节点分类(node classification)任务是在图数据处理中最流行任务之一,一个模型需要预测每个节点属于哪个类别。

在图神经网络出现之前,用于结点分类任务的方法可归为两大类:

  • 仅使用连通性(如DeepWalk或node2vec)
  • 简单地结合连通性和节点自身的特征

相比之下, GNNs是一个通过结合局部邻域(广义上的邻居,包含结点自身)的连通性及其特征来获得节点表征的方法。

Kipf等人将节点分类问题描述为 一个半监督的节点分类任务。图神经网络只需要一小部分已标记的节点,即可准确地预测其他节点的类别。

本文将展示如何在 Cora数据集中(即以论文为节点,以论文引用为边的引文网络)使用少量标签构建半监督节点分类任务的GNN模型。其具体任务为预测给定论文的类别。每个论文节点均包含一个单词计数向量(word count vector)作为它的特征,这些特征进行了归一化(使其总和为1),参考论文第5.2节。

使用DGL完成节点分类

导入相对应的包

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.data
Using backend: pytorch

加载数据集

dataset = dgl.data.CoraGraphDataset()
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.

DGL数据集对象可以包含一个或多个图。一般情况下, 整图分类任务数据集包含多个图, 边预测节点分类数据集只包含一个图,如节点分类任务中的 Cora数据集只包含一个图。

g = dataset[0]

DGL图将节点特征和边特征分别存储在两个类似字典的属性 ndataedata中,在Cora数据集中,图包含以下节点特征(其他数据集也类似):

  • train_mask:布尔张量,表示节点是否在训练集中。
  • val_mask:布尔张量,表示节点是否在验证集中。
  • test_mask:布尔张量,表示节点是否在测试集中。
  • label:节点类别。
  • feat:节点特征。
print("Node feature")
print(g.ndata)

print("Edge feature")
print(g.edata)
Node feature
{'train_mask': tensor([ True,  True,  True,  ..., False, False, False]), 'label': tensor([3, 4, 4,  ..., 3, 3, 3]), 'val_mask': tensor([False, False, False,  ..., False, False, False]), 'test_mask': tensor([False, False, False,  ...,  True,  True,  True]), 'feat': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])}
Edge feature
{}

定义图卷积网络(GCN)

本文将构建一个两层图卷积网络(GCN)。其中每一层都通过聚合邻居信息来计算新的节点表示,若需要构建多层GCN网络,我们可简单地堆叠 dgl.nn.GraphConv模块,这些都模块继承于 torch.nn.Module。(假设DGL使用的后端框架为PyTorch)

from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_class):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_class)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

in_feats = g.ndata["feat"].shape[1]
h_feats = 16
num_class = (torch.max(g.ndata["label"]) + 1).item()

model = GCN(in_feats, h_feats, num_class)

DGL提供了许多流行的邻居聚合模块的实现,我们可以使用一行代码即可轻松调用它们。

训练GCN模型

GCN模型训练过程类似其他PyTorch神经网络训练过程。

def train(g, model, learning_rate=0.01, num_epoch=100):
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    test_mask = g.ndata["test_mask"]
    val_mask = g.ndata["val_mask"]

    for epoch in range(num_epoch):
        result = model(g, features)
        pred = result.argmax(1)

        loss = F.cross_entropy(result[train_mask], labels[train_mask])

        train_acc = (pred[train_mask]==labels[train_mask]).float().mean()
        val_acc  = (pred[val_mask]==labels[val_mask]).float().mean()
        test_acc  = (pred[test_mask]==labels[test_mask]).float().mean()

        if best_val_acc < val_acc:
            best_val_acc, best_test_acc = val_acc, test_acc
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch % 5 == 0:
            print('In epoch {}, loss: {}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(
                epoch, loss, val_acc, best_val_acc, test_acc, best_test_acc))

if __name__ == "__main__":
    train(g, model, num_epoch=200, learning_rate=0.002)

In epoch 0, loss: 1.0601081612549024e-06, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 5, loss: 9.979492006095825e-07, val acc: 0.760 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 10, loss: 9.494142432231456e-07, val acc: 0.762 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 15, loss: 9.017308570946625e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 20, loss: 8.557504429518303e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 25, loss: 8.157304023370671e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 30, loss: 7.71452903336467e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 35, loss: 7.322842634494009e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 40, loss: 6.948185955479858e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 45, loss: 6.624618436035234e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 50, loss: 6.292536340879451e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 55, loss: 6.028573125149705e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 60, loss: 5.807185630146705e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 65, loss: 5.534708407139988e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 70, loss: 5.381440359997214e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 75, loss: 5.117477144267468e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 80, loss: 4.913119937555166e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 85, loss: 4.759851037761109e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 90, loss: 4.5640075541086844e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 95, loss: 4.368164354673354e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 100, loss: 4.2319251747358066e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 105, loss: 4.07865627494175e-07, val acc: 0.764 (best 0.764), test acc: 0.766 (best 0.764)
In epoch 110, loss: 3.993507107225014e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 115, loss: 3.840238207430957e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 120, loss: 3.755089039714221e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 125, loss: 3.6358795796331833e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 130, loss: 3.5081561122751737e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 135, loss: 3.414492084630183e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 140, loss: 3.363402356626466e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 145, loss: 3.218648316760664e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 150, loss: 3.159043444611598e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 155, loss: 3.0568645570383524e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 160, loss: 2.988745109178126e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 165, loss: 2.895080797316041e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 170, loss: 2.792901625525701e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 175, loss: 2.733296753376635e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 180, loss: 2.673692165444663e-07, val acc: 0.764 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 185, loss: 2.614087861729786e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 190, loss: 2.53745326972421e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 195, loss: 2.486363541720493e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)

完整代码为

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import CoraGraphDataset
from dgl.nn import GraphConv

class GCN(nn.Module):
"""
    GCN network
"""
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

def train(g, model, num_epoch = 100, learning_rate =  0.001):
"""
    train function
"""
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    best_val_accurate = 0
    best_test_accurate = 0

    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    test_mask = g.ndata["test_mask"]
    val_mask = g.ndata["val_mask"]

    for e in range(num_epoch):

        result = model(g, features)

        pred = result.argmax(dim=1)

        loss = F.cross_entropy(result[train_mask], labels[train_mask])

        train_accurate = (pred[train_mask]==labels[train_mask]).float().mean()
        test_accurate = (pred[test_mask]==labels[test_mask]).float().mean()
        val_accurate = (pred[val_mask]==labels[val_mask]).float().mean()
        if best_val_accurate < val_accurate:
            best_val_accurate, best_test_accurate = val_accurate, test_accurate

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if e % 5 == 0:
            print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(
                e, loss, val_accurate, best_val_accurate, test_accurate, best_test_accurate))

def main():
    dataset = CoraGraphDataset()
    g = dataset[0]

    in_feats = g.ndata["feat"].shape[1]
    h_feats = 16
    num_classes = dataset.num_classes

    model = GCN(in_feats, h_feats, num_classes)
    train(g, model)

if __name__ == "__main__":
    main()
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.

In epoch 0, loss: 1.946, val acc: 0.104 (best 0.104), test acc: 0.114 (best 0.114)
In epoch 5, loss: 1.942, val acc: 0.276 (best 0.276), test acc: 0.314 (best 0.314)
In epoch 10, loss: 1.936, val acc: 0.452 (best 0.452), test acc: 0.452 (best 0.452)
In epoch 15, loss: 1.929, val acc: 0.546 (best 0.546), test acc: 0.549 (best 0.549)
In epoch 20, loss: 1.921, val acc: 0.612 (best 0.612), test acc: 0.631 (best 0.631)
In epoch 25, loss: 1.913, val acc: 0.640 (best 0.640), test acc: 0.647 (best 0.647)
In epoch 30, loss: 1.904, val acc: 0.654 (best 0.654), test acc: 0.670 (best 0.670)
In epoch 35, loss: 1.895, val acc: 0.684 (best 0.684), test acc: 0.692 (best 0.692)
In epoch 40, loss: 1.886, val acc: 0.690 (best 0.692), test acc: 0.695 (best 0.693)
In epoch 45, loss: 1.876, val acc: 0.700 (best 0.700), test acc: 0.694 (best 0.694)
In epoch 50, loss: 1.866, val acc: 0.706 (best 0.708), test acc: 0.701 (best 0.699)
In epoch 55, loss: 1.855, val acc: 0.710 (best 0.710), test acc: 0.698 (best 0.698)
In epoch 60, loss: 1.844, val acc: 0.708 (best 0.712), test acc: 0.702 (best 0.699)
In epoch 65, loss: 1.833, val acc: 0.704 (best 0.712), test acc: 0.702 (best 0.699)
In epoch 70, loss: 1.821, val acc: 0.702 (best 0.712), test acc: 0.704 (best 0.699)
In epoch 75, loss: 1.809, val acc: 0.704 (best 0.712), test acc: 0.705 (best 0.699)
In epoch 80, loss: 1.796, val acc: 0.706 (best 0.712), test acc: 0.704 (best 0.699)
In epoch 85, loss: 1.783, val acc: 0.702 (best 0.712), test acc: 0.706 (best 0.699)
In epoch 90, loss: 1.769, val acc: 0.694 (best 0.712), test acc: 0.703 (best 0.699)
In epoch 95, loss: 1.755, val acc: 0.692 (best 0.712), test acc: 0.706 (best 0.699)

使用GPU进行训练

在GPU上进行训练需要使用 to方法将模型和图都放到GPU上,PyTorch训练其他神经网络模型类似。

g = g.to('cuda')
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda')
train(g, model)

参考

翻译整理自Node Classification with DGL

Original: https://blog.csdn.net/huanghelouzi/article/details/116430387
Author: huanghelouzi
Title: 使用DGL完成节点分类任务

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/688971/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球