训练一个专门捣乱的模型

三位韩国人在EMNLP 2021 Findings上发表了一篇论文,名为Devil’s Advocate: Novel Boosting Ensemble Method from Psychological Findings for Text Classification,其中Devil’s Advocate有一部同名电影,翻译过来叫「魔鬼代言人」,他们主要挑战的是传统模型融合的方法,例如硬投票(Hard-Voting)、软投票(Soft Voting)、Bagging等。源码在HwiyeolJo/DevilsAdvocate

在群体决策过程中,大部分人会根据既定思维进行思考,而Devil’s Advocate是指那些提出的意见与大多数人不一致的那个人,Devil’s Advocate的存在可以激发群体的头脑风暴,打破固化思维。以上内容参考维基百科恶魔的代言人

Ensembles

在具体讲解作者的方法前,先简单过一下常见的模型融合方法

Soft Voting

软投票是对不同模型的预测分数进行加权平均,例如有一个三分类问题,第一个模型对某个样本的预测概率为[0.2,0.1,0.7];第二个模型对该样本的预测概率为[0.2,0.6,0.2];第三个模型对该样本的预测概率为[0.1,0.7,0.2],假设三个模型的投票权重均为1 3 \frac{1}{3}3 1 ​,则该样本最终的预测概率为
[ 0.5 3 , 1.4 3 , 1.1 3 ] [\frac{0.5}{3}, \frac{1.4}{3}, \frac{1.1}{3}][3 0 .5 ​,3 1 .4 ​,3 1 .1 ​]
所以最终这个样本被预测为第2类。不过事实上很多时候模型有好有坏,所以我们的权重不一定是平均的,对于模型比较厉害的模型,我们会给他比较大的话语权(投票权重)

Hard Voting

硬投票可以看作是软投票的一个变种,还是以上面三个模型预测的概率分布为例。第一个模型预测样本为第2类,第二、三个模型都认为样本是第2类,根据少数服从多数原则,该样本就被认为是第2类

Bagging

Bagging方法的核心思想是「民主」。首先从训练集中有放回地随机采样一些样本,采样n次,训练出n个弱模型,利用这n个模型采用投票的方式得到分类结果,如果是回归问题则是计算模型输出的均值作为最后的结果

Boosting

Boosting的核心思想是「挑选精英」。Boosting与Bagging最本质的区别在于它对弱模型不是一致对待的,而是经过不停的考验和筛选来挑出「精英」,然后给精英更多的投票权,表现不好的模型则给较少的投票权

Proposed Method: Devil’s Advocate

Training Norm and DevAdv models

无论你是用什么方法做模型融合,至少都需要2个以上的模型。作者提出的方法至少需要3个模型,这些模型会被分成两个阵营:Normal models (Norm n \text{Norm}_n Norm n ​, n ≥ 2 n\ge 2 n ≥2)、Devil’s Advocate model (DevAdv)

首先我们使用传统的Cross Entropy Loss训练Norm n \text{Norm}n Norm n ​模型
L Train-Norm n = CE ( Softmax ( Y Norm n ) , Y true ) (1) \mathcal{L}
{\text{Train-Norm}n} = \text{CE}(\text{Softmax}(\mathbf{Y}{\text{Norm}n}), \mathbf{Y}{\text{true}})\tag{1}L Train-Norm n ​​=CE (Softmax (Y Norm n ​​),Y true ​)(1 )
其中,Y Norm n \mathbf{Y}{\text{Norm}_n}Y Norm n ​​是Norm n \text{Norm}_n Norm n ​模型的预测值,Y true \mathbf{Y}{\text{true}}Y true ​是真实标签。与训练Norm n \text{Norm}n Norm n ​模型相反的是,我们需要随机生成与真实标签不相交的错误标签来训练DevAdv模型(不相交指的是没有任何一个样本的错误标签和真实标签相同),生成的错误标签为Y false \mathbf{Y}{\text{false}}Y false ​,DevAdv模型的损失函数定义如下:
L Train-DevAdv = CE ( Softmax ( Y DevAdv ) , Y false ) (2) \mathcal{L}{\text{Train-DevAdv}} = \text{CE}(\text{Softmax}(\mathbf{Y}{\text{DevAdv}}), \mathbf{Y}{\text{false}})\tag{2}L Train-DevAdv ​=CE (Softmax (Y DevAdv ​),Y false ​)(2 )
由于DevAdv模型是用错误标签训练出来的,所以该模型充当了「魔鬼代言人」的角色,不同意其他模型的预测分布。特别地,我们可以通过检查arg ⁡ min ⁡ ( Y DevAdv ) \arg \min (\mathbf{Y}
{\text{DevAdv}})ar g min (Y DevAdv ​)是否为真实标签来评估DevAdv模型的性能

