SimplE:SimplE Embedding for Link Prediction in Knowledge Graphs+代码

文章目录

本文主要对知识图谱补全论文进行讲解,对其原理、思路、代码等方面进行详细讲解。

1 介绍

1.1 知识图谱

现实世界中可以通过三元组,对现实世界进行描述,三元组描述的形式为:实体-关系-实体或实体-关系-属性,通常可以用符号表示(h, r, t)其中h表示头实体,r表示实体和实体之间的关系,t表示尾实体。以该方式存储现实世界的各种关系以及属性,由于现实世界是非常的庞大的,各种关系容易缺失,因此涉及知识图谱补全,本文SimplE就是一篇针对知识图谱补全的论文。

1.2 知识图谱补全方法

基础的知识图谱补全涉及以下几种方法

  • 双线性模型(disMult)
  • 神经网络模型 (ConvKB)
  • 转换模型(Trans系列)
  • 图卷积模型(GCN)

1.3知识图谱补全(Knowledge Graph Completion,KGC)

目前主要被抽象成一个预测问题,即预测出三元组中缺失的部分。所以可分成3个子任务:

  • 头实体预测:(?, r, t)
  • 关系预测:(h, ?, t)
  • 尾实体预测:(h, r, ?)

1.4 关系分类

E \mathcal{E}E and R \mathcal{R}R 分别表示实体和关系,一个三元组用 ( h , r , t ) (h, r, t)(h ,r ,t )表示, 其中 h ∈ E h \in \mathcal{E}h ∈E 是头实体, r ∈ R r \in \mathcal{R}r ∈R 是关系, and t ∈ E t \in \mathcal{E}t ∈E 是尾实体。ζ \zeta ζ 表示 一个真实的三元组(e.g., (paris, capitalOf, france)), 而 ζ ′ \zeta^{\prime}ζ′ 表示错误的三元组 (e.g., ( paris, capitalOf, italy)). 知识图谱 K G \mathcal{K} \mathcal{G}K G 是ζ 的 子 集 . \zeta的子集 .ζ的子集. A relation

  • 对称(symmetric):( e 1 , r , e 2 ) ∈ ζ ⟺ ( e 2 , r , e 1 ) ∈ ζ \left(e_{1}, r, e_{2}\right) \in \zeta \Longleftrightarrow\left(e_{2}, r, e_{1}\right) \in \zeta (e 1 ​,r ,e 2 ​)∈ζ⟺(e 2 ​,r ,e 1 ​)∈ζ,其中e 1 , e 2 ∈ E e_{1}, e_{2} \in \mathcal{E}e 1 ​,e 2 ​∈E
  • 反身 ( reflexive):( e , r , e ) ∈ ζ (e, r, e) \in \zeta (e ,r ,e )∈ζ for all entities e ∈ E e \in \mathcal{E}e ∈E
  • 反对称(anti-symmetric):( e 1 , r , e 2 ) ∈ ζ ⟺ ( e 2 , r , e 1 ) ∈ ζ ′ \left(e_{1}, r, e_{2}\right) \in \zeta \Longleftrightarrow\left(e_{2}, r, e_{1}\right) \in \zeta^{\prime}(e 1 ​,r ,e 2 ​)∈ζ⟺(e 2 ​,r ,e 1 ​)∈ζ′
  • 传递(transitive):( e 1 , r , e 2 ) ∈ ζ ∧ ( e 2 , r , e 3 ) ∈ ζ ⇒ ( e 1 , r , e 3 ) ∈ ζ \left(e_{1}, r, e_{2}\right) \in \zeta \wedge\left(e_{2}, r, e_{3}\right) \in \zeta \Rightarrow\left(e_{1}, r, e_{3}\right) \in \zeta (e 1 ​,r ,e 2 ​)∈ζ∧(e 2 ​,r ,e 3 ​)∈ζ⇒(e 1 ​,r ,e 3 ​)∈ζ ,其中e 1 , e 2 , e 3 ∈ E e_{1}, e_{2}, e_{3} \in \mathcal{E}e 1 ​,e 2 ​,e 3 ​∈E
  • 逆关系(inverse):( e i , r , e j ) ∈ ζ ⟺ ( e j , r − 1 , e i ) ∈ ζ \left(e_{i}, r, e_{j}\right) \in \zeta \Longleftrightarrow\left(e_{j}, r^{-1}, e_{i}\right) \in \zeta (e i ​,r ,e j ​)∈ζ⟺(e j ​,r −1 ,e i ​)∈ζ

