重参数 (Reparameterization)

Contents

基本概念

  • 重参数 (Reparameterization) 实际上是处理如下期望形式的目标函数的一种技巧:
    重参数 (Reparameterization)
  • 重参数假设从分布p θ ( z ) p_θ(z)p θ​(z ) 中采样可以分解为 两个步骤:(1) 从无参数分布q ( ε ) q(ε)q (ε) 中采样一个ε εε;(2) 通过变换z = g θ ( ε ) z=g_θ(ε)z =g θ​(ε) 生成z z z。那么,上述期望就变成了
    重参数 (Reparameterization)

; 连续情形

  • 简单起见,我们先考虑z z z 为连续随机变量的情形:
    重参数 (Reparameterization)
  • 总的来说,连续情形的重参数还是比较简单的。从数学本质来看,重参数是一种 积分变换,即原来是关于z z z 积分,通过z = g θ ( ε ) z=g_θ(ε)z =g θ​(ε) 变换之后得到新的积分形式。一个最简单的例子就是 正态分布:对于正态分布来说,重参数就是 “从N ( z ; μ θ , σ θ 2 ) N(z;μ_θ,σ^2_θ)N (z ;μθ​,σθ2 ​) 中采样一个z z z” 变成 “从N ( ε ; 0 , 1 ) N(ε;0,1)N (ε;0 ,1 ) 中采样一个ε εε,然后计算ε × σ θ + μ θ ε×σ_θ+μ_θε×σθ​+μθ​”,所以
    重参数 (Reparameterization)

离散情形

  • 为了突出 “离散”,我们将随机变量z z z 换成y y y,即对于离散情形要面对的目标函数是
    重参数 (Reparameterization)
    重参数 (Reparameterization)
  • 看到上述期望项中的求和,第一反应可能是 “求和?那就求呗,又不是求不了”。的确, 对于离散的随机变量,其期望只不过是有限项求和,理论上确实可以直接完成求和再去梯度下降。但是,如果k k k 特别大呢?举个例子,假设y y y 是一个 100 维的向量,每个元素不是 0 就是 1,那么所有不同的y y y 的总数目就是2 100 2^{100}2 1 0 0,要对这样的2 100 2^{100}2 1 0 0 个单项进行求和, 计算量是难以接受的 (每一项都需要计算前向传播过程f ( y ) f(y)f (y ))。所以, 还是需要回到采样上去,如果能够采样若干个点就能得到期望的有效估计,并且还不损失梯度信息,那自然是最好了

