ERICA 代码解读

目录

前言

pretrain

数据预处理

模型训练

MLM/RD loss

MLM/ED loss

小结

finetune

总结

前言

论文链接:https://arxiv.org/abs/2012.15022

开源链接:https://github.com/thunlp/ERICA

这是一篇预训练模型,主要创新点就是提出了两个辅助性预训练任务来帮助PLM更好地理解实体和实体间关系:

(1) 实体区分任务,给定头实体和关系,推断出文本中正确的尾实体。

(2) 关系判别任务,区分两个关系在语义上是否接近,这在长文本情景下涉及复杂的关系推理。

为了避免灾难性遗忘,作者同时还加了masked language modeling (MLM)这一传统任务,所以总loss就是:

ERICA 代码解读

ED就是实体区分任务、RD就是关系判别任务、MLM就是传统屏蔽任务

更多详细解读可以看如下,笔者不再累述,本篇主要目的是解读代码。

ERICA: 提升预训练语言模型实体与关系理解的统一框架

pretrain

数据预处理

该部分代码逻辑在./pretrain/prepare_pretrain_data

get_distant.py:数据清洗,实体抽取和关系抽取

remove_test_set.py:区分训练集和测试集

sample_data.py:tokenized化,通过这样预处理。

这里没什么要说的,笔者比较感兴趣的是实体关系抽取是怎么做的。其实很简单,这里没有什么模型啥的,最主要的就是靠下面几个文件:

all_triple.txt:定义了实体关系

all_name_to_Q.json:实体名到类型的一个json

all_Q.json:所以实体类型id的。

关于实体抽取就是匹配,依靠上述文件,只要匹配到就得到实体。关于关系抽取更简单了,只要实体类型定了那么依靠all_triple.txt就确定了关系。

其中./pretrain/data/DOC/sampled_data/下就是官方给出的一个预处理完的数据结果,可以看看

模型训练

主要逻辑是在./pretrain/code/pretrain下,主入口就是main.py,主要就是:

ERICA 代码解读

根据论文我们知道模型主要涉及到三部分loss【ED/RD/MLM】

红色框的doc_loss就是【mask loss + 关系判别即 MLM + RD】,绿色框的wiki_loss就是【mask loss + 实体区分即MLM + ED】

我们来一部分一部分看,主要是在model.py中

ERICA 代码解读

可以看到主要就是对应两个函数236行和239行即get_doc_loss和get_wiki_loss函数,需要注意的是两个函数的输入是不一样的,即batch[0]和batch[1],关于输入数据的格式可以看dataset.py:

ERICA 代码解读

主要就是730行,其实就是get_doc_batch和get_wiki_batch两个函数。好了,大概代码逻辑框架知道了,下面分开看:

MLM/RD loss

数据输入就是:get_doc_batch

模型就是:get_doc_loss

如下是get_doc_loss

ERICA 代码解读

可以看到,

以上的MLM loss就是传统的预训练模型,不是本文的创新点,下面我们来看看论文的创新点RD loss也即关系区分任务【接着看上图的get_doc_loss函数,为了方便,这里再放一次】

ERICA 代码解读

作者这里用了对比学习:正样本即具有相同远程监督标签的关系表示,负样本与此相反,关于关系的表征,就是其对应的两个实体的简单拼接,即上述代码的173行得到的hidden。

start_re_output和end_re_output可以看做是头实体和尾实体表征。

context_output就是我们上述修改transformers源码返回的sequence ouput

h_mapping和t_mapping是batch传进来的,可以通过get_doc_batch看到就是代表的实体位置,然后通过和context_output相乘就可以滤除全部头实体和尾实体的编码表征

至此用pair_hidden【hidden】和relation_label通过对比学习计算loss【NTXentLoss_doc函数】

对比学习原理这里不在累述,感兴趣的可以看笔者另外一篇博客:

https://blog.csdn.net/weixin_42001089/article/details/117930433

这里对应的公式就是:

ERICA 代码解读

该小节的函数get_doc_loss最后返回就是m_loss和r_loss即MLM loss和ED loss也即屏蔽语言模型loss和关系区分loss

MLM/ED loss

数据输入就是:get_wiki_batch

模型就是:get_wiki_loss

ERICA 代码解读

首先206行返回的就是mlm loss,前面已经讲过,这里不在累述,一模一样,重点看看ED loss

他的原理是根据头实体和关系预测尾实体

start_re_output可以看出是头实体,而query_re_output可以看做是关系,我们知道paper的关系表征是头尾实体的简单拼接,所以query_re_output是通过query_mapping得到的,可以理解为query_mapping是当前关系对应头尾实体位置,通过和context_output相乘就过滤出对应头尾实体,进而进行拼接得到关系表征,关于query_mapping是batch得到的,可以看get_wiki_batch

该小节的函数get_wiki_loss最后返回就是m_loss和r_loss即MLM loss和RD loss也即屏蔽语言模型loss和实体区分loss

小结

(1) 新加的两个辅助任务是分开进行的【过两次模型】,但二者每次都顺便带了mlm loss

(2) mlm 部分给了很多落地启发,即自己有了个什么想法,能快速使用transfomers实现,尤其二次预训练,甚至我们可以改源码。

finetune

代码在finetune,这里面每一个文件夹代表一个下游任务,没什么可讲的,主要就是用上述pretrain得到的模型去热启就行了。

总结

(1) 以后我们有什么自己的mask策略想法,想落地实现,其实就是仿效改这个函数。

(2) 只需要修改完上述函数,直接传到对应的huggingface框架下的ForMaskedLM【比如BertForMaskedLM】,就可以直接到返回的Loss,进而进行MLM语言屏蔽模型训练

(3) 遇到一些特殊需求需要改huggingface框架也不是不可以,直接下载transformers代码进行需求修改即可

欢迎关注笔者的微信公众号,更多好文章:

ERICA 代码解读
​​​​​​​

Original: https://blog.csdn.net/weixin_42001089/article/details/118002302
Author: weixin_42001089
Title: ERICA 代码解读

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/568888/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球