文章目录
- 1 介绍
* - 1.1 知识图谱
- 1.2 知识图谱补全方法
- 1.3知识图谱补全(Knowledge Graph Completion,KGC)
- 1.4 关系分类
- 2 模型
* - 2.1 双线性模型
- 2.2 核心公式
- 2.3 负采样
- 2.4 损失函数
- 2.5 评价
–
+ - 3 [代码](https://github.com/baharefatemi/SimplE)
* - 3.1 数据处理模块 dataset.py
- 3.2 模型模块 model.py
- 3.3 训练模块 Trainer.py
- 3.4 测试模块 Test.py
- 3.5 评价模块 Measure.py
- 3.6 主模块 Main.py
本文主要对知识图谱补全论文进行讲解,对其原理、思路、代码等方面进行详细讲解。
1 介绍
1.1 知识图谱
现实世界中可以通过三元组,对现实世界进行描述,三元组描述的形式为:实体-关系-实体或实体-关系-属性,通常可以用符号表示(h, r, t)其中h表示头实体,r表示实体和实体之间的关系,t表示尾实体。以该方式存储现实世界的各种关系以及属性,由于现实世界是非常的庞大的,各种关系容易缺失,因此涉及知识图谱补全,本文SimplE就是一篇针对知识图谱补全的论文。
1.2 知识图谱补全方法
基础的知识图谱补全涉及以下几种方法
- 双线性模型(disMult)
- 神经网络模型 (ConvKB)
- 转换模型(Trans系列)
- 图卷积模型(GCN)
1.3知识图谱补全(Knowledge Graph Completion,KGC)
目前主要被抽象成一个预测问题,即预测出三元组中缺失的部分。所以可分成3个子任务:
- 头实体预测:(?, r, t)
- 关系预测:(h, ?, t)
- 尾实体预测:(h, r, ?)
1.4 关系分类
E \mathcal{E}E and R \mathcal{R}R 分别表示实体和关系,一个三元组用 ( h , r , t ) (h, r, t)(h ,r ,t )表示, 其中 h ∈ E h \in \mathcal{E}h ∈E 是头实体, r ∈ R r \in \mathcal{R}r ∈R 是关系, and t ∈ E t \in \mathcal{E}t ∈E 是尾实体。ζ \zeta ζ 表示 一个真实的三元组(e.g., (paris, capitalOf, france)), 而 ζ ′ \zeta^{\prime}ζ′ 表示错误的三元组 (e.g., ( paris, capitalOf, italy)). 知识图谱 K G \mathcal{K} \mathcal{G}K G 是ζ 的 子 集 . \zeta的子集 .ζ的子集. A relation
- 对称(symmetric):( e 1 , r , e 2 ) ∈ ζ ⟺ ( e 2 , r , e 1 ) ∈ ζ \left(e_{1}, r, e_{2}\right) \in \zeta \Longleftrightarrow\left(e_{2}, r, e_{1}\right) \in \zeta (e 1 ,r ,e 2 )∈ζ⟺(e 2 ,r ,e 1 )∈ζ,其中e 1 , e 2 ∈ E e_{1}, e_{2} \in \mathcal{E}e 1 ,e 2 ∈E
- 反身 ( reflexive):( e , r , e ) ∈ ζ (e, r, e) \in \zeta (e ,r ,e )∈ζ for all entities e ∈ E e \in \mathcal{E}e ∈E
- 反对称(anti-symmetric):( e 1 , r , e 2 ) ∈ ζ ⟺ ( e 2 , r , e 1 ) ∈ ζ ′ \left(e_{1}, r, e_{2}\right) \in \zeta \Longleftrightarrow\left(e_{2}, r, e_{1}\right) \in \zeta^{\prime}(e 1 ,r ,e 2 )∈ζ⟺(e 2 ,r ,e 1 )∈ζ′
- 传递(transitive):( e 1 , r , e 2 ) ∈ ζ ∧ ( e 2 , r , e 3 ) ∈ ζ ⇒ ( e 1 , r , e 3 ) ∈ ζ \left(e_{1}, r, e_{2}\right) \in \zeta \wedge\left(e_{2}, r, e_{3}\right) \in \zeta \Rightarrow\left(e_{1}, r, e_{3}\right) \in \zeta (e 1 ,r ,e 2 )∈ζ∧(e 2 ,r ,e 3 )∈ζ⇒(e 1 ,r ,e 3 )∈ζ ,其中e 1 , e 2 , e 3 ∈ E e_{1}, e_{2}, e_{3} \in \mathcal{E}e 1 ,e 2 ,e 3 ∈E
- 逆关系(inverse):( e i , r , e j ) ∈ ζ ⟺ ( e j , r − 1 , e i ) ∈ ζ \left(e_{i}, r, e_{j}\right) \in \zeta \Longleftrightarrow\left(e_{j}, r^{-1}, e_{i}\right) \in \zeta (e i ,r ,e j )∈ζ⟺(e j ,r −1 ,e i )∈ζ
2 模型
2.1 双线性模型
所谓的双线性模型即实体关系之间采用乘的方式,其乘的方式为⟨ v , w , x ⟩ ≐ ∑ j = 1 d v [ j ] ∗ w [ j ] ∗ x [ j ] \langle v, w, x\rangle \doteq \sum_{j=1}^{d} v[j] * w[j] * x[j]⟨v ,w ,x ⟩≐∑j =1 d v [j ]∗w [j ]∗x [j ],Hadamard乘积,每个元素元素之间进行相乘,然后累加,其中论文disMult也采用同样的方式。
2.2 核心公式
两个向量 h e , t e ∈ R d h_{e}, t_{e} \in \mathbb{R}^{d}h e ,t e ∈R d作为实体e e e的嵌入,向量 v r , v r − 1 ∈ R a v_{r}, v_{r}^{-1} \in \mathbb{R}^{a}v r ,v r −1 ∈R a 作为关系 r r r的嵌入。
1 2 ( ⟨ h e i , v r , t e j ⟩ + ⟨ h e j , v r − 1 , t e i ⟩ ) \frac{1}{2}\left(\left\langle h_{e_{i}}, v_{r}, t_{e_{j}}\right\rangle+\left\langle h_{e_{j}}, v_{r^{-1}}, t_{e_{i}}\right\rangle\right)2 1 (⟨h e i ,v r ,t e j ⟩+⟨h e j ,v r −1 ,t e i ⟩)作为模型的核心,计算得分函数。
2.3 负采样
模型采用随机的方法,对头实体或者尾实体进行负采样,随机从[0,num_ent-1]中抽取一个不同于原始的数据, num_ent表示实体总数。正确的三元组标记label为1,错位的三元组即复杂采样的结果标记label为-1。
2.4 损失函数
min θ ∑ ( ( h , r , t ) , l ) ∈ L B softplus ( − l ⋅ ϕ ( h , r , t ) ) + λ ∥ θ ∥ 2 2 \min {\theta} \sum{((h, r, t), l) \in \mathbf{L B}} \operatorname{softplus}(-l \cdot \phi(h, r, t))+\lambda\|\theta\|_{2}^{2}min θ∑((h ,r ,t ),l )∈L B s o f t p l u s (−l ⋅ϕ(h ,r ,t ))+λ∥θ∥2 2 ,其中θ \theta θ 代表模型参数(embeddings 的参数),l l l表示标签范围为-1或者+1即正确三元组或错误三元组。ϕ ( h , r , t ) \phi(h, r, t)ϕ(h ,r ,t )表示三元组 ( h , r , t ) (h, r, t)(h ,r ,t )的得分函数,softplus ( x ) = log ( 1 + exp ( x ) ) \operatorname{softplus}(x)=\log (1+\exp (x))s o f t p l u s (x )=lo g (1 +exp (x ))
2.5 评价
知识图谱补全评价指标有hit@n, mrr, mr等方法,博客参考KGE性能指标:MRR,MR,HITS@1,HITS@3,HITS@10
- MRRMRR的全称是Mean Reciprocal Ranking,其中Reciprocal是指”倒数的”的意思。具体的计算方法如下:
M R R = 1 ∣ S ∣ ∑ i = 1 1 rank i = 1 ∣ S ∣ ( 1 rank 1 + 1 rank 2 + … + 1 rank ∣ S ∣ ) \mathrm{MRR}=\frac{1}{|S|} \sum_{i=1} \frac{1}{\operatorname{rank}{i}}=\frac{1}{|S|}\left(\frac{1}{\operatorname{rank}{1}}+\frac{1}{\operatorname{rank}{2}}+\ldots+\frac{1}{\operatorname{rank}{|S|}}\right)M R R =∣S ∣1 ∑i =1 r a n k i 1 =∣S ∣1 (r a n k 1 1 +r a n k 2 1 +…+r a n k ∣S ∣1 ),其中S是三元组集合,|S|是三元组集合个数,r a n k i rank_{i}r a n k i 是指第i i i个三元组的链接预测排名。该指标越大越好。 - MRMR的全称是Mean Rank。具体的计算方法如下:
M R = 1 ∣ S ∣ ∑ i = 1 ∣ S ∣ rank i = 1 ∣ S ∣ ( rank 1 + rank 2 + … + rank ∣ S ∣ ) \mathbf{M R}=\frac{1}{|S|} \sum_{i=1}^{|S|} \operatorname{rank}{i}=\frac{1}{|S|}\left(\operatorname{rank}{1}+\operatorname{rank}{2}+\ldots+\operatorname{rank}{|S|}\right)M R =∣S ∣1 ∑i =1 ∣S ∣r a n k i =∣S ∣1 (r a n k 1 +r a n k 2 +…+r a n k ∣S ∣)
上述公式涉及的符号和MRR计算公式中涉及的符号一样。该指标越小越好。 - HITS@n该指标是指在链接预测中排名小于n的三元组的平均占比。具体的计算方法如下:
HITS @ n = 1 ∣ S ∣ ∑ i = 1 ∣ S ∣ I ( rank i ⩽ n ) \operatorname{HITS} @ n=\frac{1}{|S|} \sum_{i=1}^{|S|} \mathbb{I}\left(\operatorname{rank}_{i} \leqslant n\right)H I T S @n =∣S ∣1 ∑i =1 ∣S ∣I (r a n k i ⩽n ) 其中,上述公式涉及的符号和MRR计算公式中涉及的符号一样,另外I ( ⋅ ) \mathbb{I}(\cdot)I (⋅) 是indicator函数(若条件真则函数值为1,否则为0)。一般地,取n等于1、3或者10。该指标越大越好。
代码包括6个模块,分别为:数据处理,模型模块,训练模块,测试模块,评价模块,主模块等,如图所示:
; 3.1 数据处理模块 dataset.py
import numpy as np
import random
import torch
import math
class Dataset:
def __init__(self, ds_name):
self.name = ds_name
self.dir = "datasets/" + ds_name + "/"
self.ent2id = {}
self.rel2id = {}
self.data = {spl: self.read(self.dir + spl + ".txt") for spl in ["train", "valid", "test"]}
self.batch_index = 0
def read(self, file_path):
with open(file_path, "r") as f:
lines = f.readlines()
triples = np.zeros((len(lines), 3))
for i, line in enumerate(lines):
triples[i] = np.array(self.triple2ids(line.strip().split("\t")))
return triples
def num_ent(self):
return len(self.ent2id)
def num_rel(self):
return len(self.rel2id)
def triple2ids(self, triple):
return [self.get_ent_id(triple[0]), self.get_rel_id(triple[1]), self.get_ent_id(triple[2])]
def get_ent_id(self, ent):
if not ent in self.ent2id:
self.ent2id[ent] = len(self.ent2id)
return self.ent2id[ent]
def get_rel_id(self, rel):
if not rel in self.rel2id:
self.rel2id[rel] = len(self.rel2id)
return self.rel2id[rel]
def rand_ent_except(self, ent):
rand_ent = random.randint(0, self.num_ent() - 1)
while(rand_ent == ent):
rand_ent = random.randint(0, self.num_ent() - 1)
return rand_ent
def next_pos_batch(self, batch_size):
if self.batch_index + batch_size < len(self.data["train"]):
batch = self.data["train"][self.batch_index: self.batch_index+batch_size]
self.batch_index += batch_size
else:
batch = self.data["train"][self.batch_index:]
self.batch_index = 0
return np.append(batch, np.ones((len(batch), 1)), axis=1).astype("int")
def generate_neg(self, pos_batch, neg_ratio):
neg_batch = np.repeat(np.copy(pos_batch), neg_ratio, axis=0)
for i in range(len(neg_batch)):
if random.random() < 0.5:
neg_batch[i][0] = self.rand_ent_except(neg_batch[i][0])
else:
neg_batch[i][2] = self.rand_ent_except(neg_batch[i][2])
neg_batch[:,-1] = -1
return neg_batch
def next_batch(self, batch_size, neg_ratio, device):
pos_batch = self.next_pos_batch(batch_size)
neg_batch = self.generate_neg(pos_batch, neg_ratio)
batch = np.append(pos_batch, neg_batch, axis=0)
np.random.shuffle(batch)
heads = torch.tensor(batch[:,0]).long().to(device)
rels = torch.tensor(batch[:,1]).long().to(device)
tails = torch.tensor(batch[:,2]).long().to(device)
labels = torch.tensor(batch[:,3]).float().to(device)
return heads, rels, tails, labels
def was_last_batch(self):
return (self.batch_index == 0)
def num_batch(self, batch_size):
return int(math.ceil(float(len(self.data["train"])) / batch_size))
3.2 模型模块 model.py
import torch
import torch.nn as nn
import math
class SimplE(nn.Module):
def __init__(self, num_ent, num_rel, emb_dim, device):
super(SimplE, self).__init__()
self.num_ent = num_ent
self.num_rel = num_rel
self.emb_dim = emb_dim
self.device = device
self.ent_h_embs = nn.Embedding(num_ent, emb_dim).to(device)
self.ent_t_embs = nn.Embedding(num_ent, emb_dim).to(device)
self.rel_embs = nn.Embedding(num_rel, emb_dim).to(device)
self.rel_inv_embs = nn.Embedding(num_rel, emb_dim).to(device)
sqrt_size = 6.0 / math.sqrt(self.emb_dim)
nn.init.uniform_(self.ent_h_embs.weight.data, -sqrt_size, sqrt_size)
nn.init.uniform_(self.ent_t_embs.weight.data, -sqrt_size, sqrt_size)
nn.init.uniform_(self.rel_embs.weight.data, -sqrt_size, sqrt_size)
nn.init.uniform_(self.rel_inv_embs.weight.data, -sqrt_size, sqrt_size)
def l2_loss(self):
return ((torch.norm(self.ent_h_embs.weight, p=2) ** 2) + (torch.norm(self.ent_t_embs.weight, p=2) ** 2) + (torch.norm(self.rel_embs.weight, p=2) ** 2) + (torch.norm(self.rel_inv_embs.weight, p=2) ** 2)) / 2
def forward(self, heads, rels, tails):
hh_embs = self.ent_h_embs(heads)
ht_embs = self.ent_h_embs(tails)
th_embs = self.ent_t_embs(heads)
tt_embs = self.ent_t_embs(tails)
r_embs = self.rel_embs(rels)
r_inv_embs = self.rel_inv_embs(rels)
scores1 = torch.sum(hh_embs * r_embs * tt_embs, dim=1)
scores2 = torch.sum(ht_embs * r_inv_embs * th_embs, dim=1)
return torch.clamp((scores1 + scores2) / 2, -20, 20)
3.3 训练模块 Trainer.py
from dataset import Dataset
from SimplE import SimplE
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
class Trainer:
def __init__(self, dataset, args):
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.model = SimplE(dataset.num_ent(), dataset.num_rel(), args.emb_dim, self.device)
self.dataset = dataset
self.args = args
def train(self):
self.model.train()
optimizer = torch.optim.Adagrad(
self.model.parameters(),
lr=self.args.lr,
weight_decay= 0,
initial_accumulator_value= 0.1
)
for epoch in range(1, self.args.ne + 1):
last_batch = False
total_loss = 0.0
while not last_batch:
h, r, t, l = self.dataset.next_batch(self.args.batch_size, neg_ratio=self.args.neg_ratio, device = self.device)
last_batch = self.dataset.was_last_batch()
optimizer.zero_grad()
scores = self.model(h, r, t)
loss = torch.sum(F.softplus(-l * scores))+ (self.args.reg_lambda * self.model.l2_loss() / self.dataset.num_batch(self.args.batch_size))
loss.backward()
optimizer.step()
total_loss += loss.cpu().item()
print("Loss in iteration " + str(epoch) + ": " + str(total_loss) + "(" + self.dataset.name + ")")
if epoch % self.args.save_each == 0:
self.save_model(epoch)
def save_model(self, chkpnt):
print("Saving the model")
directory = "models/" + self.dataset.name + "/"
if not os.path.exists(directory):
os.makedirs(directory)
torch.save(self.model, directory + str(chkpnt) + ".chkpnt")
3.4 测试模块 Test.py
import torch
from dataset import Dataset
import numpy as np
from measure import Measure
from os import listdir
from os.path import isfile, join
class Tester:
def __init__(self, dataset, model_path, valid_or_test):
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.model = torch.load(model_path, map_location = self.device)
self.model.eval()
self.dataset = dataset
self.valid_or_test = valid_or_test
self.measure = Measure()
self.all_facts_as_set_of_tuples = set(self.allFactsAsTuples())
def get_rank(self, sim_scores):
return (sim_scores >= sim_scores[0]).sum()
def create_queries(self, fact, head_or_tail):
head, rel, tail = fact
if head_or_tail == "head":
return [(i, rel, tail) for i in range(self.dataset.num_ent())]
elif head_or_tail == "tail":
return [(head, rel, i) for i in range(self.dataset.num_ent())]
def add_fact_and_shred(self, fact, queries, raw_or_fil):
if raw_or_fil == "raw":
result = [tuple(fact)] + queries
elif raw_or_fil == "fil":
result = [tuple(fact)] + list(set(queries) - self.all_facts_as_set_of_tuples)
return self.shred_facts(result)
def test(self):
settings = ["raw", "fil"] if self.valid_or_test == "test" else ["fil"]
for i, fact in enumerate(self.dataset.data[self.valid_or_test]):
for head_or_tail in ["head", "tail"]:
queries = self.create_queries(fact, head_or_tail)
for raw_or_fil in settings:
h, r, t = self.add_fact_and_shred(fact, queries, raw_or_fil)
sim_scores = self.model(h, r, t).cpu().data.numpy()
rank = self.get_rank(sim_scores)
self.measure.update(rank, raw_or_fil)
self.measure.normalize(len(self.dataset.data[self.valid_or_test]))
self.measure.print_()
return self.measure.mrr["fil"]
def shred_facts(self, triples):
heads = [triples[i][0] for i in range(len(triples))]
rels = [triples[i][1] for i in range(len(triples))]
tails = [triples[i][2] for i in range(len(triples))]
return torch.LongTensor(heads).to(self.device), torch.LongTensor(rels).to(self.device), torch.LongTensor(tails).to(self.device)
def allFactsAsTuples(self):
tuples = []
for spl in self.dataset.data:
for fact in self.dataset.data[spl]:
tuples.append(tuple(fact))
return tuples
3.5 评价模块 Measure.py
class Measure:
def __init__(self):
self.hit1 = {"raw": 0.0, "fil": 0.0}
self.hit3 = {"raw": 0.0, "fil": 0.0}
self.hit10 = {"raw": 0.0, "fil": 0.0}
self.mrr = {"raw": 0.0, "fil": 0.0}
self.mr = {"raw": 0.0, "fil": 0.0}
def update(self, rank, raw_or_fil):
if rank == 1:
self.hit1[raw_or_fil] += 1.0
if rank 3:
self.hit3[raw_or_fil] += 1.0
if rank 10:
self.hit10[raw_or_fil] += 1.0
self.mr[raw_or_fil] += rank
self.mrr[raw_or_fil] += (1.0 / rank)
def normalize(self, num_facts):
for raw_or_fil in ["raw", "fil"]:
self.hit1[raw_or_fil] /= (2 * num_facts)
self.hit3[raw_or_fil] /= (2 * num_facts)
self.hit10[raw_or_fil] /= (2 * num_facts)
self.mr[raw_or_fil] /= (2 * num_facts)
self.mrr[raw_or_fil] /= (2 * num_facts)
def print_(self):
for raw_or_fil in ["raw", "fil"]:
print(raw_or_fil.title() + " setting:")
print("\tHit@1 =", self.hit1[raw_or_fil])
print("\tHit@3 =", self.hit3[raw_or_fil])
print("\tHit@10 =", self.hit10[raw_or_fil])
print("\tMR =", self.mr[raw_or_fil])
print("\tMRR =", self.mrr[raw_or_fil])
print("")
3.6 主模块 Main.py
from trainer import Trainer
from tester import Tester
from dataset import Dataset
import argparse
import time
def get_parameter():
parser = argparse.ArgumentParser()
parser.add_argument('-ne', default=1000, type=int, help="number of epochs")
parser.add_argument('-lr', default=0.1, type=float, help="learning rate")
parser.add_argument('-reg_lambda', default=0.03, type=float, help="l2 regularization parameter")
parser.add_argument('-dataset', default="WN18", type=str, help="wordnet dataset")
parser.add_argument('-emb_dim', default=200, type=int, help="embedding dimension")
parser.add_argument('-neg_ratio', default=1, type=int, help="number of negative examples per positive example")
parser.add_argument('-batch_size', default=1415, type=int, help="batch size")
parser.add_argument('-save_each', default=50, type=int, help="validate every k epochs")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_parameter()
dataset = Dataset(args.dataset)
print("~~~~ Training ~~~~")
trainer = Trainer(dataset, args)
trainer.train()
print("~~~~ Select best epoch on validation set ~~~~")
epochs2test = [str(int(args.save_each * (i + 1))) for i in range(args.ne // args.save_each)]
dataset = Dataset(args.dataset)
best_mrr = -1.0
best_epoch = "0"
for epoch in epochs2test:
start = time.time()
print(epoch)
model_path = "models/" + args.dataset + "/" + epoch + ".chkpnt"
tester = Tester(dataset, model_path, "valid")
mrr = tester.test()
if mrr > best_mrr:
best_mrr = mrr
best_epoch = epoch
print(time.time() - start)
print("Best epoch: " + best_epoch)
print("~~~~ Testing on the best epoch ~~~~")
best_model_path = "models/" + args.dataset + "/" + best_epoch + ".chkpnt"
tester = Tester(dataset, best_model_path, "test")
tester.test()
Original: https://blog.csdn.net/REfusing/article/details/123314548
Author: Re:fused
Title: SimplE:SimplE Embedding for Link Prediction in Knowledge Graphs+代码
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/555356/
转载文章受原作者版权保护。转载请注明原作者出处!