; Gumbel Max

  • 为此,需要先引入 Gumbel Max。假设每个类别的概率是p 1 , p 2 , … , p k p_1,p_2,…,p_k p 1 ​,p 2 ​,…,p k ​,那么 Gumbel Max 提供了一种 依概率采样类别的方案:
    重参数 (Reparameterization)
  • 可以证明,这样的过程精确等价于依概率p 1 , p 2 , … , p k p_1,p_2,…,p_k p 1 ​,p 2 ​,…,p k ​ 采样一个类别,换句话说, 在 Gumbel Max 中,输出i i i 的概率正好是p i p_i p i ​. 不失一般性,这里我们证明输出 1 的概率是p 1 p_1 p 1 ​. 注意,输出 1 意味着log ⁡ p 1 − l o g ( − l o g ε 1 ) \log p_1−log(−logε_1)lo g p 1 ​−l o g (−l o g ε1 ​) 是最大的,这又意味着:
    log ⁡ p 1 − log ⁡ ( − log ⁡ ε 1 ) > log ⁡ p 2 − log ⁡ ( − log ⁡ ε 2 ) log ⁡ p 1 − log ⁡ ( − log ⁡ ε 1 ) > log ⁡ p 3 − log ⁡ ( − log ⁡ ε 3 ) ⋮ log ⁡ p 1 − log ⁡ ( − log ⁡ ε 1 ) > log ⁡ p k − log ⁡ ( − log ⁡ ε k ) \begin{aligned} &\log p_1 – \log(-\log \varepsilon_1) > \log p_2 – \log(-\log \varepsilon_2) \ &\log p_1 – \log(-\log \varepsilon_1) > \log p_3 – \log(-\log \varepsilon_3) \ &\qquad \vdots\ &\log p_1 – \log(-\log \varepsilon_1) > \log p_k – \log(-\log \varepsilon_k) \end{aligned}​lo g p 1 ​−lo g (−lo g ε1 ​)>lo g p 2 ​−lo g (−lo g ε2 ​)lo g p 1 ​−lo g (−lo g ε1 ​)>lo g p 3 ​−lo g (−lo g ε3 ​)⋮lo g p 1 ​−lo g (−lo g ε1 ​)>lo g p k ​−lo g (−lo g εk ​)​不失一般性,我们只分析第一个不等式,化简后得到:
    ε 2 < ε 1 p 2 / p 1 ≤ 1 \varepsilon_2 < \varepsilon_1^{p_2 / p_1}\leq 1 ε2 ​<ε1 p 2 ​/p 1 ​​≤1由于ε 2 ∼ U [ 0 , 1 ] ε_2∼U[0,1]ε2 ​∼U [0 ,1 ],所以ε 2 < ε 1 p 2 / p 1 ε_2 的概率就是 ε 1 p 2 / p 1 ε^{p_2/p_1}_1 ε1 p 2 ​/p 1 ​​,这就是固定 ε 1 ε_1 ε1 ​ 的情况下,第一个不等式成立的概率。那么,所有不等式同时成立的概率是
    ε 1 p 2 / p 1 ε 1 p 3 / p 1 … ε 1 p k / p 1 = ε 1 ( p 2 + p 3 + ⋯ + p k ) / p 1 = ε 1 ( 1 / p 1 ) − 1 \varepsilon_1^{p_2 / p_1}\varepsilon_1^{p_3 / p_1}\dots \varepsilon_1^{p_k / p_1}=\varepsilon_1^{(p_2 + p_3 + \dots + p_k) / p_1}=\varepsilon_1^{(1/p_1)-1}ε1 p 2 ​/p 1 ​​ε1 p 3 ​/p 1 ​​…ε1 p k ​/p 1 ​​=ε1 (p 2 ​+p 3 ​+⋯+p k ​)/p 1 ​​=ε1 (1 /p 1 ​)−1 ​然后对所有 ε 1 ε_1 ε1 ​ 求平均,就是
    ∫ 0 1 ε 1 ( 1 / p 1 ) − 1 d ε 1 = p 1 \int_0^1 \varepsilon_1^{(1/p_1)-1}d\varepsilon_1 = p_1 ∫0 1 ​ε1 (1 /p 1 ​)−1 ​d ε1 ​=p 1 ​

Gumbel Softmax

  • 我们希望重参数不丢失梯度信息,但是 Gumbel Max 做不到,因为arg max ⁡ \argmax a r g m a x 不可导,为此,需要做进一步的近似。首先,留意到在神经网络中,处理离散输入的基本方法是转化为 one hot 形式,包括 Embedding 层的本质也是 one hot 全连接,因此arg max ⁡ \argmax a r g m a x 实际上是one_hot ( arg max ⁡ ) \text{one_hot}(\argmax)one_hot (a r g m a x ),然后,我们寻求one_hot ( arg max ⁡ ) \text{one_hot}(\argmax)one_hot (a r g m a x ) 的光滑近似,它就是s o f t m a x softmax s o f t m a x. 由此,我们得到 Gumbel Max 的光滑近似版本——Gumbel Softmax
    重参数 (Reparameterization)
    重参数 (Reparameterization)
  • 跟连续情形一样,Gumbel Softmax 就是用在需要求E y ∼ p θ ( y ) [ f ( y ) ] \mathbb{E}{y\sim p{\theta}(y)}[f(y)]E y ∼p θ​(y )​[f (y )]、且无法直接完成对y y y 求和的场景,这时候我们算出p θ ( y ) p_θ(y)p θ​(y )(或者o i o_i o i ​),然后选定一个τ > 0 τ>0 τ>0,用 Gumbel Softmax 算出一个随机向量来y ~ \tilde y y ~​,代入计算得到f ( y ~ ) f(\tilde y)f (y ~​),它就是E y ∼ p θ ( y ) [ f ( y ) ] \mathbb{E}{y\sim p{\theta}(y)}[f(y)]E y ∼p θ​(y )​[f (y )] 的一个好的近似,且保留了梯度信息
  • 注意,Gumbel Softmax 不是类别采样的等价形式,Gumbel Max 才是。而 Gumbel Max 可以看成是 Gumbel Softmax 在τ → 0 τ→0 τ→0 时的极限。当τ ττ 比较小时,Gumbel Softmax 采样得到的样本接近 one-hot vector,也就比较接近实际的采样情况,但梯度的方差比较大;当τ ττ 比较大时,Gumbel Softmax 采样得到的样本比较平滑 (一个平滑的概率向量,向量的各个分量的值都差不多),但梯度的方差比较小。所以 *在应用 Gumbel Softmax 时,开始可以选择较大的τ ττ (比如 1),然后慢慢退火到一个接近于 0 的数(比如 0.01),这样才能得到比较好的结果

