深度学习:标签平滑(Label Smoothing Regularization)

1.标签平滑的作用—防止过拟合

在进行多分类时,很多时候采用one-hot标签进行计算交叉熵损失,而单纯的交叉熵损失时,只考虑到了正确标签的位置的损失,而忽略了错误标签位置的损失。这样导致模型可能会在训练集上拟合的非常好,但由于其错误标签位置的损失没有计算,导致预测的时候,预测错误的概率比较大,也就是常说的过拟合。
标签平滑可以在一定程度上防止过拟合。

2. 传统的交叉熵损失计算

Step1: softmax多分类
P i = e z i ∑ i = 1 n e z i P_i = { e^{z_i} \over {\sum_{i=1}^{n} e^{z_i}} }P i ​=∑i =1 n ​e z i ​e z i ​​
其中,p i p_i p i ​为当前样本属于类别i i i的概率,z i z_i z i ​ 指当前样本的对应类别i i i的l o g i t logit l o g i t, n表示样本的总列别数。
Step2: 交叉熵损失计算公式:
c r o s s L o s s = − 1 M ∑ m = 1 M ∑ i = 1 n y i l o g p i crossLoss = – {1 \over M} {\sum_{m=1}^M {\sum_{i=1}^n}} y_ilog{p_i}cross L oss =−M 1 ​m =1 ∑M ​i =1 ∑n ​y i ​l o g p i ​
其中,M M M表示样本综述。
实例:
假设一批样本,样本类别的总数n=5, 其中一个样本的one-hot标签为[ 0 , 0 , 0 , 1 , 0 ] [0,0,0,1,0][0 ,0 ,0 ,1 ,0 ],假设通过模型(如全连接等)的l o g i t logit l o g i t进行softmax后的概率矩阵p p p为:
p = [ 0.1 , 0.1 , 0.1 , 0.36 , 0.34 ] p = [0.1,0.1,0.1, 0.36, 0.34]p =[0.1 ,0.1 ,0.1 ,0.36 ,0.34 ]
将其带入到上面的公式,即可计算出单个样本的loss为:
l o s s = − ( 0 ∗ l o g 0.1 + 0 ∗ l o g 0.1 + 0 ∗ l o g 0.1 + 1 ∗ l o g 0.36 + 0 ∗ l o g 0.34 ) = − l o g 0.36 = 1.47 loss = -(0log0.1+0log0.1+0log0.1+1log0.36+0*log0.34) = -log0.36=1.47 l oss =−(0 ∗l o g 0.1 +0 ∗l o g 0.1 +0 ∗l o g 0.1 +1 ∗l o g 0.36 +0 ∗l o g 0.34 )=−l o g 0.36 =1.47
这种传统计算交叉熵损失只考虑了正确标签位置的损失,而没有考虑错误标签的损失。下面让我们看看带有标签平滑的交叉熵损失是怎样计算的吧。

3.带有标签平滑的交叉熵损失的计算

同样是上面的例子:一批样本,样本类别的总数n=5, 其中一个样本的one-hot标签为[ 0 , 0 , 0 , 1 , 0 ] [0,0,0,1,0][0 ,0 ,0 ,1 ,0 ],假设通过模型(如全连接等)的l o g i t logit l o g i t进行softmax后的概率矩阵p p p为:
p = [ 0.1 , 0.1 , 0.1 , 0.36 , 0.34 ] p = [0.1,0.1,0.1, 0.36, 0.34]p =[0.1 ,0.1 ,0.1 ,0.36 ,0.34 ]
设:标签的平滑因子ϵ = 0.1 \epsilon=0.1 ϵ=0.1,平滑的计算步骤如下:
y 1 = ( 1 − ϵ ) ∗ [ 0 , 0 , 0 , 1 , 0 ] = [ 0 , 0 , 0 , 0.9 , 0 ] y1 = (1-\epsilon)[0,0,0,1,0] = [0,0,0,0.9,0]y 1 =(1 −ϵ)∗[0 ,0 ,0 ,1 ,0 ]=[0 ,0 ,0 ,0.9 ,0 ]
y 2 = ϵ ∗ [ 1 , 1 , 1 , 1 , 1 ] / 5 = [ 0.1 , 0.1 , 0.1 , 0.1 , 0.1 ] / 5 = [ 0.02 , 0.02 , 0.02 , 0.02 , 0.02 ] y2 = \epsilon
[1,1,1,1,1] / 5= [0.1,0.1,0.1,0.1,0.1]/5 = [0.02, 0.02, 0.02, 0.02, 0.02]y 2 =ϵ∗[1 ,1 ,1 ,1 ,1 ]/5 =[0.1 ,0.1 ,0.1 ,0.1 ,0.1 ]/5 =[0.02 ,0.02 ,0.02 ,0.02 ,0.02 ]
y = y 1 + y 2 = [ 0.02 , 0.02 , 0.02 , 0.92 , 0.02 ] y = y1+y2 = [0.02,0.02,0.02,0.92, 0.02]y =y 1 +y 2 =[0.02 ,0.02 ,0.02 ,0.92 ,0.02 ]
y y y即是平滑后的新标签,然后按照传统的交叉熵损失计算步骤即可,如:
l o s s = − y ∗ l o g p = − [ 0.02 , 0.02 , 0.02 , 0.92 , 0.02 ] ∗ l o g ( [ 0.1 , 0.1 , 0.1 , 0.36 , 0.34 ] ) = 2.63 loss=-ylogp = -[0.02,0.02,0.02,0.92, 0.02] log([0.1,0.1,0.1,0.36,0.34])=2.63 l oss =−y ∗l o g p =−[0.02 ,0.02 ,0.02 ,0.92 ,0.02 ]∗l o g ([0.1 ,0.1 ,0.1 ,0.36 ,0.34 ])=2.63

4.标签平滑与传统的交叉熵损失的比较与分析

有上面实例可以看出,带有标签平滑的损失要比传统交叉熵损失要更大。换言之,带有标签平滑的损失要想下降到传统交叉熵损失的程度,就要学习的更好,迫使模型往正确分类的方向走。

5. 标签平滑的应用场景

只要用到的是交叉熵损失(cross loss),都可以采取标签平滑处理。

6.pytorch的实现与使用

import torch
import torch.nn as nn
import torch.nn.functional as F

class CELossWithLabelSmoothing(nn.Module):
    ''' Cross Entropy Loss with label smoothing '''
    def __init__(self, label_smooth=0.1, class_num=3755):
        super().__init__()
        self.label_smooth = label_smooth
        self.class_num = class_num

    def forward(self, pred, target):
        '''
        Args:
            pred: prediction of model output    [N, M]
            target: ground truth of sampler [N]
        '''
        eps = 1e-12

        if self.label_smooth is not None:

            logprobs = F.log_softmax(pred, dim=1)
            target = F.one_hot(target, self.class_num)

            target = torch.clamp(target.float(), min=self.label_smooth / (self.class_num - 1),
                                 max=1.0 - self.label_smooth)
            loss = -1 * torch.sum(target * logprobs, 1)

        else:

            loss = -1. * pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred + eps).sum(dim=1))

        return loss.mean()

if __name__ == '__main__':
    loss2 = CELossWithLabelSmoothing(label_smooth=0.2, class_num=3)
    x = torch.tensor([[0.1, 8, 0.1], [0.1, 0.1, 8]], dtype=torch.float)
    y = torch.tensor([1, 2])
    print(loss2(x, y))

Original: https://blog.csdn.net/qq_41915623/article/details/124852409
Author: 陈壮实的搬砖生活
Title: 深度学习:标签平滑(Label Smoothing Regularization)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/628450/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球