2 模型

2.1 双线性模型

所谓的双线性模型即实体关系之间采用乘的方式,其乘的方式为⟨ v , w , x ⟩ ≐ ∑ j = 1 d v [ j ] ∗ w [ j ] ∗ x [ j ] \langle v, w, x\rangle \doteq \sum_{j=1}^{d} v[j] * w[j] * x[j]⟨v ,w ,x ⟩≐∑j =1 d ​v [j ]∗w [j ]∗x [j ],Hadamard乘积,每个元素元素之间进行相乘,然后累加,其中论文disMult也采用同样的方式。

2.2 核心公式

两个向量 h e , t e ∈ R d h_{e}, t_{e} \in \mathbb{R}^{d}h e ​,t e ​∈R d作为实体e e e的嵌入,向量 v r , v r − 1 ∈ R a v_{r}, v_{r}^{-1} \in \mathbb{R}^{a}v r ​,v r −1 ​∈R a 作为关系 r r r的嵌入。
1 2 ( ⟨ h e i , v r , t e j ⟩ + ⟨ h e j , v r − 1 , t e i ⟩ ) \frac{1}{2}\left(\left\langle h_{e_{i}}, v_{r}, t_{e_{j}}\right\rangle+\left\langle h_{e_{j}}, v_{r^{-1}}, t_{e_{i}}\right\rangle\right)2 1 ​(⟨h e i ​​,v r ​,t e j ​​⟩+⟨h e j ​​,v r −1 ​,t e i ​​⟩)作为模型的核心,计算得分函数。

2.3 负采样

模型采用随机的方法,对头实体或者尾实体进行负采样,随机从[0,num_ent-1]中抽取一个不同于原始的数据, num_ent表示实体总数。正确的三元组标记label为1,错位的三元组即复杂采样的结果标记label为-1。

2.4 损失函数

min ⁡ θ ∑ ( ( h , r , t ) , l ) ∈ L B softplus ⁡ ( − l ⋅ ϕ ( h , r , t ) ) + λ ∥ θ ∥ 2 2 \min {\theta} \sum{((h, r, t), l) \in \mathbf{L B}} \operatorname{softplus}(-l \cdot \phi(h, r, t))+\lambda\|\theta\|_{2}^{2}min θ​∑((h ,r ,t ),l )∈L B ​s o f t p l u s (−l ⋅ϕ(h ,r ,t ))+λ∥θ∥2 2 ​,其中θ \theta θ 代表模型参数(embeddings 的参数),l l l表示标签范围为-1或者+1即正确三元组或错误三元组。ϕ ( h , r , t ) \phi(h, r, t)ϕ(h ,r ,t )表示三元组 ( h , r , t ) (h, r, t)(h ,r ,t )的得分函数,softplus ⁡ ( x ) = log ⁡ ( 1 + exp ⁡ ( x ) ) \operatorname{softplus}(x)=\log (1+\exp (x))s o f t p l u s (x )=lo g (1 +exp (x ))

2.5 评价