Gumbel Softmax v.s. Softmax

  • Gumbel Softmax 通过τ → 0 τ→0 τ→0 的退火来逐渐逼近 one hot,相比直接用原始的 Softmax 进行退火,区别在于 *原始 Softmax 退火只能得到最大值位置为 1 的 one hot 向量,而 Gumbel Softmax 有概率得到非最大值位置的 one hot 向量,增加了随机性,会使得基于采样的训练更充分一些

; Straight-Through Gumbel-Softmax Estimator

  • 由 Gumbel Softmax 得到的采样样本是实际采样样本的一个近似,它甚至都不在离散变量的取值范围之内,即使τ ττ 比较小,Gumbel Softmax 采样得到的样本也只是接近 one-hot vector,而非真正离散化的 one-hot vector. 但 总存在那么一些场景,我们只想采样离散值而非连续值 (e.g. RL 中从离散的动作空间中采样)
  • 假设 Gumbel Softmax 输出的采样向量为y y y,为了 利用 Gumbel Softmax 采样离散值,我们可以在前向传播时使用z = one_hot ( arg max ⁡ y ) z=\text{one_hot}(\argmax y)z =one_hot (a r g m a x y ) 得到离散的采样值,在反向传播时利用∇ θ z ≈ ∇ θ y \nabla_\theta z\approx \nabla_\theta y ∇θ​z ≈∇θ​y,对∇ θ y \nabla_\theta y ∇θ​y 进行梯度回传:
    z = y + s g ( one_hot ( arg max ⁡ y ) − y ) z=y+sg(\text{one_hot}(\argmax y)-y)z =y +s g (one_hot (a r g m a x y )−y )其中,s g sg s g 为 stop gradient 操作

背后的故事: 梯度估计 (gradient estimator)

  • 重参数就这样介绍完了吗?远远没有,重参数的背后,实际上是一个称为 ” 梯度估计“的 大家族,而重参数只不过是这个大家族中的一员。每年的 ICLR、ICML 等顶会上搜索 gradient estimatorREINFORCE 等关键词,可以搜索到不少文章,说明这是个大家还在钻研的课题。要想说清重参数的来龙去脉,也要说些梯度估计的故事

SF 估计 (Score Function Estimator)

  • 前面我们分别讲了连续型和离散型的重参数,都是在 “loss 层面” 讲述的,也就是说都是想办法把 loss 显式地定义好,剩下的交给框架自动求导、自动优化就是了。而事实上,就 算不能显式地写出 loss 函数,也不妨碍我们对它求导,自然也不妨碍我们去用梯度下降了。比如 Score Function Estimator
    重参数 (Reparameterization)
  • 同时注意到, *重参数技巧要求f f f 可导,但是在诸如强化学习的场景下,f ( z ) f(z)f (z ) 对应着奖励函数,很难做到光滑可导,此时就必须使用 SF 估计

; 梯度方差

  • SF 估计看上去很美好,得到了一个连续和离散变量都适用的估计式,那 为什么还需要重参数呢?主要的原因是: SF 估计的方差太大。SF 估计是函数f ( z ) ∂ ∂ θ log ⁡ p θ ( z ) f(z) \frac{\partial}{\partial\theta} \log p_{\theta}(z)f (z )∂θ∂​lo g p θ​(z ) 在分布p θ ( z ) p_θ(z)p θ​(z ) 下的期望,我们要采样几个点来算 (理想情况下,希望只采样一个点),换句话说,我们想用下面的近似
    重参数 (Reparameterization)

降方差

  • 重参数就是一种降方差技巧,为此,我们写出重参数后的梯度表达式:
    重参数 (Reparameterization)

; References

Original: https://blog.csdn.net/weixin_42437114/article/details/125671285
Author: 连理o
Title: 重参数 (Reparameterization)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/615227/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球