注意上面的函数是arg ⁡ min ⁡ \arg \min ar g min,不是arg ⁡ max ⁡ \arg \max ar g max,因为求arg ⁡ max ⁡ \arg \max ar g max很明显预测结果大部分是Y false \mathbf{Y}_{\text{false}}Y false ​,但如果真实类别被预测的概率为最小,即通过arg ⁡ min ⁡ \arg \min ar g min取到,我们就认为DevAdv非常会捣乱

Group Discussion: Fine-tuning

我看到这个标题的时候感觉很奇怪,这又不是预训练模型,怎么会有Fine-tuning阶段?仔细看了他们的代码之后才明白,他这个名字起的不好,不应该叫Fine-tuning,应该叫为Discussing或者Ensembles,即模型融合阶段。具体来说,之前我们已经把所有的模型都训练一遍了,接下来我们需要把DevAdv引入进来再训练一遍Norm n \text{Norm}_n Norm n ​模型。特别地,当前阶段只会更新Norm n \text{Norm}_n Norm n ​模型的参数,DevAdv模型的参数不会进行更新

给我一种感觉就像是:”DevAdv,你已经学会如何抬杠了,快去干扰Norm n \text{Norm}_n Norm n ​他们的讨论吧”

对于Norm n \text{Norm}n Norm n ​模型来说,此时的损失函数比较特殊
L Discuss-Norm 1 = CE ( Softmax ( Y Norm 1 + Sofmtax ( Y DevAdv ) ) , Y true ) + MSE ( Y Norm 1 , Y Norm 2 ) L Discuss-Norm 2 = CE ( Softmax ( Y Norm 2 + Sofmtax ( Y DevAdv ) ) , Y true ) + MSE ( Y Norm 2 , Y Norm 1 ) (3) \begin{aligned} \mathcal{L}
{\text{Discuss-Norm}1} &= \text{CE}(\text{Softmax}(\mathbf{Y}{\text{Norm}1} + \text{Sofmtax}(\mathbf{Y}{\text{DevAdv}})), \mathbf{Y}{\text{true}})\ &+\text{MSE}(\mathbf{Y}{\text{Norm}1}, \mathbf{Y}{\text{Norm}2})\\ \mathcal{L}{\text{Discuss-Norm}2} &= \text{CE}(\text{Softmax}(\mathbf{Y}{\text{Norm}2} + \text{Sofmtax}(\mathbf{Y}{\text{DevAdv}})), \mathbf{Y}{\text{true}})\ &+\text{MSE}(\mathbf{Y}{\text{Norm}2}, \mathbf{Y}{\text{Norm}_1}) \end{aligned}\tag{3}L Discuss-Norm 1 ​​L Discuss-Norm 2 ​​​=CE (Softmax (Y Norm 1 ​​+Sofmtax (Y DevAdv ​)),Y true ​)+MSE (Y Norm 1 ​​,Y Norm 2 ​​)=CE (Softmax (Y Norm 2 ​​+Sofmtax (Y DevAdv ​)),Y true ​)+MSE (Y Norm 2 ​​,Y Norm 1 ​​)​(3 )
只有DevAdv模型的输出进行了归一化,Norm n \text{Norm}_n Norm n ​模型不进行归一化,目的是为了使得Norm n \text{Norm}_n Norm n ​预测的分布值远大于归一化的DevAdv的值。在CE Loss中,DevAdv模型阻止Norm n \text{Norm}_n Norm n ​模型对真实标签进行正确拟合。但在「Discuss」过程中,即使有DevAdv模型的干扰,Norm n \text{Norm}_n Norm n ​模型最终也能正确预测真实标签,主要有以下几个原因:

  1. DevAdv模型在该阶段是不更新参数的,因此它相当于变成了一个固定的噪声生成器。Norm n \text{Norm}_n Norm n ​模型的参数是会随着损失进行调整的,所以肯定效果会慢慢变好,这是可以预见的
  2. 非常特别的一点在于MSE \text{MSE}MSE损失。Norm n \text{Norm}_n Norm n ​模型在「Discuss」的过程中会互相影响、学习其他Norm models的信息

