算法学习之gumbel softmax

1. gumbel_softmax有什么用呢?

假设如下场景:
模型训练过程中, 网络的输出为p = [0.1, 0.7, 0.2], 三个数值分别为”向左”, “向上”, “向右”的概率。 我们的决策可能是y = argmax§, 也即选择”向上”这条决策。
但是,这样做会有两个问题:

  1. argmax()函数是不可导的。这样网络就无法通过反向传播进行学习。
  2. argmax()的选择不具有随机性。同样的输出p选择100次,每次的结果都为”向上”。而按照概率为0.7的含义,100次应该有70次左右的决策结果是选择”向上”.

而gumbel_softmax的作用就是解决上述这两个子问题.。

2.argmax(x)是什么?为什么不可导?

为了更直观,这里使用两维的vector
y = argmax(x); x = (x1, x2)

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch

class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        FancyArrowPatchcyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs
    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
        self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))
        FancyArrowPatch.draw(self, renderer)

xs = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
ys = [0.5, 0.4, 0.3, 0.2 , 0.1, 0.0]
zs = [0, 0, 0, 0, 0, 0]
fig = plt.figure()
ax = axisartist.Subplot(fig, 111)
ax = fig.add_axes((0.1,0.1,0.8,0.8), projection='3d')
ax.plot3D(xs, ys, zs, c='red', marker='o')
ys = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
xs = [0.5, 0.4, 0.3, 0.2 , 0.1, 0.0]
zs = [1, 1, 1, 1, 1, 1]
ax.plot3D(xs, ys, zs, c='blue', marker='o')
plt.xlim(0, 1)
plt.ylim(0, 1)
ax.view_init(azim=30, elev=30)
plt.show()

算法学习之gumbel softmax
多元函数可微分的充分条件是函数连续且具有偏导数. 从argmax的三维图可以看出, argmax(x), 首先在x1 = x2处不连续,因此在该点处必定是不可导的. 在红线处, 保持x1不变, 求 y相对于x2的偏微分,发现是不存在的.因为x1不变的情况下,x2也是无法有一个微小的变动. 故, argmax()函数不可微分.

3. 引入随机性:gumbel分布

为了在y=argmax§中引入随机性, 将其修改为y = argmax(log§ + G).G称之为gumbel分布, 它的数学表达式为G=-log(-log(ξ \xi ξ)))。引入该分布的作用是引入了随机性,且该随机性保证了该分布输出i的概率等于pi。下面是科学空间上的证明,比较容易理解。

算法学习之gumbel softmax

; 4. 解决不可导:gumbel_softmax

  1. 解决不可导的方法可以用gumbel_softmax来处理。也即forward阶段,使用argmax操作,暂时不用管后面反向操作;但在反向阶段则使用gumbel_softmax来做bp计算,可以通过看pytorch中相关代码块有一个很清晰的认知。
def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor:
    。。。
    gumbels = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    )
    gumbels = (logits + gumbels) / tau
    y_soft = gumbels.softmax(dim)

    if hard:

        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:

        ret = y_soft
    return ret
  1. 这样做是没有问题的,但是前向的y_hard,与y_softmax我们还是要尽可能缩小它们之间的”误差”,因此gumbel_softmax中引入了温度t, t越小,softmax就越接近One-hot。为了训练稳定性,一般t会取一个比较大的数字,然后逐步缩小。

