【论文笔记】(知识蒸馏)Distilling the Knowledge in a Neural Network

摘要

模型平均可以提高算法的性能,但是计算量大且麻烦,难以部署给用户。《模型压缩》这篇论文中表明,知识可以从复杂的大型模型或由多个模型构成的集成模型中压缩并转移到一个小型模型中,本文基于这一观点做出了进一步研究:通过 知识蒸馏(knowledge distillation)显著提高了转移后的小型模型的性能,此外还提出了一种新的集成模型,它由一个或多个完整模型再加多个specialist models(区别是:完整模型无法细粒度分类,specialist模型可以)组成。

名词解释

  • 教师模型 teacher model:一个单个的复杂大型模型 或 由多个模型组成的一个集成模型,知识从教师模型转出;
  • 学生模型 student model:一个小型的、简单的、易于部署的模型,知识转入学生模型;
  • transfer set:包含从教师模型中提取的知识,是学生模型的训练集;
  • transfer 阶段:知识从教师模型转移到学生模型的阶段,即学生模型的训练阶段;
  • soft target:教师模型输出的(每个类的)概率;
  • hard target:原始数据集自带的labels;
  • temperature:蒸馏的目标函数中的超参数,用于控制softmax函数的形状;

1 Introduction

首先训练一个复杂的大型模型/教师模型,蒸馏就是将这个大型模型学到的知识通过特殊的训练转移到小型模型/学生模型上。

作者认为,学生模型应该学习教师模型的泛化能力,而非数据拟合能力;如果教师模型的泛化能力强,学生模型经过学习训练后,就能够在测试集上表现很好。

将教师模型的泛化能力转移到学生模型的一个方法是:使用教师模型产生的(每个类的)概率作为训练学生模型的”soft targets”。在transfer阶段,可以使用相同的训练集或单独的 transfer set,transfer set可以全部由unlabeled的数据构成。当教师模型是集成模型时,可以使用各模型预测的分布的平均值作为soft targets。当soft targets的熵比hard targets高时,它们提供的信息比hard targets多,且训练梯度变化也会变小,所以小模型通常使用更少的数据进行训练,且使用更高的学习率。

对于MNIST这种简单的数据,大型模型的准确率很高,更多的知识是在非常小的软目标中。比如,一个数字2的图像被预测为: (10^{-6}) 的概率是数字3,(10^{-9}) 的概率是数字7;而对于另一个数字2的图像,预测结果可能相反。这些信息很有价值,它们定义了数据上丰富的相似性结构。但对与transfer阶段的交叉熵损失影响很小,因为概率都十分接近0,对此《模型压缩》使用 logits作为训练小型模型的目标来放大这些信息,而本文采用”蒸馏”,方法是提高 softmax 的 temperature 直到大型模型生成一组合适的soft targets,然后用这些soft targets来训练小型模型。

2 Distillation

神经网络通常使用 softmax 层来生成每个类的概率,softmax 层将 logit 层的输出值 (z_i) 转换为概率 (q_i):

[q_i = \frac{\exp(z_i/T)} {\sum_j \exp(z_j/T)} \tag{1} ]

其中 (T) 为 temperature,通常设置为 1,T 的值越高,生成的概率分布会越平滑(softer)。

transfer sets的两种形式:

  1. 全部由 soft targets 组成,也就是大型模型的 softmax 输出,使用与大型模型相同的 temperature 值;
  2. 由 soft targets 和 hard targets 组成,目标函数选择交叉熵。soft targets 的temperature 值与大型模型的一样,hard targets 的 temperature 值取 1。当使用 soft 和 hard targets,需要乘以(T^2)。

2.1 Matching logits is a special case of distillation

大型模型的logits输出为 (v_i),输入进softmax层生后计算的概率值为(p_i) (也是 soft targets);小型模型的logits值 (z_i),softmax层值为(q_i)(即公式1)且temperature的值都设置为(T),此时的交叉熵为(假设有N个训练数据):

[C = -\sum_{j=1}^{N} p_j \log q_j ]

transfer 阶段,对(z_i)的梯度为:

[\begin{aligned} \frac{\partial C}{\partial z_i}&=-\sum_{j=1}^{N} p_j \frac{\partial \log q_j}{\partial z_i}\ &=-\sum_{j=1}^{N} p_j \frac{\partial \log q_j}{\partial q_j}\frac{\partial q_j}{\partial z_i}\ &=-\sum_{j=1}^{N} p_j \frac{1}{q_j}\frac{\partial q_j}{\partial z_i} \ \end{aligned} ]

分情况考虑第三项,(q_j)的分母部分可以拆分成(c+e^{z_i/t}),其中(c)与(z_i)无关,相当于常数。那么,当(i=j)时,(q_j)可以写成(q_j = 1 – \frac{c}{c+e^{z_i/T}}):

