交叉熵、KL散度、负对数似然三个损失函数等价

本文最后更新于:2025年3月26日 下午

前言

正好在阅读CosyVoice源码时对修改模型前向传播时的损失函数产生了疑问,因此系统地整理一下这三种损失函数的等价关系。结论就是这三个损失函数在大部分情况下都是等价的,而交叉熵损失的计算最简单应用最广泛,在其他情况下魔改损失函数时则会用到KL散度和NLL等其他损失。

从实验开始

这里用LLM训练使用的常见的序列数据做实验。

1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.nn as nn
import torch.nn.functional as F
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss

# 模型输出 logits 和目标索引
logits = torch.rand(3, 20 , 5) # batch_size=3,seq_len=20,vocab_size=5
targets = torch.randint(0, 5, (3, 20)) # 每个样本的目标类别
# 标签平滑参数
label_smoothing = 0
vocab_size = logits.size(-1)

首先定义一组测试数据,这里不考虑batch中的有效长度,即mask。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# **方法1: 使用 CrossEntropyLoss**
ce_criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
ce_loss = ce_criterion(logits.transpose(1,2), targets)

# **方法2: 使用 KLDivLoss 实现带标签平滑的损失**
log_probs = F.log_softmax(logits, dim=-1) # 转换为 log_softmax
# 构造目标分布(One-hot + Label Smoothing)
target_probs = torch.zeros_like(log_probs).fill_(label_smoothing / (vocab_size-1)) # 将所有零的位置填充上标签平滑后的值
target_probs = target_probs.scatter_(-1,targets.unsqueeze(-1),1 - label_smoothing)
# 使用 KLDivLoss 计算损失
kl_criterion = nn.KLDivLoss(reduction='none')
kl_loss = kl_criterion(log_probs, target_probs)
entropy = -(target_probs * torch.log(target_probs + 1e-12)).sum(dim=-1).mean() # 计算信息熵,在不使用标签平滑时需要有1e-12避免计算log(0)

# 方法3:使用cosyvoice实现的标签平滑损失
lsl_criterion = LabelSmoothingLoss(vocab_size, -1, label_smoothing,True)
lsl_loss = lsl_criterion(log_probs,targets)

# 方法4: 使用负对数似然损失
nll_criterion = nn.NLLLoss()
nll_loss = nll_criterion(log_probs.transpose(1, 2), targets)

依次使用四种方法计算损失,其中只有交叉熵损失函数的输入为logits,其他方法都需要输入log_softmax(logits)

1
2
3
4
5
6
7
8
9
10
11
12
print(f"CrossEntropyLoss: {ce_loss.item()}")
print(f"KLDivLoss: {kl_loss.sum().item()/(logits.size(0)*logits.size(1))}") # reduction='none'时要除以seq_len和batch数平均,等同于直接reduction='batchmean'
print(f"adjust-KL: ",kl_loss.sum().item()/(logits.size(0)*logits.size(1))+entropy.item() )
print(f"LSLoss:",lsl_loss.item())
print(f'NLL Loss: {nll_loss.item()}')

# output:
CrossEntropyLoss: 1.629240870475769
KLDivLoss: 1.6292406717936199
adjust-KL: 1.6292406717936199
LSLoss: 1.62924063205719
NLL Loss: 1.629240870475769

可以看到在没有标签平滑的时候,五个结果都可以认为是相同的。但当label_smoothing = 0.1时:

1
2
3
4
5
CrossEntropyLoss: 1.644450306892395
KLDivLoss: 1.1806944529215495
adjust-KL: 1.6444068769613902
LSLoss: 1.1806944608688354
NLL Loss: 1.64462411403656

交叉熵损失、adjust-KL、NLL 损失的结果相同,而KL散度损失和CosyVoice实现的标签平滑损失相同,而和其他结果不同了(额外说明CosyVoice的损失函数实现和这里的KL散度损失是相同的)

结论解释

KL散度又叫相对熵,相对熵=交叉熵-信息熵,则有交叉熵=相对熵+信息熵,对应adjust-KL:

1
kl_loss.sum().item()/(logits.size(0)*logits.size(1))+entropy.item()

也是因为多计算了一个相对熵并求和才得到与交叉熵损失函数相同的结果,而在非标签平滑的(独热分布)情况下,信息熵为0,也就是交叉熵=相对熵完全等价的关系。

负对数似然NLL Loss函数只能输入独热标签计算,因此在目标分布为独热分布时计算公式和交叉熵相同,在标签平滑的情况下,交叉熵损失结果有所有变化,但负对数似然损失函数无法应用标签平滑,损失结果不会变化。


交叉熵、KL散度、负对数似然三个损失函数等价
https://ash-one.github.io/2025/01/06/jiao-cha-shang-kl-san-du-fu-dui-shu-si-ran-san-ge-sun-shi-han-shu-deng-jie/
作者
灰一
发布于
2025年1月6日
许可协议