PyTorch 中如何使用 torch.nn.CrossEntropyLoss?

推荐答案

在 PyTorch 中,torch.nn.CrossEntropyLoss 是一个常用的损失函数,通常用于多分类任务。它结合了 LogSoftmaxNLLLoss 的功能,因此在使用时不需要在模型的最后一层手动添加 Softmax 层。

使用步骤

  1. 定义模型:确保模型的输出是未经过 Softmax 处理的原始 logits。
  2. 定义损失函数:使用 torch.nn.CrossEntropyLoss() 创建损失函数。
  3. 计算损失:将模型的输出和真实标签传递给损失函数,计算损失值。

示例代码

-- -------------------- ---- -------
------ -----
------ -------- -- --

- ------------
----- -----------------------
    --- ---------------
        ------------------ ----------------
        ------- - ------------- --  - ----------------

    --- ------------- ---
        ------ ----------

- ------
----- - -------------

- ------
--------- - ---------------------

- -----------------
------ - -------------- ---  - ---------------
------ - ---------------- -- ---  - ---------

- ----
------- - -------------

- ----
---- - ------------------ -------

------------- ---------------

本题详细解读

1. torch.nn.CrossEntropyLoss 的作用

torch.nn.CrossEntropyLoss 是 PyTorch 中用于多分类任务的损失函数。它将 LogSoftmaxNLLLoss 结合在一起,因此在计算损失时不需要手动对模型的输出进行 Softmax 处理。

2. 输入要求

  • 模型的输出:模型的输出应该是未经过 Softmax 处理的原始 logits,形状为 (N, C),其中 N 是批量大小,C 是类别数。
  • 真实标签:真实标签的形状为 (N,),其中每个元素是类别的索引,范围在 [0, C-1] 之间。

3. 计算过程

CrossEntropyLoss 的计算过程如下:

  1. 对模型的输出进行 LogSoftmax 操作,得到每个类别的对数概率。
  2. 使用 NLLLoss 计算负对数似然损失,即根据真实标签选择对应的对数概率并取负值。

4. 注意事项

  • 不需要手动 Softmax:由于 CrossEntropyLoss 内部已经包含了 LogSoftmax,因此在模型的最后一层不需要手动添加 Softmax 层。
  • 标签格式:真实标签应该是类别的索引,而不是 one-hot 编码形式。

5. 示例代码解析

在示例代码中,我们定义了一个简单的线性模型 SimpleModel,它有一个全连接层,输出 5 个类别的 logits。然后我们使用 CrossEntropyLoss 计算模型输出与真实标签之间的损失。

  • inputs 是一个形状为 (3, 10) 的张量,表示 3 个样本,每个样本有 10 个特征。
  • labels 是一个形状为 (3,) 的张量,表示这 3 个样本的真实类别索引。
  • outputs 是模型的输出,形状为 (3, 5),表示每个样本属于 5 个类别的 logits。
  • loss 是计算得到的损失值。

通过这个示例,我们可以看到如何使用 torch.nn.CrossEntropyLoss 来计算多分类任务中的损失。

纠错
反馈