[\begin{aligned} \frac{\partial q_j}{\partial z_i} &= (-c)(-1)\frac{\frac{1}{T}e^{z_i/T}}{(c+e^{z_i/T})^2}\ &=\frac{1}{T}\frac{e^{z_i/T}}{c+e^{z_i/T}}\frac{c}{c+e^{z_i/T}}\ &=\frac{1}{T}q_i(1-q_i)\ &=\frac{1}{T}q_j(1-q_j) \end{aligned} ]

当(i\neq j)时,(q_j)可以写成(q_j = \frac{e^{z_j/T}}{c+e^{z_i/T}}):

[\begin{aligned} \frac{\partial q_j}{\partial z_i} &= (-1)e^{z_j/T}\frac{\frac{1}{T}e^{z_i/T}}{(c+e^{z_i/T})^2}\ &=-\frac{1}{T}\frac{e^{z_j/T}}{c+e^{z_i/T}}\frac{e^{z_i/T}}{c+e^{z_i/T}}\ &=-\frac{1}{T}q_j q_i \end{aligned} ]

整理,可得:

[\begin{aligned} \frac{\partial C}{\partial z_i} &= -\frac{1}{T}\left ( p_j(1-q_j)-\sum_{j=1,j\neq i}^{N}p_j q_i \right )\ &=\frac{1}{T}\left ( p_j q_j +\sum_{j=1,j\neq i}^{N}p_j q_i -p_j \right )\ &=\frac{1}{T}\left ( \sum_{j=1}^{N}p_j q_i -p_j \right )\ &=\frac{1}{T}\left ( q_i – p_j \right )\ &=\frac{1}{T}\left ( \frac{e^{z_i/T}}{\sum_j e^{z_j/T}} – \frac{e^{v_i/T}}{\sum_j e^{v_j/T}}\right ) \end{aligned} \tag{2} ]

如果 temperature 的值高于 logits,即 (T \gg z_i,v_i),则可以近似:

[\frac{\partial C}{\partial z_i} \approx \frac{1}{T}\left (\frac{1+z_i/T}{N+\sum_j z_j/T} – \frac{1+v_i/T}{N+\sum_j v_j/T}\right ) \tag{3} ]

假设 logits 的均值为0,即 (\sum_j z_j = \sum_j v_j = 0),则有:

[\frac{\partial C}{\partial z_i} \approx \frac{1}{NT^2}(z_i-v_i) \tag{4} ]

因此,在temperature很高的极限下和每个transfer case都为零均值时,蒸馏等同于最小化 (1/2(z_i – v_i)^2),即MSE。

当 temperature较低时,蒸馏不怎么关注极小的负的logits,这既有优点又有缺点:好处是这些logits可能非常嘈杂,坏处是logits可能会传递一些有用的信息。所以temperature往往取中间值时效果好。

3&4 实验

作者在图像识别和语音识别两个领域进行实验,不过多描述。

5 Training ensembles of specialists on very big datasets

在这节中,针对JFT数据集,训练了多个能够细粒度分类的specialist models 和通用模型,将这些模型组合成集成模型。其中,specialist models很容易过拟合,作者还给出了如何防止过拟合的方法。

5.1 The JFT dataset

JFT 是一个谷歌数据集,包含 1 亿张带有 15,000 种标签的图像。

5.2 Specialist Models

各个specialist models 会在各个类的集合上进行训练,比如全是蘑菇但不同种类的蘑菇,将它们不关心的类整合为一个dustbin class,这样它们会给出很小的softmax值。

为了减少过拟合并分担通用模型的工作,每个specialist model都使用通用模型的权重进行初始化。然后通过训练specialist models来稍微修改这些权重,其中一半样本来自其特殊子集,一半来自训练集的其余部分随机抽样。训练后,可以通过将dustbin class的 logit 增加 log(specialist class 被采样的比例) 来纠正有偏差的训练集。

5.3 Assigning classes to specialists

为了为specialists派生类别的分组,作者重点关注于常混淆的类别。作者将聚类算法应用于通用模型预测的协方差矩阵,经常被一起预测的一组类 (S^m) 将用作一个specialist model 的target (m)。表 2 采用的是on-line version of the K-means。

【论文笔记】(知识蒸馏)Distilling the Knowledge in a Neural Network

表1. 由协方差矩阵聚类算法计算的聚类类

5.4 Performing inference with ensembles of specialists

首先检查这个新的集成模型效果如何,对于一个给定的图像(x),分两步进行分类:
第 1 步:对于每个测试样本,根据通用模型找到 (n) 个最可能的类,称这组类为 (k),在实验中作者使用 (n = 1)。
第 2 步:然后取所有的满足以下的specialists (m):其可混淆类的特殊子集 (S^m) 与 (k) 有一个非空交集,并将其称为specialists 的active set (A_k)(该集合可能为空)。 然后找所有类的完整概率分布 (\mathbf{q}),(\mathbf{q})最小化以下公式:

