OhemCELoss函数( Online hard example mining cross-entropy loss 的缩写)
分割任务中的OhemCELoss函数:其实就是分类任务的交叉熵函数—>每个像素计算分类交叉熵—->根据loss选取难样本,一步一步扩展得到。
在语义分割网络中常用的损失函数,这里大概记录几个需要留意的点:
1)计算交叉熵损失时,是以 一个像素点为计算单位,计算出 每个像素点的交叉熵分类损失。
2)ohem难样本挖掘时,根据给定的阈值选取前n_min个像素点的loss值。
使用pytorch框架OhemCELoss函数的代码实现
class OhemCELoss(nn.Module):
"""
Online hard example mining cross-entropy loss:在线难样本挖掘
if loss[self.n_min] > self.thresh: 最少考虑 n_min 个损失最大的 pixel,
如果前 n_min 个损失中最小的那个的损失仍然大于设定的阈值,
那么取实际所有大于该阈值的元素计算损失:loss=loss[loss>thresh]。
否则,计算前 n_min 个损失:loss = loss[:self.n_min]
"""
def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
super(OhemCELoss, self).__init__()
self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda() # 将输入的概率 转换为loss值
self.n_min = n_min
self.ignore_lb = ignore_lb
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') #交叉熵
def forward(self, logits, labels):
N, C, H, W = logits.size()
loss = self.criteria(logits, labels).view(-1)
loss, _ = torch.sort(loss, descending=True) # 排序
if loss[self.n_min] > self.thresh: # 当loss大于阈值(由输入概率转换成loss阈值)的像素数量比n_min多时,取所以大于阈值的loss值
loss = loss[loss>self.thresh]
else:
loss = loss[:self.n_min]
return torch.mean(loss)
详细的内容直接参见下面的博客就好,不必重复码字造车。
Original: https://blog.csdn.net/chen1234520nnn/article/details/122812038
Author: 超级无敌陈大佬的跟班
Title: 图像分割损失函数OhemCELoss
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/710234/
转载文章受原作者版权保护。转载请注明原作者出处!