「解析」FocalLoss 解决数据不平衡问题

「解析」FocalLoss 解决数据不平衡问题

FocalLoss 的出现,主要是为了解决 anchor-based (one-stage) 目标检测网络的分类问题。后面实例分割也常使用。

注意
这里是 目标检测网络的分类问题,而不是单纯的分类问题,这两者是不一样的。
区别在于,对于分配问题,一个图片一定是属于某一确定的类的;而检测任务中的分类,是有大量的anchor无目标的(可以称为负样本)。

分类任务

正常的 K类分类任务 的标签,是用一个K长度的向量作为标签,用one-hot(或者+smooth,这里先不考虑)来进行编码,最终的标签是一个形如[1,…, 0, …, 0]这样的。那么如果想要将背景分离出,自然可以想到增加一个1维,如果目标检测任务有K类,这里只要用K+1维来表示分类,其中1维代表无目标即可。对于分类任务而言,最后一般使用 softmax 来归一,使得所有类别的输出加和为1。

「解析」FocalLoss 解决数据不平衡问题

但是在检测任务中,对于无目标的anchor,我们并不希望最终结果加和为1,而是所有的概率输出都是0。 那么可以这样,我们将一个多分类任务看做多个二分类任务(sigmoid),针对每一个类别,我输出一个概率,如果接近0则代表非该类别,如果接近1,则代表这个anchor是该类别。

所以网络输出不需要用softmax来归一,而是对K长度向量的每一个分量进行sigmoid激活,让其输出值代表二分类的概率。对于无目标的anchor,gt中所有的分量都是0,代表属于每一类的概率是0,即标注为背景。

至此,FocalLoss解决的问题不是多分类问题,而是 多个二分类问题

; 公式解析

首先看公式:只有 标签y = 1 y=1 y =1时,公式/交叉熵才有意义,p t p_t p t ​ 即为标签为1时对应的预测值/模型分类正确的概率
p t = ( 1 − p r e d _ s i g m o i d ) ∗ t a r g e t + p r e d _ s i g m o i d ∗ ( 1 − t a r g e t ) p_t = (1 – pred_sigmoid) * target + pred_sigmoid * (1 – target)p t ​=(1 −p re d _s i g m o i d )∗t a r g e t +p re d _s i g m o i d ∗(1 −t a r g e t )

C E ( p t ) = − α t log ⁡ ( p t ) F L ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) F L ( p ) = { − α ( 1 − p ) γ log ⁡ ( p ) , i f y = 1 − ( 1 − α ) p γ log ⁡ ( 1 − p ) , i f y = 0 CE(p_t)=-\alpha_t \log(p_t) \ \quad \ FL(p_t)=-\alpha_t(1-p_t)^\gamma \log(p_t) \ \quad \ FL(p) = \begin{cases} \quad -\alpha(1-p)^\gamma \log(p) &, if \quad y=1 &\ -(1-\alpha)p^\gamma \log(1-p)&,if \quad y=0 \end{cases}CE (p t ​)=−αt ​lo g (p t ​)F L (p t ​)=−αt ​(1 −p t ​)γlo g (p t ​)F L (p )={−α(1 −p )γlo g (p )−(1 −α)p γlo g (1 −p )​,i f y =1 ,i f y =0 ​

  1. 参数p[公式3]:当 p->0时(概率很低/很难区分是那个类别),调制因子 (1-p)接近1,损失不被影响,当 p->1时,(1-p)接近0,从而减小易分样本对总 loss的贡献
  2. 参数γ \gamma γ:当γ = 0 \gamma=0 γ=0 时,Focal loss就是传统的交叉熵,
    当γ \gamma γ 增加时, 调节系数( 1 − p t ) (1-p_t)(1 −p t ​) 也会增加。
    当γ \gamma γ 为定值时,比如γ = 2 \gamma=2 γ=2 1⃣️对于easy example(p>0.5) p=0.9 的loss要比标准的交叉熵小 100倍,当 p=0.968时,要小1000+倍;2⃣️对于 hard example(p
  3. α \alpha α 调节正负样本不平衡系数,γ \gamma γ 控制难易样本不平衡

代码复现

在官方给的代码中,并没有 target = F.one_hot(target, num_clas) 这行代码,这是因为


import torch
from torch.nn import functional as F

def sigmoid_focal_loss( inputs: torch.Tensor, targets: torch.Tensor, alpha: float = -1,
                        gamma: float = 2, reduction: str = "none") -> torch.Tensor:

    inputs  = inputs.float()
    targets = targets.float()
    p       = torch.sigmoid(inputs)
    target  = F.one_hot(target, num_clas+1)

    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t     = p * targets + (1 - p) * (1 - targets)
    loss    = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    if reduction == "mean":
        loss = loss.mean()
    elif reduction == "sum":
        loss = loss.sum()

    return loss

sigmoid_focal_loss_jit: "torch.jit.ScriptModule" = torch.jit.script(sigmoid_focal_loss)

此外,torchvision 中也支持 focal loss

完整代码

官方完整代码:https://github.com/facebookresearch/

参考

  • https://zhuanlan.zhihu.com/p/391186824

Original: https://blog.csdn.net/ViatorSun/article/details/124861342
Author: ViatorSun
Title: 「解析」FocalLoss 解决数据不平衡问题

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/681957/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球