知识图到文本的生成——伍

2021SC@SDUSC

我们继续分析dataset类,dataset类位于lastDataset.py文件中,是该算法的核心代码之一。dataset类中一共有20个类函数,我将会挑选核心的函数来分析。

首先是对数据集建立词表的build_ent_vocab函数。

  def build_ent_vocab(self,path,unkat=0):
    ents = ""
    with open(path,encoding='utf-8') as f:
      for l in f:
        ents +=  " "+l.split("\t")[1]
    itos = sorted(list(set(ents.split(" "))))
    itos[0] == ""; itos[1] == ""
    stoi = {x:i for i,x in enumerate(itos)}
    return itos,stoi

参数中的path就是数据集所在的路径,调用的时候传入。unkat参数初始值为0,意为为转换。ents是声明的字符串变量,存储遍历读取到的字符串数据集。itos是一个列表变量,每个元素都是ents中根据” “切割出的分词。比如ents=’A B’,那么itos则为[‘A’,’B’],初始化itos第一个值为unk,第二个值为pad,enumerate()函数将itos组合为索引序列,结果组合为stoi变量。返回数据对象itos和索引序列stoi。

接下来是mkGraphs函数。

  def mkGraphs(self,r,ent):
    ......

    return (adj,rel)

这个函数的作用是用adj和rel矩阵将三元组转换为entlist。具体操作非关键代码,此处不再赘述。

接下来是mkVocabs函数。

  def mkVocabs(self,args):
    args.path = args.datadir + args.data
    self.INP = data.Field(sequential=True, batch_first=True,init_token="", eos_token="",include_lengths=True)
    self.OUTP = data.Field(sequential=True, batch_first=True,init_token="", eos_token="",include_lengths=True)
    self.TGT = data.Field(sequential=True, batch_first=True,init_token="", eos_token="")
    self.NERD = data.Field(sequential=True, batch_first=True,eos_token="")
    self.ENT = data.RawField()
    self.REL = data.RawField()
    self.SORDER = data.RawField()
    self.SORDER.is_target = False
    self.REL.is_target = False
    self.ENT.is_target = False
    self.fields=[("src",self.INP),("ent",self.ENT),("nerd",self.NERD),("rel",self.REL),("out",self.OUTP),("sorder",self.SORDER)]

该段代码就是对这些参数进行操作,Field类和RawField类在之前已经详细分析过,此处不再单独分析这两个类。它设置了处理后保存的路径,设置INP和OUTP为顺序数据、先生成batch dimension的tensor、以”

    if args.eval:
      train = data.TabularDataset(path=args.datadir+args.traindata, format='tsv',fields=self.fields)
    else:
      train = data.TabularDataset(path=args.path, format='tsv',fields=self.fields)

    print('building vocab')

train变量为把data定义为以TSV格式存储的列的数据集。TabularDataset是一个类,用来定义以CSV、TSV或JSON格式存储的列的数据集。如果使用dict,键应该是JSON键或CSV/TSV列的子集,值应该是(name, field)的元组。这会允许我们从其JSON/CSV/TSV键名重命名列,还允许选择要加载的列的子集。

    self.OUTP.build_vocab(train, min_freq=args.outunk)
    generics =['','','','','']
    self.OUTP.vocab.itos.extend(generics)
    for x in generics:
      self.OUTP.vocab.stoi[x] = self.OUTP.vocab.itos.index(x)
    self.TGT.vocab = copy(self.OUTP.vocab)
    specials = "method material otherscientificterm metric task".split(" ")
    for x in specials:
      for y in range(40):
        s = ""
        self.TGT.vocab.stoi[s] = len(self.TGT.vocab.itos)+y
    self.NERD.build_vocab(train,min_freq=0)
    for x in generics:
      self.NERD.vocab.stoi[x] = self.OUTP.vocab.stoi[x]

首先对要输出的变量进行build_vocab操作,该函数为Field的类函数,之前已分析过,此处不再赘述。generics是作者(不是我,是写代码的人)在数据集中找的一个实例。接下来就是对这个数据集进行扩大、切割、存储操作,specials就是把”method material otherscientificterm metric task”这个字符串根据” “进行分割,也就是generics。

接下来看一个批处理函数fixBatch()。

  def fixBatch(self,b):
    ent,phlens = zip(*b.ent)
    ent,elens = self.adjToBatch(ent)
    ent = ent.to(self.args.device)
    adj,rel = zip(*b.rel)
    if self.args.sparse:
      b.rel = [adj,self.listTo(rel)]
    else:
      b.rel = [self.listTo(adj),self.listTo(rel)]
    if self.args.plan:
      b.sordertgt = self.listTo(self.pad_list(b.sordertgt))
    phlens = torch.cat(phlens,0).to(self.args.device)
    elens = elens.to(self.args.device)
    b.ent = (ent,phlens,elens)
    return b

参数b为传入的地址。ent,phlens = zip(b.ent)和adj,rel = zip(b.rel)为解压b,解压后仍为元组,对解压后的元组调用adjToBatch函数进行生成邻接矩阵的批处理操作,最后返回的是矩阵。最后b直接变为三元组并返回。

Original: https://blog.csdn.net/qq_50729659/article/details/121340858
Author: 槐廿拾
Title: 知识图到文本的生成——伍

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/556565/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球