【Pytorch基础】torch.nn.CrossEntropyLoss损失函数介绍

交叉熵主要是用来判定实际的输出与期望的输出的接近程度,为什么这么说呢,举个例子:在做分类的训练的时候,如果一个样本属于第K类,那么这个类别所对应的输出节点的输出值应该为1,而其他节点的输出都为0,即[0,0,1,0,….0,0],这个数组也就是样本的Label,是神经网络最期望的输出结果。也就是说用它来衡量网络的输出与标签的差异,利用这种差异经过反向传播去更新网络参数。 参考文献【1】

3.1 举个栗子

交叉熵损失,是分类任务中最常用的一个损失函数。在Pytorch中是基于下面的公式实现的。
Loss ⁡ ( x ^ , x ) = − ∑ i = 1 n x log ⁡ ( x ^ ) \operatorname{Loss}(\hat{x}, x)=-\sum_{i=1}^{n} x \log (\hat{x})Loss (x ^,x )=−i =1 ∑n ​x lo g (x ^)
其中x x x是真实标签, x ^ \hat{x}x ^ 是预测的类分布(通常是使用softmax将模 型输出转换为概率分布)。
取单个样本举例, 假设x 1 = [ 0 , 1 , 0 ] x_1=[0, 1, 0]x 1 ​=[0 ,1 ,0 ], 模型预测样本x 1 x_1 x 1 ​的概率为x 1 ^ = [ 0.1 , 0.5 , 0.4 ] \hat{x_1}=[0.1, 0.5, 0.4]x 1 ​^​=0.1 ,0.5 ,0.4 。则样本的损失计算如下所示:

Loss ⁡ ( x 1 ^ , x 1 ) = − 0 × log ⁡ ( 0.1 ) − 1 × log ⁡ ( 0.5 ) − 0 × log ⁡ ( 0.4 ) = log ⁡ ( 0.5 ) \operatorname{Loss}(\hat{x_1}, x_1)=-0 \times \log (0.1)-1 \times \log (0.5)-0 \times \log (0.4)=\log (0.5)Loss (x 1 ​^​,x 1 ​)=−0 ×lo g (0.1 )−1 ×lo g (0.5 )−0 ×lo g (0.4 )=lo g (0.5 )

更详细的多分类交叉熵损失函数的例子可以参考文献【4】

3.2 Pytorch实现

实际使用中需要注意几点:

输入的形式大概如下所示

import torch
target = [1, 3, 2]

input_ = [[0.13, -0.18, 0.87],
         [0.25, -0.04, 0.32],
         [0.24, -0.54, 0.53]]

loss_item = torch.nn.CrossEntropyLoss()
loss = loss_item(input, target)

CrossEntropyLoss函数里面的实现,如下所示:

def forward(self, input, target):
    return F.cross_entropy(input, target, weight=self.weight,
                           ignore_index=self.ignore_index, reduction=self.reduction)

是调用的torch.nn.functional(俗称F)中的cross_entropy()函数。

此处需要区分一下:torch.nn.Module 和 torch.nn.functional(俗称F)中损失函数的区别。Module的损失函数例如CrossEntropyLoss、NLLLoss等是封装之后的损失函数类,是一个类,因此其中的变量可以自动维护。经常是对F中的函数的封装。而F中的损失函数只是单纯的函数。
下面看一下F.cross_entropy函数

3.3 F.cross_entropy

  • input:预测值,(batch,dim),这里dim就是要分类的总类别数
  • target:真实值,(batch),这里为啥是1维的?因为真实值并不是用one-hot形式表示,而是直接传类别id。
  • weight:指定权重,(dim),可选参数,可以给每个类指定一个权重。通常在训练数据中不同类别的样本数量差别较大时,可以使用权重来平衡。
  • ignore_index:指定忽略一个真实值,(int),也就是手动忽略一个真实值。
  • reduction:在[none, mean, sum]中选,string型。none表示不降维,返回和target相同形状;mean表示对一个batch的损失求均值;sum表示对一个batch的损失求和。

其中参数weight、ignore_index、reduction要在实例化CrossEntropyLoss对象时指定,例如:

loss = torch.nn.CrossEntropyLoss(reduction='none')

F中的cross_entropy的实现

return nll_loss(log_softmax(input, dim=1), target, weight, None, ignore_index, None, reduction)

可以看到就是先调用log_softmax,再调用nll_loss。log_softmax就是先softmax再取log。

Original: https://blog.csdn.net/zfhsfdhdfajhsr/article/details/124689632
Author: 一穷二白到年薪百万
Title: 【Pytorch基础】torch.nn.CrossEntropyLoss损失函数介绍

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/627307/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球