图像分割损失函数OhemCELoss

OhemCELoss函数( Online hard example mining cross-entropy loss 的缩写)

分割任务中的OhemCELoss函数:其实就是分类任务的交叉熵函数—>每个像素计算分类交叉熵—->根据loss选取难样本,一步一步扩展得到。

在语义分割网络中常用的损失函数,这里大概记录几个需要留意的点:

1)计算交叉熵损失时,是以 一个像素点为计算单位,计算出 每个像素点的交叉熵分类损失。

2)ohem难样本挖掘时,根据给定的阈值选取前n_min个像素点的loss值。

使用pytorch框架OhemCELoss函数的代码实现

class OhemCELoss(nn.Module):
"""
    Online hard example mining cross-entropy loss:在线难样本挖掘
    if loss[self.n_min] > self.thresh: 最少考虑 n_min 个损失最大的 pixel,
    如果前 n_min 个损失中最小的那个的损失仍然大于设定的阈值,
    那么取实际所有大于该阈值的元素计算损失:loss=loss[loss>thresh]。
    否则,计算前 n_min 个损失:loss = loss[:self.n_min]
"""
    def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
        super(OhemCELoss, self).__init__()
        self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()     # 将输入的概率 转换为loss值
        self.n_min = n_min
        self.ignore_lb = ignore_lb
        self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')   #交叉熵

    def forward(self, logits, labels):
        N, C, H, W = logits.size()
        loss = self.criteria(logits, labels).view(-1)
        loss, _ = torch.sort(loss, descending=True)     # 排序
        if loss[self.n_min] > self.thresh:       # 当loss大于阈值(由输入概率转换成loss阈值)的像素数量比n_min多时,取所以大于阈值的loss值
            loss = loss[loss>self.thresh]
        else:
            loss = loss[:self.n_min]
        return torch.mean(loss)

详细的内容直接参见下面的博客就好,不必重复码字造车。

Original: https://blog.csdn.net/chen1234520nnn/article/details/122812038
Author: 超级无敌陈大佬的跟班
Title: 图像分割损失函数OhemCELoss

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/710234/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球