最后,对测试集进行测试时,采用软投票的机制组合Norm n \text{Norm}n Norm n ​模型的结果。然后…然后就结束了吗?我们辛辛苦苦训练的DevAdv仅仅只是在「Discuss」阶段提供点噪声吗?来点作用啊DevAdv。仔细想一想,最开始在训练DevAdv模型的时候,我们评估它的指标是
arg ⁡ min ⁡ ( Y DevAdv ) = = Y true (4) \arg \min (\mathbf{Y}
{\text{DevAdv}}) == \mathbf{Y}{\text{true}}\tag{4}ar g min (Y DevAdv ​)==Y true ​(4 )
我们将Y DevAdv \mathbf{Y}
{\text{DevAdv}}Y DevAdv ​内的值全部取负数,并将arg ⁡ min ⁡ \arg \min ar g min改为arg ⁡ max ⁡ \arg \max ar g max,它的结果仍然没变
arg ⁡ max ⁡ ( − Y DevAdv ) = = Y true (5) \arg \max (-\mathbf{Y}{\text{DevAdv}}) == \mathbf{Y}{\text{true}}\tag{5}ar g max (−Y DevAdv ​)==Y true ​(5 )
但此时我们就可以让DevAdv一同参与到Norm n \text{Norm}n Norm n ​模型的测试过程中了,其实就相当于三个模型共同进行软投票,此时预测结果为
arg ⁡ max ⁡ ( ∑ n Y Norm n − Y DevAdv ) (6) \arg \max (\sum
{n}\mathbf{Y}{\text{Norm}_n} – \mathbf{Y}{\text{DevAdv}})\tag{6}ar g max (n ∑​Y Norm n ​​−Y DevAdv ​)(6 )

Results

这本质是一个模型融合的方法,理论上来说所有模型都是适用的。首先我们看一下消融研究

训练一个专门捣乱的模型

其中比较重要的去掉了DiscussLoss的部分,何谓DiscussLoss,其实就是Norm n \text{Norm}_n Norm n ​模型相互讨论的阶段,即MSE \text{MSE}MSE损失。去掉这部分后,除了Yelp数据集有些反常居然上升了,其他的都有不同程度的下降。同时作者证明他们的方法可以使用超过3个模型的情况,例如最后一行,他们使用了4个模型,其中有3个正常模型,一个DevAdv,效果虽然不如使用3个模型的情况(第一行),但是比常规的软投票还是要好一些,特别地,此时他们使用KL散度来替代MSE损失

接着作者分别采用TextCNN和基于Transformers的模型(论文里就没写到底用的是BERT还是RoBERTa等,如果直接用transformer不太可能,seq2seq的模型如何做分类?)做了一组实验

训练一个专门捣乱的模型

基本上作者所提出的方法都要比软投票好一些,不过我特别好奇的是硬投票,以及其他的一些模型融合方法为什么不对比下呢?

; 个人总结

首先我要吐槽的是作者的美感,原论文中的数学公式写的非常丑,基本上感觉是直接用 \text{}框住然后乱写一通,这里截个图给大家感受下

训练一个专门捣乱的模型
其次是上图中我红框框出来的部分(下面没有用红框框住的公式也一样),我觉得它这个公式写错了,漏了个Softmax \text{Softmax}Softmax,可以对比我这篇文章里的公式和他论文中的公式。最后是我觉得比较有意思的地方,因为单看「Discuss」的过程,DevAdv在里面充当的只是捣乱的角色,那为什么我不可以直接采样一个服从N ( 0 , σ 2 ) \mathcal{N}(0, \sigma^2)N (0 ,σ2 )的向量分布呢,用这个分布直接替换掉Y DevAdv \mathbf{Y}_{\text{DevAdv}}Y DevAdv ​,直到我看到了公式(6),我才明白DevAdv不仅仅是充当一个噪声生成器,实际上在最后Inference阶段,它也可以一起参与进来,而这一点是单纯采样一个向量分布所无法做到的。作者在他的文章中并没有做鲁棒性测试,实际上我觉得引入Devil’s model误导模型训练的过程是可以增加模型的鲁棒性的