[KL(\mathbf{p}^g,\mathbf{q})+\sum_{m\in A_k} KL(\mathbf{p}^m,\mathbf{q}) ]

(\mathbf{p}^m,\mathbf{p}^g) 表示specialists或完整模型的概率分布,(\mathbf{p}^m) 是 (m) 的所有specialist类别加上单个dustbin class的分布。

6 Soft Targets as Regularizers

作者认为使用soft targets能够防止specialists过拟合,起到了正则化的作用。

Original: https://www.cnblogs.com/setdong/p/16392677.html
Author: 李斯赛特
Title: 【论文笔记】(知识蒸馏)Distilling the Knowledge in a Neural Network

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/578620/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • NGINX压力测试

    Nginx可以作为HTTP服务器和反向代理服务器。反向代理服务器取决于后端服务器的性能,这次只针对HTTP服务器做性能测试。Nginx作为服务器对于网络的性能必然是非常依赖的,尤其…

    Linux 2023年6月14日
    097
  • requests模块

    掌握 headers参数的使用 掌握 发送带参数的请求 掌握 headers中携带cookie 掌握 cookies参数的使用 掌握 cookieJar的转换方法 掌握 超时参数t…

    Linux 2023年6月8日
    0109
  • BLACKTOAD 的模板 未完

    博客园 :当前访问的博文已被密码保护 请输入阅读密码: Original: https://www.cnblogs.com/Grharris/p/10876375.htmlAuth…

    Linux 2023年6月6日
    086
  • Dockerfile 构建镜像

    从 Dockerfile 构建镜像涉及三个步骤 创建工作目录 编写 Dockerfile 规格 使用 docker build 命令构建镜像 1. 创建工作目录 这个根据应用实际情…

    Linux 2023年6月6日
    0107
  • 06-ElasticSearch搜索结果处理

    * package com.coolman.hotel.test; import com.coolman.hotel.pojo.HotelDoc; import com.faste…

    Linux 2023年6月7日
    0102
  • 16-ArrayList和LinkedList的区别

    1.1、作用 ArrayList和LinkedList都是实现了List接口的容器类,用于存储一系列的对象引用。它们可以对元素的增删改查进行操作 对于ArrayList,它在集合的…

    Linux 2023年6月7日
    086
  • Spring MVC处理日期字符串参数自动转换成后台Date类型

    当前台提交日期字符串到后台时,以字符串形式传输,若后台接收时采用Date类型,则会报格式转换错误的异常. 方式一: 将 @DateTimeFormat(pattern = &amp…

    Linux 2023年6月14日
    093
  • c++ 跨平台线程同步对象那些事儿——基于 ace

    前言 ACE (Adaptive Communication Environment) 是早年间很火的一个 c++ 开源通讯框架,当时 c++ 的库比较少,以至于谈 c++ 网络通…

    Linux 2023年6月6日
    094
  • CentOS导入CA证书

    把CA证书放到如下目录 /etc/pki/ca-trust/source/anchors 再命令行运行 /bin/update-ca-trust 如下所示的操作步骤 删除证书只需要…

    Linux 2023年6月6日
    094
  • thinkphp3.2.3 使用redis session存储

    为了解决session 共享问题,使用redis存储session会话信息 首先我们先研究一下 thinkphp 底层是怎么调用session的 ThinkPHP/Library/…

    Linux 2023年5月28日
    084
  • 匿名远程启动jenkins的job

    安装jenkins插件Build Authorization Token Root job配置中的构建触发器,勾选触发远程构建,输入要用的令牌,如soul 通过jenkins地址调…

    Linux 2023年6月6日
    0125
  • 函数式编程

    1 概述 2 Lambda表达式 3 Stream流 // 创建stream的方法 //1 使用Collection下的 stream() 和 parallelStream() 方…

    Linux 2023年6月7日
    089
  • shell笔记

    shell脚本学习笔记 1.Shell入门简介 Shell是操作系统的最外层, Shell可以合并编程语言以控制进程和文件,以及启动和控制其它程序。shell通过提示您输入,向操作…

    Linux 2023年6月7日
    067
  • Netty-如何写一个Http服务器

    前言 动机 最近在学习Netty框架,发现Netty是支持Http协议的。加上以前看过Spring-MVC的源码,就想着二者能不能结合一下,整一个简易的web框架(PS:其实不是整…

    Linux 2023年6月7日
    095
  • archLinux 配置用户

    archlinux 启动之后只有默认的root用户,首先介绍下系统启动到登录需要的步骤 1.系统通过systemd 以pid为1初始化系统,启动系统用户和系统必要的服务,(这一步目…

    Linux 2023年6月13日
    079
  • zabbix自定义监控mysql主从状态和延迟

    zabbix自定义监控mysql主从状态和延迟 zabbix自定义监控mysql主从状态和延迟 zabbix自定义监控mysql主从状态 zabbix自定义监控mysql主从延迟 …

    Linux 2023年6月13日
    0117
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球