torch.nn.CrossEntropyLoss用法(原理, nlp, cv例子)

早上想花一个小时参照网上其他教程,修改模型结构,写一个手写识别数字的出来,结果卡在了这个上面,loss一直降不下来,然后我就去查看了一下CrossEntropyLoss的用法,毕竟分类问题一般都用这个。

原理层面

引入一个库:

import torch

假如是一个四分类任务,batch为2(只是为了显示简单,举个例子罢了)

logists = torch.randn(2, 4, requires_grad=True)
print(logists)

事实上,根据这个模型,第一个样本预测的类别是1,第二个样本预测的类别是2。

[En]

In fact, according to this model, the category predicted by the first sample is 1 and the category predicted by the second sample is 2.

这里我们假设模型足够好,都预测对了,那么其实target就是ground_truth。

target = logists.argmax(dim=-1)

定义损失函数:

crition = torch.nn.CrossEntropyLoss()

先来看个target_1d版的loss:

crition(logists, target)

再来看个target one-hot版的:
注意: 该版本在我的macbook python3.7.8, torch1.10.2的版本上没有问题, 但是在我的windows python3.7.6 torch1.9.1就出问题了!!! 因此稳妥起见还是直接用target比较好
先把target转为one

t_onehot = torch.nn.functional.one_hot(target, num_classes=4)

如何是one_hot, 要求target也是浮点类型的,所以t_onehot再调用float()转为浮点类型。

crition(logists, t_onehot.float())

最后发现两种方法其实算出来的loss都是0.5601

另外插一嘴,crossEntropyLoss也可以通过nll_loss实现(如果你去看torch.nn.crossEntropyLoss的源码就会发现官方就是使用torch.nn.functional.nll_loss实现的,只不过模型输出的logists值要先经过log_softmax

nlp

这里来个序列标注的例子。模型输出是 (batch_size, seq_len, hidden_dim)。
在这里,我将向您展示两种方法,这两种方法会更简洁,但不如方法2那么易读,具体取决于您的个人喜好。

[En]

Here I will show you two methods, which will be more concise, but not as readable as method 2, depending on your personal preference.


import torch

cel = torch.nn.CrossEntropyLoss()

batch_size, seq_len, hidden_dim = 4, 28, 128

x = torch.randn(batch_size, seq_len, hidden_dim)

gt = torch.ones(batch_size, seq_len).long()

print('method 1:')
print(cel(x.permute(0, 2, 1), gt))
print()

print('method 2:')
print(cel(x.view(-1, hidden_dim), gt.view(-1)))

可以看到两种方法一样。

图像分割例子,模型输出是 (batch_size, channel, height, width), 有多少个类别就有多少个channel, 通常医疗上的语义分割是2分类,因此输出channel为2。
在这里,我将向您展示两种方法,这两种方法会更简洁,但不如方法2那么易读,具体取决于您的个人喜好。

[En]

Here I will show you two methods, which will be more concise, but not as readable as method 2, depending on your personal preference.


import torch

cel = torch.nn.CrossEntropyLoss()

b, ic, oc, h, w = 4, 3, 2, 28, 28

x = torch.randn(b, oc, h, w)

gt = torch.ones(b, 1, h, w).long()

print('method 1:')
print(cel(x.view(b, oc, -1), gt.squeeze().view(b, -1)))
print()

print('method 2:')
print(cel(x.permute(0, 2, 3, 1).reshape(-1, oc), gt.squeeze().reshape(-1)))

Original: https://blog.csdn.net/weixin_43850253/article/details/122510794
Author: Andy Dennis
Title: torch.nn.CrossEntropyLoss用法(原理, nlp, cv例子)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/527495/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球