知识图谱补全评价指标有hit@n, mrr, mr等方法,博客参考KGE性能指标:MRR,MR,HITS@1,HITS@3,HITS@10

  • MRRMRR的全称是Mean Reciprocal Ranking,其中Reciprocal是指”倒数的”的意思。具体的计算方法如下:
    M R R = 1 ∣ S ∣ ∑ i = 1 1 rank ⁡ i = 1 ∣ S ∣ ( 1 rank ⁡ 1 + 1 rank ⁡ 2 + … + 1 rank ⁡ ∣ S ∣ ) \mathrm{MRR}=\frac{1}{|S|} \sum_{i=1} \frac{1}{\operatorname{rank}{i}}=\frac{1}{|S|}\left(\frac{1}{\operatorname{rank}{1}}+\frac{1}{\operatorname{rank}{2}}+\ldots+\frac{1}{\operatorname{rank}{|S|}}\right)M R R =∣S ∣1 ​∑i =1 ​r a n k i ​1 ​=∣S ∣1 ​(r a n k 1 ​1 ​+r a n k 2 ​1 ​+…+r a n k ∣S ∣​1 ​),其中S是三元组集合,|S|是三元组集合个数,r a n k i rank_{i}r a n k i ​是指第i i i个三元组的链接预测排名。该指标越大越好。
  • MRMR的全称是Mean Rank。具体的计算方法如下:
    M R = 1 ∣ S ∣ ∑ i = 1 ∣ S ∣ rank ⁡ i = 1 ∣ S ∣ ( rank ⁡ 1 + rank ⁡ 2 + … + rank ⁡ ∣ S ∣ ) \mathbf{M R}=\frac{1}{|S|} \sum_{i=1}^{|S|} \operatorname{rank}{i}=\frac{1}{|S|}\left(\operatorname{rank}{1}+\operatorname{rank}{2}+\ldots+\operatorname{rank}{|S|}\right)M R =∣S ∣1 ​∑i =1 ∣S ∣​r a n k i ​=∣S ∣1 ​(r a n k 1 ​+r a n k 2 ​+…+r a n k ∣S ∣​)
    上述公式涉及的符号和MRR计算公式中涉及的符号一样。该指标越小越好。
  • HITS@n该指标是指在链接预测中排名小于n的三元组的平均占比。具体的计算方法如下:
    HITS ⁡ @ n = 1 ∣ S ∣ ∑ i = 1 ∣ S ∣ I ( rank ⁡ i ⩽ n ) \operatorname{HITS} @ n=\frac{1}{|S|} \sum_{i=1}^{|S|} \mathbb{I}\left(\operatorname{rank}_{i} \leqslant n\right)H I T S @n =∣S ∣1 ​∑i =1 ∣S ∣​I (r a n k i ​⩽n ) 其中,上述公式涉及的符号和MRR计算公式中涉及的符号一样,另外I ( ⋅ ) \mathbb{I}(\cdot)I (⋅) 是indicator函数(若条件真则函数值为1,否则为0)。一般地,取n等于1、3或者10。该指标越大越好。

代码包括6个模块,分别为:数据处理,模型模块,训练模块,测试模块,评价模块,主模块等,如图所示:

SimplE:SimplE Embedding for Link Prediction in Knowledge Graphs+代码

; 3.1 数据处理模块 dataset.py

import numpy as np
import random
import torch
import math

