Contents
- 基本概念
- 连续情形
- 离散情形
* - Gumbel Max
- Gumbel Softmax
- Straight-Through Gumbel-Softmax Estimator
- 背后的故事: 梯度估计 (gradient estimator)
* - SF 估计 (Score Function Estimator)
- 梯度方差
- 降方差
- References
基本概念
- 重参数 (Reparameterization) 实际上是处理如下期望形式的目标函数的一种技巧:
- 重参数假设从分布p θ ( z ) p_θ(z)p θ(z ) 中采样可以分解为 两个步骤:(1) 从无参数分布q ( ε ) q(ε)q (ε) 中采样一个ε εε;(2) 通过变换z = g θ ( ε ) z=g_θ(ε)z =g θ(ε) 生成z z z。那么,上述期望就变成了
; 连续情形
- 简单起见,我们先考虑z z z 为连续随机变量的情形:
- 总的来说,连续情形的重参数还是比较简单的。从数学本质来看,重参数是一种 积分变换,即原来是关于z z z 积分,通过z = g θ ( ε ) z=g_θ(ε)z =g θ(ε) 变换之后得到新的积分形式。一个最简单的例子就是 正态分布:对于正态分布来说,重参数就是 “从N ( z ; μ θ , σ θ 2 ) N(z;μ_θ,σ^2_θ)N (z ;μθ,σθ2 ) 中采样一个z z z” 变成 “从N ( ε ; 0 , 1 ) N(ε;0,1)N (ε;0 ,1 ) 中采样一个ε εε,然后计算ε × σ θ + μ θ ε×σ_θ+μ_θε×σθ+μθ”,所以
离散情形
- 为了突出 “离散”,我们将随机变量z z z 换成y y y,即对于离散情形要面对的目标函数是
- 看到上述期望项中的求和,第一反应可能是 “求和?那就求呗,又不是求不了”。的确, 对于离散的随机变量,其期望只不过是有限项求和,理论上确实可以直接完成求和再去梯度下降。但是,如果k k k 特别大呢?举个例子,假设y y y 是一个 100 维的向量,每个元素不是 0 就是 1,那么所有不同的y y y 的总数目就是2 100 2^{100}2 1 0 0,要对这样的2 100 2^{100}2 1 0 0 个单项进行求和, 计算量是难以接受的 (每一项都需要计算前向传播过程f ( y ) f(y)f (y ))。所以, 还是需要回到采样上去,如果能够采样若干个点就能得到期望的有效估计,并且还不损失梯度信息,那自然是最好了
; Gumbel Max
- 为此,需要先引入 Gumbel Max。假设每个类别的概率是p 1 , p 2 , … , p k p_1,p_2,…,p_k p 1 ,p 2 ,…,p k ,那么 Gumbel Max 提供了一种 依概率采样类别的方案:
- 可以证明,这样的过程精确等价于依概率p 1 , p 2 , … , p k p_1,p_2,…,p_k p 1 ,p 2 ,…,p k 采样一个类别,换句话说, 在 Gumbel Max 中,输出i i i 的概率正好是p i p_i p i . 不失一般性,这里我们证明输出 1 的概率是p 1 p_1 p 1 . 注意,输出 1 意味着log p 1 − l o g ( − l o g ε 1 ) \log p_1−log(−logε_1)lo g p 1 −l o g (−l o g ε1 ) 是最大的,这又意味着:
log p 1 − log ( − log ε 1 ) > log p 2 − log ( − log ε 2 ) log p 1 − log ( − log ε 1 ) > log p 3 − log ( − log ε 3 ) ⋮ log p 1 − log ( − log ε 1 ) > log p k − log ( − log ε k ) \begin{aligned} &\log p_1 – \log(-\log \varepsilon_1) > \log p_2 – \log(-\log \varepsilon_2) \ &\log p_1 – \log(-\log \varepsilon_1) > \log p_3 – \log(-\log \varepsilon_3) \ &\qquad \vdots\ &\log p_1 – \log(-\log \varepsilon_1) > \log p_k – \log(-\log \varepsilon_k) \end{aligned}lo g p 1 −lo g (−lo g ε1 )>lo g p 2 −lo g (−lo g ε2 )lo g p 1 −lo g (−lo g ε1 )>lo g p 3 −lo g (−lo g ε3 )⋮lo g p 1 −lo g (−lo g ε1 )>lo g p k −lo g (−lo g εk )不失一般性,我们只分析第一个不等式,化简后得到:
ε 2 < ε 1 p 2 / p 1 ≤ 1 \varepsilon_2 < \varepsilon_1^{p_2 / p_1}\leq 1 ε2 <ε1 p 2 /p 1 ≤1由于ε 2 ∼ U [ 0 , 1 ] ε_2∼U[0,1]ε2 ∼U [0 ,1 ],所以ε 2 < ε 1 p 2 / p 1 ε_2 的概率就是 ε 1 p 2 / p 1 ε^{p_2/p_1}_1 ε1 p 2 /p 1 ,这就是固定 ε 1 ε_1 ε1 的情况下,第一个不等式成立的概率。那么,所有不等式同时成立的概率是
ε 1 p 2 / p 1 ε 1 p 3 / p 1 … ε 1 p k / p 1 = ε 1 ( p 2 + p 3 + ⋯ + p k ) / p 1 = ε 1 ( 1 / p 1 ) − 1 \varepsilon_1^{p_2 / p_1}\varepsilon_1^{p_3 / p_1}\dots \varepsilon_1^{p_k / p_1}=\varepsilon_1^{(p_2 + p_3 + \dots + p_k) / p_1}=\varepsilon_1^{(1/p_1)-1}ε1 p 2 /p 1 ε1 p 3 /p 1 …ε1 p k /p 1 =ε1 (p 2 +p 3 +⋯+p k )/p 1 =ε1 (1 /p 1 )−1 然后对所有 ε 1 ε_1 ε1 求平均,就是
∫ 0 1 ε 1 ( 1 / p 1 ) − 1 d ε 1 = p 1 \int_0^1 \varepsilon_1^{(1/p_1)-1}d\varepsilon_1 = p_1 ∫0 1 ε1 (1 /p 1 )−1 d ε1 =p 1
Gumbel Softmax
- 我们希望重参数不丢失梯度信息,但是 Gumbel Max 做不到,因为arg max \argmax a r g m a x 不可导,为此,需要做进一步的近似。首先,留意到在神经网络中,处理离散输入的基本方法是转化为 one hot 形式,包括 Embedding 层的本质也是 one hot 全连接,因此arg max \argmax a r g m a x 实际上是one_hot ( arg max ) \text{one_hot}(\argmax)one_hot (a r g m a x ),然后,我们寻求one_hot ( arg max ) \text{one_hot}(\argmax)one_hot (a r g m a x ) 的光滑近似,它就是s o f t m a x softmax s o f t m a x. 由此,我们得到 Gumbel Max 的光滑近似版本——Gumbel Softmax:
- 跟连续情形一样,Gumbel Softmax 就是用在需要求E y ∼ p θ ( y ) [ f ( y ) ] \mathbb{E}{y\sim p{\theta}(y)}[f(y)]E y ∼p θ(y )[f (y )]、且无法直接完成对y y y 求和的场景,这时候我们算出p θ ( y ) p_θ(y)p θ(y )(或者o i o_i o i ),然后选定一个τ > 0 τ>0 τ>0,用 Gumbel Softmax 算出一个随机向量来y ~ \tilde y y ~,代入计算得到f ( y ~ ) f(\tilde y)f (y ~),它就是E y ∼ p θ ( y ) [ f ( y ) ] \mathbb{E}{y\sim p{\theta}(y)}[f(y)]E y ∼p θ(y )[f (y )] 的一个好的近似,且保留了梯度信息
- 注意,Gumbel Softmax 不是类别采样的等价形式,Gumbel Max 才是。而 Gumbel Max 可以看成是 Gumbel Softmax 在τ → 0 τ→0 τ→0 时的极限。当τ ττ 比较小时,Gumbel Softmax 采样得到的样本接近 one-hot vector,也就比较接近实际的采样情况,但梯度的方差比较大;当τ ττ 比较大时,Gumbel Softmax 采样得到的样本比较平滑 (一个平滑的概率向量,向量的各个分量的值都差不多),但梯度的方差比较小。所以 *在应用 Gumbel Softmax 时,开始可以选择较大的τ ττ (比如 1),然后慢慢退火到一个接近于 0 的数(比如 0.01),这样才能得到比较好的结果
Gumbel Softmax v.s. Softmax
- Gumbel Softmax 通过τ → 0 τ→0 τ→0 的退火来逐渐逼近 one hot,相比直接用原始的 Softmax 进行退火,区别在于 *原始 Softmax 退火只能得到最大值位置为 1 的 one hot 向量,而 Gumbel Softmax 有概率得到非最大值位置的 one hot 向量,增加了随机性,会使得基于采样的训练更充分一些
; Straight-Through Gumbel-Softmax Estimator
- 由 Gumbel Softmax 得到的采样样本是实际采样样本的一个近似,它甚至都不在离散变量的取值范围之内,即使τ ττ 比较小,Gumbel Softmax 采样得到的样本也只是接近 one-hot vector,而非真正离散化的 one-hot vector. 但 总存在那么一些场景,我们只想采样离散值而非连续值 (e.g. RL 中从离散的动作空间中采样)
- 假设 Gumbel Softmax 输出的采样向量为y y y,为了 利用 Gumbel Softmax 采样离散值,我们可以在前向传播时使用z = one_hot ( arg max y ) z=\text{one_hot}(\argmax y)z =one_hot (a r g m a x y ) 得到离散的采样值,在反向传播时利用∇ θ z ≈ ∇ θ y \nabla_\theta z\approx \nabla_\theta y ∇θz ≈∇θy,对∇ θ y \nabla_\theta y ∇θy 进行梯度回传:
z = y + s g ( one_hot ( arg max y ) − y ) z=y+sg(\text{one_hot}(\argmax y)-y)z =y +s g (one_hot (a r g m a x y )−y )其中,s g sg s g 为 stop gradient 操作
背后的故事: 梯度估计 (gradient estimator)
- 重参数就这样介绍完了吗?远远没有,重参数的背后,实际上是一个称为 ” 梯度估计“的 大家族,而重参数只不过是这个大家族中的一员。每年的 ICLR、ICML 等顶会上搜索 gradient estimator、 REINFORCE 等关键词,可以搜索到不少文章,说明这是个大家还在钻研的课题。要想说清重参数的来龙去脉,也要说些梯度估计的故事
SF 估计 (Score Function Estimator)
- 前面我们分别讲了连续型和离散型的重参数,都是在 “loss 层面” 讲述的,也就是说都是想办法把 loss 显式地定义好,剩下的交给框架自动求导、自动优化就是了。而事实上,就 算不能显式地写出 loss 函数,也不妨碍我们对它求导,自然也不妨碍我们去用梯度下降了。比如 Score Function Estimator:
- 同时注意到, *重参数技巧要求f f f 可导,但是在诸如强化学习的场景下,f ( z ) f(z)f (z ) 对应着奖励函数,很难做到光滑可导,此时就必须使用 SF 估计
; 梯度方差
- SF 估计看上去很美好,得到了一个连续和离散变量都适用的估计式,那 为什么还需要重参数呢?主要的原因是: SF 估计的方差太大。SF 估计是函数f ( z ) ∂ ∂ θ log p θ ( z ) f(z) \frac{\partial}{\partial\theta} \log p_{\theta}(z)f (z )∂θ∂lo g p θ(z ) 在分布p θ ( z ) p_θ(z)p θ(z ) 下的期望,我们要采样几个点来算 (理想情况下,希望只采样一个点),换句话说,我们想用下面的近似
降方差
- 重参数就是一种降方差技巧,为此,我们写出重参数后的梯度表达式:
; References
- 苏剑林. (Jun. 10, 2019). 《漫谈重参数:从正态分布到 Gumbel Softmax 》[Blog post]. Retrieved from https://kexue.fm/archives/6705
- Gumbel Softmax paper:Jang, Eric, et al. “Categorical Reparameterization with Gumbel-Softmax.” 5th International Conference on Learning Representations, ICLR 2017
Original: https://blog.csdn.net/weixin_42437114/article/details/125671285
Author: 连理o
Title: 重参数 (Reparameterization)
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/615227/
转载文章受原作者版权保护。转载请注明原作者出处!