R2D2:基于可微分树的预训练模型

https://arxiv.org/abs/2107.00967
在一次分享中看到这篇论文,感觉有意思细读了一下
主要是讲基于可微分树的递归transformer来实现具有强解释性的层次预训练语言模型

论文主要章节涉及了三个方面

  • 模型算法,讲解借助transformer实现对句子树结构的提取
  • 算法复杂度的优化,相比于之前提出的tree-LSTM是n 3 n^3 n 3复杂度降低到了线性复杂度
  • 在以上基础上进行大语料的预训练

相关背景知识

  1. 基于CKY算法的语法分析介绍 博客
乔姆斯基范式(CNF,Chomsky Normal Form)

任何语法都可以转化成一个弱等价的CNF形式,CNF语法都是二分叉

R2D2:基于可微分树的预训练模型
; CYK算法

CYK算法(也称为Cocke–Younger–Kasami算法)是一种用来对 上下文无关文法(CFG,Context Free Grammar)进行语法分析(parsing)的算法。该算法最早由John Cocke, Daniel Younger and Tadao Kasami分别独立提出,其中John Cocke还是1987年度的图灵奖得主。CYK算法是基于动态规划思想设计的一种自底向上语法分析算法。
看过最易懂的博文
代码实现
2. Gumbel-Softmax estimation
在自底向上的计算过程中,每个格子会有多种组合方式,在各种组合方式中,选择概率最大的组合,即argmax函数。但是argmax函数是不可导的,没有办法反向传播。
通过reparameterization对logits的输出拟合为onehot,同时保证梯度可以反向传播
对离散变量再参数化
4. 基于大语料的预训练语言模型的大概套路

模型结构设计
Differentiable Tree

R2D2:基于可微分树的预训练模型
该论文定义了一个类似于CKY形式的可微二叉树解析器
句子 S={s1,s2,s3,…sn}
如上图,每一个格子T ( i , j ) = < e i , j , p i , j , p ~ i , j > \Tau(i,j)=T (i ,j )=
e i , j e_{i,j}e i ,j ​ 是向量表征
p i , j p_{i,j}p i ,j ​ 是每一个步所有组合的概率
p ~ i , j \tilde{p}{i,j}p ~​i ,j ​是在[s i s_i s i ​,s j s_j s j ​]的子树的概率
树的末端节点是T i , i \Tau
{i,i}T i ,i ​,e i , i e_{i,i}e i ,i ​以当前输入s i s_i s i ​的向量初始化,p i , j p_{i,j}p i ,j ​ 和p ~ i , j \tilde{p}_{i,j}p ~​i ,j ​初始化为1。
R2D2:基于可微分树的预训练模型

上述公式的k是指(s i s_i s i ​,s j − 1 s_{j-1}s j −1 ​)之间的某一分割点(分割点不同,会对应出不同的组合)
第一个公式
f ( . ) f(.)f (.)是我们下一节Recursive Transformer定义的函数,p i , j k p_{i,j}^k p i ,j k ​ 和p ~ i , j k \tilde{p}{i,j}^k p ~​i ,j k ​分别指一步中组合的概率和其子树的概率
第二个公式
以K为分割点的子树的概率,是当前组合的概率和左右子树概率的乘积,这个和CKY算法是一致的
第三个公式
Straight Through Gumbel-Softmax ,通过一定方式实现类似argmax函数的可微
p i , j p
{i,j}p i ,j ​ 和p ~ i , j \tilde{p}{i,j}p ~​i ,j ​是基于所有分割点得到的p i , j k p{i,j}^k p i ,j k ​ 和p ~ i , j k \tilde{p}{i,j}^k p ~​i ,j k ​的组合
output: 计算得出权重
第四个公式
通过当前组合与权重系数的乘积计算出e i , i e
{i,i}e i ,i ​
第五个公式
通过概率向量与权重系数的乘积计算出新的概率向量

; Recursive Transformer

R2D2:基于可微分树的预训练模型
这个图对应了上一节第一个公式。
中间shape的转换过程看图,不想转述了,最终输出的p i , j p_{i,j}p i ,j ​是R 1 R^1 R 1, c i , j k c_{i,j}^k c i ,j k ​是R d R^d R d
Tree Recovery

通过Straight-Through Gumbel-Softmax在每一个cell选择最佳的分割点,Tree(T 1 , n \Tau_{1,n}T 1 ,n ​), 从树的根节点自顶向下递归操作,选择的最佳分割点还原树的结构,类似于CKY算法最后的回溯过程

Complexity Optimization 复杂度优化

上述的f ( . ) f(.)f (.)是整个模型的核心计算部分,我们可以通过树的剪枝归并算法来实现对f ( . ) f(.)f (.)O(n 3 n^3 n 3)
复杂度到线性复杂度

算法

R2D2:基于可微分树的预训练模型
; 寻找最佳的合并点

R2D2:基于可微分树的预训练模型
example

R2D2:基于可微分树的预训练模型
这张图展示了长度为6的句子的处理过程。
m表示设定的剪枝的阈值 T \Tau T 是一个二维数组,用来盛放自底向上计算的所有cell。
上上述图示的三个function:
TREEINDUCTION 是前向计算的过程,调用PRUNING进行剪枝,PRUNING调用FIND寻找最佳消并点。
计算m之下的cell,如上图(b)显示。
当cell的row大于等于m时,还原所有以第m行的节点为root节点的子树,调用PRUNING进行剪枝操作,
剪枝的第一步是找到局部最佳的merge点(上图c),剪掉部分的cell(上图d),返回一个新的T \Tau T(上图e)
在FIND中,最佳分割点的候选集合需要满足两个条件
(1)在T \Tau T的第二行
(2)在以第m行的节点为root节点的子树中有被使用到
然后在候选集合中选择(x.p pl pr)最高的cell T i , j \Tau_{i,j}T i ,j ​做为最佳merge点,对应的将T i , ∗ \Tau_{i,}T i ,∗​和T ∗ , j \Tau_{,j}T ∗,j ​剪掉,得到T 3 \Tau^3 T 3
; 实验

预训练目标:

  1. 学习词汇表征,在实际实验中是对于word piece的表征,选择WikiText-2数据集,长度在128以内的句子,mask词汇,输入左子树和右子树的embedding进行词汇预测
    因为剪枝操作,存在左子树或者右子树为空,以临近的最长子树来替代
    R2D2:基于可微分树的预训练模型
  2. 无监督成分句法分析
    在 WSJ and CTB 测试集计算F1
    R2D2:基于可微分树的预训练模型

基于word-piece的word、NP等的召回

R2D2:基于可微分树的预训练模型

Original: https://blog.csdn.net/qq_27965129/article/details/120953040
Author: 不知芝芝
Title: R2D2:基于可微分树的预训练模型

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/532191/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球