class Dataset:
    def __init__(self, ds_name):
        self.name = ds_name
        self.dir = "datasets/" + ds_name + "/"
        self.ent2id = {}
        self.rel2id = {}
        self.data = {spl: self.read(self.dir + spl + ".txt") for spl in ["train", "valid", "test"]}
        self.batch_index = 0

    def read(self, file_path):
        with open(file_path, "r") as f:
            lines = f.readlines()

        triples = np.zeros((len(lines), 3))

        for i, line in enumerate(lines):
            triples[i] = np.array(self.triple2ids(line.strip().split("\t")))
        return triples

    def num_ent(self):
        return len(self.ent2id)

    def num_rel(self):
        return len(self.rel2id)

    def triple2ids(self, triple):
        return [self.get_ent_id(triple[0]), self.get_rel_id(triple[1]), self.get_ent_id(triple[2])]

    def get_ent_id(self, ent):
        if not ent in self.ent2id:
            self.ent2id[ent] = len(self.ent2id)
        return self.ent2id[ent]

    def get_rel_id(self, rel):
        if not rel in self.rel2id:
            self.rel2id[rel] = len(self.rel2id)
        return self.rel2id[rel]

    def rand_ent_except(self, ent):
        rand_ent = random.randint(0, self.num_ent() - 1)
        while(rand_ent == ent):
            rand_ent = random.randint(0, self.num_ent() - 1)
        return rand_ent

    def next_pos_batch(self, batch_size):
        if self.batch_index + batch_size < len(self.data["train"]):
            batch = self.data["train"][self.batch_index: self.batch_index+batch_size]
            self.batch_index += batch_size
        else:
            batch = self.data["train"][self.batch_index:]
            self.batch_index = 0
        return np.append(batch, np.ones((len(batch), 1)), axis=1).astype("int")

    def generate_neg(self, pos_batch, neg_ratio):
        neg_batch = np.repeat(np.copy(pos_batch), neg_ratio, axis=0)
        for i in range(len(neg_batch)):
            if random.random() < 0.5:
                neg_batch[i][0] = self.rand_ent_except(neg_batch[i][0])
            else:
                neg_batch[i][2] = self.rand_ent_except(neg_batch[i][2])

        neg_batch[:,-1] = -1
        return neg_batch

    def next_batch(self, batch_size, neg_ratio, device):
        pos_batch = self.next_pos_batch(batch_size)
        neg_batch = self.generate_neg(pos_batch, neg_ratio)
        batch = np.append(pos_batch, neg_batch, axis=0)
        np.random.shuffle(batch)

        heads  = torch.tensor(batch[:,0]).long().to(device)
        rels   = torch.tensor(batch[:,1]).long().to(device)
        tails  = torch.tensor(batch[:,2]).long().to(device)
        labels = torch.tensor(batch[:,3]).float().to(device)
        return heads, rels, tails, labels

    def was_last_batch(self):
        return (self.batch_index == 0)

    def num_batch(self, batch_size):
        return int(math.ceil(float(len(self.data["train"])) / batch_size))

3.2 模型模块 model.py

import torch
import torch.nn as nn
import math

class SimplE(nn.Module):
    def __init__(self, num_ent, num_rel, emb_dim, device):
        super(SimplE, self).__init__()
        self.num_ent = num_ent
        self.num_rel = num_rel
        self.emb_dim = emb_dim
        self.device = device

        self.ent_h_embs   = nn.Embedding(num_ent, emb_dim).to(device)

        self.ent_t_embs   = nn.Embedding(num_ent, emb_dim).to(device)

        self.rel_embs     = nn.Embedding(num_rel, emb_dim).to(device)

        self.rel_inv_embs = nn.Embedding(num_rel, emb_dim).to(device)

        sqrt_size = 6.0 / math.sqrt(self.emb_dim)

        nn.init.uniform_(self.ent_h_embs.weight.data, -sqrt_size, sqrt_size)
        nn.init.uniform_(self.ent_t_embs.weight.data, -sqrt_size, sqrt_size)
        nn.init.uniform_(self.rel_embs.weight.data, -sqrt_size, sqrt_size)
        nn.init.uniform_(self.rel_inv_embs.weight.data, -sqrt_size, sqrt_size)

    def l2_loss(self):
        return ((torch.norm(self.ent_h_embs.weight, p=2) ** 2) + (torch.norm(self.ent_t_embs.weight, p=2) ** 2) + (torch.norm(self.rel_embs.weight, p=2) ** 2) + (torch.norm(self.rel_inv_embs.weight, p=2) ** 2)) / 2

    def forward(self, heads, rels, tails):
        hh_embs = self.ent_h_embs(heads)
        ht_embs = self.ent_h_embs(tails)
        th_embs = self.ent_t_embs(heads)
        tt_embs = self.ent_t_embs(tails)
        r_embs = self.rel_embs(rels)
        r_inv_embs = self.rel_inv_embs(rels)

        scores1 = torch.sum(hh_embs * r_embs * tt_embs, dim=1)
        scores2 = torch.sum(ht_embs * r_inv_embs * th_embs, dim=1)

        return torch.clamp((scores1 + scores2) / 2, -20, 20)

