推荐答案
在 PyTorch 中,torch.nn.CrossEntropyLoss
是一个常用的损失函数,通常用于多分类任务。它结合了 LogSoftmax
和 NLLLoss
的功能,因此在使用时不需要在模型的最后一层手动添加 Softmax
层。
使用步骤
- 定义模型:确保模型的输出是未经过
Softmax
处理的原始 logits。 - 定义损失函数:使用
torch.nn.CrossEntropyLoss()
创建损失函数。 - 计算损失:将模型的输出和真实标签传递给损失函数,计算损失值。
示例代码
-- -------------------- ---- ------- ------ ----- ------ -------- -- -- - ------------ ----- ----------------------- --- --------------- ------------------ ---------------- ------- - ------------- -- - ---------------- --- ------------- --- ------ ---------- - ------ ----- - ------------- - ------ --------- - --------------------- - ----------------- ------ - -------------- --- - --------------- ------ - ---------------- -- --- - --------- - ---- ------- - ------------- - ---- ---- - ------------------ ------- ------------- ---------------
本题详细解读
1. torch.nn.CrossEntropyLoss
的作用
torch.nn.CrossEntropyLoss
是 PyTorch 中用于多分类任务的损失函数。它将 LogSoftmax
和 NLLLoss
结合在一起,因此在计算损失时不需要手动对模型的输出进行 Softmax
处理。
2. 输入要求
- 模型的输出:模型的输出应该是未经过
Softmax
处理的原始 logits,形状为(N, C)
,其中N
是批量大小,C
是类别数。 - 真实标签:真实标签的形状为
(N,)
,其中每个元素是类别的索引,范围在[0, C-1]
之间。
3. 计算过程
CrossEntropyLoss
的计算过程如下:
- 对模型的输出进行
LogSoftmax
操作,得到每个类别的对数概率。 - 使用
NLLLoss
计算负对数似然损失,即根据真实标签选择对应的对数概率并取负值。
4. 注意事项
- 不需要手动
Softmax
:由于CrossEntropyLoss
内部已经包含了LogSoftmax
,因此在模型的最后一层不需要手动添加Softmax
层。 - 标签格式:真实标签应该是类别的索引,而不是 one-hot 编码形式。
5. 示例代码解析
在示例代码中,我们定义了一个简单的线性模型 SimpleModel
,它有一个全连接层,输出 5 个类别的 logits。然后我们使用 CrossEntropyLoss
计算模型输出与真实标签之间的损失。
inputs
是一个形状为(3, 10)
的张量,表示 3 个样本,每个样本有 10 个特征。labels
是一个形状为(3,)
的张量,表示这 3 个样本的真实类别索引。outputs
是模型的输出,形状为(3, 5)
,表示每个样本属于 5 个类别的 logits。loss
是计算得到的损失值。
通过这个示例,我们可以看到如何使用 torch.nn.CrossEntropyLoss
来计算多分类任务中的损失。