关于Focal loss损失函数的代码实现

  • Focal loss的公式:其中用到的交叉熵损失函数表达式是(3)
    F L ( p t ) = − ( 1 − p t ) γ log ⁡ p t (1) FL(p_{t}) = – (1 – p_{t})^{\gamma}\log{p_{t}}\tag{1}F L (p t ​)=−(1 −p t ​)γlo g p t ​(1 )
  • 其中:
    p t = { p i f y = 1 1 − p o t h e r w i s e (1.1) p_{t}=\begin{cases} p& if & y = 1 \ 1-p && otherwise \end{cases}\tag{1.1}p t ​={p 1 −p ​i f ​y =1 o t h er w i se ​(1.1 )

    关于Focal loss损失函数的代码实现
  • BCE:二分类
    L = − ∑ i = 1 N ( y i log ⁡ y ^ i + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ) (2) L = -\sum^N_{i=1}(y_{i}\log{\hat{y}{i}} + (1-y{i})\log{(1-\hat{y}}_{i}))\tag{2}L =−i =1 ∑N ​(y i ​lo g y ^​i ​+(1 −y i ​)lo g (1 −y ^​i ​))(2 )

  • CE:多分类,当其是二分类时候与BCE有什么区别可见上面的链接
    L = − ∑ i = 1 N ( y i log ⁡ y ^ i ) (3) L = -\sum^N_{i=1}(y_{i}\log{\hat{y}_{i}} )\tag{3}L =−i =1 ∑N ​(y i ​lo g y ^​i ​)(3 )
  • pytorch中具体实现方法可以查看:[CrossEntropyLoss — PyTorch 1.12 documentation]
    关于Focal loss损失函数的代码实现
    关于Focal loss损失函数的代码实现
  • softmax,log_softmax,nllloss的表达式:
  • 关于nllloss专门整理一篇介绍。
    σ ( z ) j = e z j ∑ k = 1 n e z k (softmax) \sigma(z){j} = \frac{e^{z{j}}}{\sum_{k=1}^ne^{z_{k}}}\tag{softmax}σ(z )j ​=∑k =1 n ​e z k ​e z j ​​(softmax )

l o g s o f t m a x = ln ⁡ σ ( z ) j logsoftmax = \ln{\sigma(z)_{j}}l o g so f t ma x =ln σ(z )j ​

n l l l o s s = − 1 N ∑ k = 1 N y k ( l o g s o f t m a x ) nllloss = – \frac{1}{N}\sum_{k=1}^Ny_{k}(logsoftmax)n lll oss =−N 1 ​k =1 ∑N ​y k ​(l o g so f t ma x )

  • 使用pytorch实现focal loss源码如下:(个人觉得比较简练的一个)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision
import torchvision.transforms as F

from IPython.display import display
class FocalLoss(nn.Module):

    def __init__(self, weight=None, reduction='mean', gamma=0, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps
        self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction)

    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        return loss.mean()

代码实现的原理如下:

pytorch中交叉熵损失函数所有表达式,类比(3)
l o s s ( x , c l a s s ) = − log ⁡ e x c l a s s ∑ j e x j = − x c l a s s + log ⁡ ∑ j e x j (3) loss(x,class) = -\log{\frac{e^{x_{class}}}{\sum_{j}e^{x_j}}}= -x_{class} + \log{\sum_{j}e^{x_j}}\tag{3}l oss (x ,c l a ss )=−lo g ∑j ​e x j ​e x c l a ss ​​=−x c l a ss ​+lo g j ∑​e x j ​(3 )
α-balanced交叉熵结合表达式
l o s s ( x , c l a s s ) = α c l a s s ∗ ( − x c l a s s + log ⁡ ∑ j e x j ) (4) loss(x,class)= \alpha_{class}*(-x_{class} + \log{\sum_{j}e^{x_j}})\tag{4}l oss (x ,c l a ss )=αc l a ss ​∗(−x c l a ss ​+lo g j ∑​e x j ​)(4 )
focal loss表达式:
l o s s ( x , c l a s s ) = ( 1 − e x c l a s s ∑ j e x j ) γ − log ⁡ e x c l a s s ∑ j e x j = ( 1 − e x c l a s s ∑ j e x j ) γ ( − x c l a s s + log ⁡ ∑ j e x j ) = − ( 1 − p t ) γ log ⁡ ( p t ) (5) loss(x,class) =(1 – \frac{e^{x_{class}}}{\sum_{j}e^{x_j}})^{\gamma} -\log{\frac{e^{x_{class}}}{\sum_{j}e^{x_j}}} =(1 – \frac{e^{x_{class}}}{\sum_{j}e^{x_j}})^{\gamma}(-x_{class} + \log{\sum_{j}e^{x_j}}) = -(1-p_{t})^{\gamma} \log{(p_{t})}\tag{5}l oss (x ,c l a ss )=(1 −∑j ​e x j ​e x c l a ss ​​)γ−lo g ∑j ​e x j ​e x c l a ss ​​=(1 −∑j ​e x j ​e x c l a ss ​​)γ(−x c l a ss ​+lo g j ∑​e x j ​)=−(1 −p t ​)γlo g (p t ​)(5 )
带有alpha平衡参数的focal loss表达式:
l o s s ( x , c l a s s ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) (6) loss(x,class) = -\alpha_{t}(1-p_{t})^{\gamma} \log{(p_{t})}\tag{6}l oss (x ,c l a ss )=−αt ​(1 −p t ​)γlo g (p t ​)(6 )
将CrossEntropyLoss改成Focal Loss
− log ⁡ p t = n n . C r o s s E n t r o p y L o s s ( i n p u t , t a r g e t ) (7) -\log{p_{t}} = nn.CrossEntropyLoss(input, target)\tag{7}−lo g p t ​=nn .C ross E n t ro p y L oss (in p u t ,t a r g e t )(7 )
那么:
p t = t o r c h . e x p ( − n n . C r o s s E n t r o p y L o s s ( i n p u t , t a r g e t ) ) (8) p_{t} = torch.exp(-nn.CrossEntropyLoss(input, target))\tag{8}p t ​=t orc h .e x p (−nn .C ross E n t ro p y L oss (in p u t ,t a r g e t ))(8 )
所有Focal loss的最终为
f o c a l l o s s = − α t ( 1 − p t ) γ log ⁡ ( p t ) (9) focalloss = -\alpha_{t}(1-p_{t})^{\gamma} \log{(p_{t})}\tag{9}f oc a ll oss =−αt ​(1 −p t ​)γlo g (p t ​)(9 )
当然考虑到是mini-batch算法,因此最后一步取均值运算。

关于使用CE与BCE的实现方法可以参考以下代码:(关于γ与α的调参也有部分解答)

基于二分类交叉熵实现


class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

其他的参考资料

关于binary_cross_entropy_with_logits与binary_cross_entropy的区别可以看:

关于focal loss二分类公式的一些变形可以参考:

使用纯pytorch代码实现focal loss

辅助理解代码实现:

Original: https://blog.csdn.net/Lian_Ge_Blog/article/details/126247720
Author: Lian_Ge_Blog
Title: 关于Focal loss损失函数的代码实现

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/787049/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球