3.3 训练模块 Trainer.py

from dataset import Dataset
from SimplE import SimplE
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

class Trainer:
    def __init__(self, dataset, args):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        self.model = SimplE(dataset.num_ent(), dataset.num_rel(), args.emb_dim, self.device)

        self.dataset = dataset
        self.args = args

    def train(self):
        self.model.train()

        optimizer = torch.optim.Adagrad(
            self.model.parameters(),
            lr=self.args.lr,
            weight_decay= 0,
            initial_accumulator_value= 0.1
        )

        for epoch in range(1, self.args.ne + 1):
            last_batch = False
            total_loss = 0.0

            while not last_batch:

                h, r, t, l = self.dataset.next_batch(self.args.batch_size, neg_ratio=self.args.neg_ratio, device = self.device)
                last_batch = self.dataset.was_last_batch()
                optimizer.zero_grad()

                scores = self.model(h, r, t)

                loss = torch.sum(F.softplus(-l * scores))+ (self.args.reg_lambda * self.model.l2_loss() / self.dataset.num_batch(self.args.batch_size))
                loss.backward()
                optimizer.step()
                total_loss += loss.cpu().item()

            print("Loss in iteration " + str(epoch) + ": " + str(total_loss) + "(" + self.dataset.name + ")")

            if epoch % self.args.save_each == 0:
                self.save_model(epoch)

    def save_model(self, chkpnt):
        print("Saving the model")
        directory = "models/" + self.dataset.name + "/"
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save(self.model, directory + str(chkpnt) + ".chkpnt")

3.4 测试模块 Test.py

import torch
from dataset import Dataset
import numpy as np
from measure import Measure
from os import listdir
from os.path import isfile, join

class Tester:
    def __init__(self, dataset, model_path, valid_or_test):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model = torch.load(model_path, map_location = self.device)
        self.model.eval()
        self.dataset = dataset
        self.valid_or_test = valid_or_test
        self.measure = Measure()
        self.all_facts_as_set_of_tuples = set(self.allFactsAsTuples())

    def get_rank(self, sim_scores):
        return (sim_scores >= sim_scores[0]).sum()

    def create_queries(self, fact, head_or_tail):
        head, rel, tail = fact
        if head_or_tail == "head":
            return [(i, rel, tail) for i in range(self.dataset.num_ent())]
        elif head_or_tail == "tail":
            return [(head, rel, i) for i in range(self.dataset.num_ent())]

    def add_fact_and_shred(self, fact, queries, raw_or_fil):
        if raw_or_fil == "raw":
            result = [tuple(fact)] + queries
        elif raw_or_fil == "fil":
            result = [tuple(fact)] + list(set(queries) - self.all_facts_as_set_of_tuples)

        return self.shred_facts(result)

    def test(self):
        settings = ["raw", "fil"] if self.valid_or_test == "test" else ["fil"]

        for i, fact in enumerate(self.dataset.data[self.valid_or_test]):
            for head_or_tail in ["head", "tail"]:
                queries = self.create_queries(fact, head_or_tail)
                for raw_or_fil in settings:
                    h, r, t = self.add_fact_and_shred(fact, queries, raw_or_fil)
                    sim_scores = self.model(h, r, t).cpu().data.numpy()
                    rank = self.get_rank(sim_scores)
                    self.measure.update(rank, raw_or_fil)

        self.measure.normalize(len(self.dataset.data[self.valid_or_test]))
        self.measure.print_()
        return self.measure.mrr["fil"]

    def shred_facts(self, triples):
        heads  = [triples[i][0] for i in range(len(triples))]
        rels   = [triples[i][1] for i in range(len(triples))]
        tails  = [triples[i][2] for i in range(len(triples))]
        return torch.LongTensor(heads).to(self.device), torch.LongTensor(rels).to(self.device), torch.LongTensor(tails).to(self.device)

    def allFactsAsTuples(self):
        tuples = []
        for spl in self.dataset.data:
            for fact in self.dataset.data[spl]:
                tuples.append(tuple(fact))

        return tuples