Original: https://blog.csdn.net/u011345885/article/details/122610352
Author: 学弟
Title: 算法学习之gumbel softmax

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/613308/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • 疫情下,工业设计公司的机遇与挑战

    始料未及疫情,弄乱许多和规划, 工业设计公司都在所难免受影响,生产受到影响,营销推广高效率降低。能不能优良解决本次困境,是公司可持续发展的重要因素。 祸兮福之所倚,”危…

    人工智能 2023年6月29日
    079
  • Pytorch入门(三) 训练 / 测试模型

    上一篇文章中讲解了神经网络模型的编写,一般情况下,我们只需要对现有的网络模型进行修改就可以了,那这篇文章就进入到最重要的部分了,也就是网络模型的训练和测试。其实对于分类和回归的模型…

    人工智能 2023年7月14日
    089
  • Matlab回归分析

    线性回归:在实际中,对于情况较复杂的实际问题(因素不易化简,作用机理不详)可直接使用数据组建模,寻找简单的因果变量之间的数量关系, 从而对未知的情形作预报。这样组建的模型为拟合模型…

    人工智能 2023年6月16日
    079
  • 强化学习代码实战

    强化学习代码实战 注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV N…

    人工智能 2023年7月22日
    059
  • 定序回归模型

    定序回归的因变量是定序变量,数据类型是顺序数据。比如不满意,一般,满意;不合格,合格,优秀等。 假设因变量是评分,先由单变量回归说起,则普通的线性回归模型为:s c o r e =…

    人工智能 2023年6月17日
    0117
  • XGB(有监督学习)和多维时序模型结合——预测风电出力

    新能源风力发电机上保存有很多实时传感器的感应数据。 解决的问题: 1,想要通过传感器数据预测未来一段时间出力功率。2,单XGB等有监督的机器学习模型,根据输入感应器数据预测出力功率…

    人工智能 2023年6月23日
    0117
  • SPSS软件的数据分析与GDP和人口老龄化的预测

    目录前言问题二模型的建立与求解1.2.1 ARIMA时间序列模型的建立与求解平稳性检验的时间序列预测模型的建立与求解​​​​​​​模型的检验​​​​​​​对于的时间序列预测模型的建…

    人工智能 2023年7月15日
    0196
  • [机器学习]Logistic回归

    目录 什么是逻辑斯蒂(Logistic)回归? 1.线性回归函数 2. 逻辑函数(Sigmoid函数) 3. Logistic回归函数 Logistic回归分类器 梯度上升算法 p…

    人工智能 2023年6月18日
    093
  • 论文浅尝 | Seq2Seq 知识图谱补全与问答

    笔记整理:李行,天津大学硕士论文题目:Sequence-to-Sequence Knowledge Graph Completion and Question Answering链…

    人工智能 2023年6月1日
    084
  • 基于ResNet50的CIFAR10分类

    本次运用了 ResNet50进行了图像分类处理(基于Pytorch) 一、数据集 CIFAR-10数据集共有60000张彩色图像,这些图像是32*32,分为10个类,每类6000张…

    人工智能 2023年6月25日
    067
  • 边缘计算 | 在移动设备上部署深度学习模型的思路与注意点 ⛵

    💡 作者:韩信子@ShowMeAI📘 深度学习◉技能提升系列:https://www.showmeai.tech/tutorials/35📘 深度学习实战系列:https://ww…

    人工智能 2023年6月4日
    0116
  • python机器学习实现对基于TCP协议的DDOS攻击的流量监测器

    文章目录 一、Wireshark抓包工具使用以及数据包分析 * 1.数据包筛选 2.数据包搜索 3.数据包分析 二、使用python库进行流量特征提取 * 1.下载scapy库 2…

    人工智能 2023年7月2日
    0118
  • nnunet详细预处理过程

    重采样 代码部分整理 import SimpleITK import numpy as np def get_target_spacing(spacings,sizes): ”’…

    人工智能 2023年6月21日
    079
  • 问EXCEL、Python、BI到底谁才是数据分析中的佼佼者?

    俗话说的好: 有人的地方就有鄙视圈,就像学C/C++的看不起学JAVA,学JAVA看不起学PHP,学PHP看不起学VBA的。在数据分析行业也存在着这样的鄙视链:学Python看不起…

    人工智能 2023年7月16日
    0126
  • JS/html5前端合成语音(播报)

    要在前端实现语音合成,即将文字讲述出来,一开始考虑用百度语音合成的方法,后来发现html5 本身就支持语音合成。就直接用html5的咯,百度的那个还有调用次数限制,配置还麻烦HTM…

    人工智能 2023年5月27日
    0101
  • 机器学习之回归与分类

    机器学习是? 在认识世界过程,类似于从一个已知量再到未知的函数。机器学习,就是预测这个函数,且使得预测结果尽量准确。收集一大堆数据,然后用训练数据集去预测一个值,称为回归问题。例如…

    人工智能 2023年6月18日
    078
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球