交叉熵主要是用来判定实际的输出与期望的输出的接近程度,为什么这么说呢,举个例子:在做分类的训练的时候,如果一个样本属于第K类,那么这个类别所对应的输出节点的输出值应该为1,而其他节点的输出都为0,即[0,0,1,0,….0,0],这个数组也就是样本的Label,是神经网络最期望的输出结果。也就是说用它来衡量网络的输出与标签的差异,利用这种差异经过反向传播去更新网络参数。 参考文献【1】
3.1 举个栗子
交叉熵损失,是分类任务中最常用的一个损失函数。在Pytorch中是基于下面的公式实现的。
Loss ( x ^ , x ) = − ∑ i = 1 n x log ( x ^ ) \operatorname{Loss}(\hat{x}, x)=-\sum_{i=1}^{n} x \log (\hat{x})Loss (x ^,x )=−i =1 ∑n x lo g (x ^)
其中x x x是真实标签, x ^ \hat{x}x ^ 是预测的类分布(通常是使用softmax将模 型输出转换为概率分布)。
取单个样本举例, 假设x 1 = [ 0 , 1 , 0 ] x_1=[0, 1, 0]x 1 =[0 ,1 ,0 ], 模型预测样本x 1 x_1 x 1 的概率为x 1 ^ = [ 0.1 , 0.5 , 0.4 ] \hat{x_1}=[0.1, 0.5, 0.4]x 1 ^=0.1 ,0.5 ,0.4 。则样本的损失计算如下所示:
Loss ( x 1 ^ , x 1 ) = − 0 × log ( 0.1 ) − 1 × log ( 0.5 ) − 0 × log ( 0.4 ) = log ( 0.5 ) \operatorname{Loss}(\hat{x_1}, x_1)=-0 \times \log (0.1)-1 \times \log (0.5)-0 \times \log (0.4)=\log (0.5)Loss (x 1 ^,x 1 )=−0 ×lo g (0.1 )−1 ×lo g (0.5 )−0 ×lo g (0.4 )=lo g (0.5 )
更详细的多分类交叉熵损失函数的例子可以参考文献【4】
3.2 Pytorch实现
实际使用中需要注意几点:
输入的形式大概如下所示:
import torch
target = [1, 3, 2]
input_ = [[0.13, -0.18, 0.87],
[0.25, -0.04, 0.32],
[0.24, -0.54, 0.53]]
loss_item = torch.nn.CrossEntropyLoss()
loss = loss_item(input, target)
CrossEntropyLoss函数里面的实现,如下所示:
def forward(self, input, target):
return F.cross_entropy(input, target, weight=self.weight,
ignore_index=self.ignore_index, reduction=self.reduction)
是调用的torch.nn.functional(俗称F)中的cross_entropy()函数。
此处需要区分一下:torch.nn.Module 和 torch.nn.functional(俗称F)中损失函数的区别。Module的损失函数例如CrossEntropyLoss、NLLLoss等是封装之后的损失函数类,是一个类,因此其中的变量可以自动维护。经常是对F中的函数的封装。而F中的损失函数只是单纯的函数。
下面看一下F.cross_entropy函数
3.3 F.cross_entropy
- input:预测值,(batch,dim),这里dim就是要分类的总类别数
- target:真实值,(batch),这里为啥是1维的?因为真实值并不是用one-hot形式表示,而是直接传类别id。
- weight:指定权重,(dim),可选参数,可以给每个类指定一个权重。通常在训练数据中不同类别的样本数量差别较大时,可以使用权重来平衡。
- ignore_index:指定忽略一个真实值,(int),也就是手动忽略一个真实值。
- reduction:在[none, mean, sum]中选,string型。none表示不降维,返回和target相同形状;mean表示对一个batch的损失求均值;sum表示对一个batch的损失求和。
其中参数weight、ignore_index、reduction要在实例化CrossEntropyLoss对象时指定,例如:
loss = torch.nn.CrossEntropyLoss(reduction='none')
F中的cross_entropy的实现
return nll_loss(log_softmax(input, dim=1), target, weight, None, ignore_index, None, reduction)
可以看到就是先调用log_softmax,再调用nll_loss。log_softmax就是先softmax再取log。
Original: https://blog.csdn.net/zfhsfdhdfajhsr/article/details/124689632
Author: 一穷二白到年薪百万
Title: 【Pytorch基础】torch.nn.CrossEntropyLoss损失函数介绍
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/627307/
转载文章受原作者版权保护。转载请注明原作者出处!