3.5 评价模块 Measure.py

class Measure:
    def __init__(self):
        self.hit1  = {"raw": 0.0, "fil": 0.0}
        self.hit3  = {"raw": 0.0, "fil": 0.0}
        self.hit10 = {"raw": 0.0, "fil": 0.0}
        self.mrr   = {"raw": 0.0, "fil": 0.0}
        self.mr    = {"raw": 0.0, "fil": 0.0}

    def update(self, rank, raw_or_fil):
        if rank == 1:
            self.hit1[raw_or_fil] += 1.0
        if rank  3:
            self.hit3[raw_or_fil] += 1.0
        if rank  10:
            self.hit10[raw_or_fil] += 1.0

        self.mr[raw_or_fil]  += rank
        self.mrr[raw_or_fil] += (1.0 / rank)

    def normalize(self, num_facts):
        for raw_or_fil in ["raw", "fil"]:
            self.hit1[raw_or_fil]  /= (2 * num_facts)
            self.hit3[raw_or_fil]  /= (2 * num_facts)
            self.hit10[raw_or_fil] /= (2 * num_facts)
            self.mr[raw_or_fil]    /= (2 * num_facts)
            self.mrr[raw_or_fil]   /= (2 * num_facts)

    def print_(self):
        for raw_or_fil in ["raw", "fil"]:
            print(raw_or_fil.title() + " setting:")
            print("\tHit@1 =",  self.hit1[raw_or_fil])
            print("\tHit@3 =",  self.hit3[raw_or_fil])
            print("\tHit@10 =", self.hit10[raw_or_fil])
            print("\tMR =",     self.mr[raw_or_fil])
            print("\tMRR =",    self.mrr[raw_or_fil])
            print("")

3.6 主模块 Main.py

from trainer import Trainer
from tester import Tester
from dataset import Dataset
import argparse
import time
def get_parameter():
    parser = argparse.ArgumentParser()
    parser.add_argument('-ne', default=1000, type=int, help="number of epochs")
    parser.add_argument('-lr', default=0.1, type=float, help="learning rate")
    parser.add_argument('-reg_lambda', default=0.03, type=float, help="l2 regularization parameter")
    parser.add_argument('-dataset', default="WN18", type=str, help="wordnet dataset")
    parser.add_argument('-emb_dim', default=200, type=int, help="embedding dimension")
    parser.add_argument('-neg_ratio', default=1, type=int, help="number of negative examples per positive example")
    parser.add_argument('-batch_size', default=1415, type=int, help="batch size")
    parser.add_argument('-save_each', default=50, type=int, help="validate every k epochs")
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = get_parameter()
    dataset = Dataset(args.dataset)

    print("~~~~ Training ~~~~")
    trainer = Trainer(dataset, args)
    trainer.train()

    print("~~~~ Select best epoch on validation set ~~~~")
    epochs2test = [str(int(args.save_each * (i + 1))) for i in range(args.ne // args.save_each)]
    dataset = Dataset(args.dataset)

    best_mrr = -1.0
    best_epoch = "0"
    for epoch in epochs2test:
        start = time.time()
        print(epoch)
        model_path = "models/" + args.dataset + "/" + epoch + ".chkpnt"
        tester = Tester(dataset, model_path, "valid")
        mrr = tester.test()
        if mrr > best_mrr:
            best_mrr = mrr
            best_epoch = epoch
        print(time.time() - start)

    print("Best epoch: " + best_epoch)

    print("~~~~ Testing on the best epoch ~~~~")
    best_model_path = "models/" + args.dataset + "/" + best_epoch + ".chkpnt"
    tester = Tester(dataset, best_model_path, "test")
    tester.test()

Original: https://blog.csdn.net/REfusing/article/details/123314548
Author: Re:fused
Title: SimplE:SimplE Embedding for Link Prediction in Knowledge Graphs+代码

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/555356/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球