Original: https://blog.csdn.net/qq_37236745/article/details/121469762
Author: 数学家是我理想
Title: 训练一个专门捣乱的模型

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/532157/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • docker 对 容器的管理

    容器的操作命令 1,查看容器: docker ps 列出所有正在运行的容器 docker ps 2,启动已停止的容器: docker start 启动一个或多个已经被停止的容器 d…

    大数据 2023年5月29日
    072
  • 我眼中的大数据(三)——MapReduce

    这次来聊聊Hadoop中使用广泛的分布式计算方案——MapReduce。 MapReduce是一种编程模型,还是一个分布式计算框架。 MapReduce作为一种编程模型功能强大,使…

    大数据 2023年6月2日
    059
  • 【云原生】阿里云 RocketMQ介绍

    目录 ​​简介​​ ​​为什么选择消息队列RocketMQ版​​ ​​1、架构先进性​​ ​​2、高性能​​ ​​3、稳定性SLA​​ ​​4、弹性低成本​​ ​​5、运维可观测​…

    大数据 2023年5月24日
    0100
  • 云开发两周年庆 — 游戏畅玩领好礼

    五一5天小假期的结束大家休息好了吗?上班了状态回整的怎么样呢?阿里云云发平台给大家带福利了呢,通过玩游戏把奖品带回家。 云开发平台两周年,0门槛部署上线4款热门游戏,游戏畅玩还有A…

    大数据 2023年5月27日
    075
  • LinearLayout和RelativeLayout的区别

    博客园 :当前访问的博文已被密码保护 请输入阅读密码: Original: https://www.cnblogs.com/hustdc/p/11947302.htmlAuthor…

    大数据 2023年5月28日
    0103
  • 「GoCN酷Go推荐」高性能中文分词库 gojieba

    gojieba 是什么? gojieba 是 Python 知名分词库结巴 jieba 的 Go 语言实现版本,底层分词算法由 C++ 实现,具备很高的性能; gojieba 解决…

    大数据 2023年5月28日
    074
  • docker官方教程 getting-started

    从docker hub拉取镜像到本地,docker是用户名,后面的是镜像名称 $ docker pull docker/getting-started 显示本地安装镜像 $ doc…

    大数据 2023年5月28日
    078
  • Redis架构之哨兵机制与集群

    Redis架构之哨兵机制与集群 哨兵机制 1、介绍: Sentinel(哨兵)是redis高可用性解决方案:由一个或多个由一个或多个Sentinel 实例 组成的Sentinel …

    大数据 2023年6月2日
    078
  • harbor仓库部署

    harbor仓库部署 harbor仓库部署 部署harbor 部署服务端 部署客户端 查看效果 Harbor简介 Harbor是由VMWare在Docker Registry的基础…

    大数据 2023年5月27日
    084
  • 384. Shuffle an Array

    Shuffle an Array 原创 wx62ea2466cca9a2022-08-03 21:21:46博主文章分类:leetcode ©著作权 文章标签 leetcode-j…

    大数据 2023年5月24日
    070
  • Building good docker images

    The docker registry is bursting at the seams. At the time of this writing, a search for &#…

    大数据 2023年5月29日
    058
  • Kafka 数据丢失问题总结

    是否真正的存在数据丢失问题,比如有很多时候可能是其他同事操作了测试环境,所以首先确保数据没有第三方干扰。 理清你的业务流程,数据流向,数据到底是在什么地方丢失的数据,在kafka …

    大数据 2023年5月28日
    062
  • (1)通过FlinkSQL将数据写入mysql demo

    (1)通过FlinkSQL将数据写入mysql demo 原创 wx5d37d5fd4aa622022-08-13 00:33:03©著作权 文章标签 Flink 大数据 Flin…

    大数据 2023年5月24日
    084
  • 如何使用DBeaver连接Hive

    1 DBeaver介绍 DBeaver是一个通用的数据库管理工具和 SQL 客户端,支持多种兼容 JDBC 的数据库。DBeaver 提供一个图形界面用来查看数据库结构、执行SQL…

    大数据 2023年6月2日
    099
  • openwrt临时封禁ip

    用的openwrt路由器,家里宽带申请了动态公网ip,为了方便把22 80端口映射到公网,发现经常被暴力破解,自己写了个临时封禁ip功能的脚本,实现5分钟内同一个ip登录密码错误1…

    大数据 2023年5月27日
    077
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球