本文最后更新于:2025年3月26日 下午
前言
正好在阅读CosyVoice源码时对修改模型前向传播时的损失函数产生了疑问,因此系统地整理一下这三种损失函数的等价关系。结论就是这三个损失函数在大部分情况下都是等价的,而交叉熵损失的计算最简单应用最广泛,在其他情况下魔改损失函数时则会用到KL散度和NLL等其他损失。
从实验开始
这里用LLM训练使用的常见的序列数据做实验。
1 2 3 4 5 6 7 8 9 10 11
| import torch import torch.nn as nn import torch.nn.functional as F from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
logits = torch.rand(3, 20 , 5) targets = torch.randint(0, 5, (3, 20))
label_smoothing = 0 vocab_size = logits.size(-1)
|
首先定义一组测试数据,这里不考虑batch中的有效长度,即mask。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| ce_criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing) ce_loss = ce_criterion(logits.transpose(1,2), targets)
log_probs = F.log_softmax(logits, dim=-1)
target_probs = torch.zeros_like(log_probs).fill_(label_smoothing / (vocab_size-1)) target_probs = target_probs.scatter_(-1,targets.unsqueeze(-1),1 - label_smoothing)
kl_criterion = nn.KLDivLoss(reduction='none') kl_loss = kl_criterion(log_probs, target_probs) entropy = -(target_probs * torch.log(target_probs + 1e-12)).sum(dim=-1).mean()
lsl_criterion = LabelSmoothingLoss(vocab_size, -1, label_smoothing,True) lsl_loss = lsl_criterion(log_probs,targets)
nll_criterion = nn.NLLLoss() nll_loss = nll_criterion(log_probs.transpose(1, 2), targets)
|
依次使用四种方法计算损失,其中只有交叉熵损失函数的输入为logits,其他方法都需要输入log_softmax(logits)
1 2 3 4 5 6 7 8 9 10 11 12
| print(f"CrossEntropyLoss: {ce_loss.item()}") print(f"KLDivLoss: {kl_loss.sum().item()/(logits.size(0)*logits.size(1))}") print(f"adjust-KL: ",kl_loss.sum().item()/(logits.size(0)*logits.size(1))+entropy.item() ) print(f"LSLoss:",lsl_loss.item()) print(f'NLL Loss: {nll_loss.item()}')
CrossEntropyLoss: 1.629240870475769 KLDivLoss: 1.6292406717936199 adjust-KL: 1.6292406717936199 LSLoss: 1.62924063205719 NLL Loss: 1.629240870475769
|
可以看到在没有标签平滑的时候,五个结果都可以认为是相同的。但当label_smoothing = 0.1
时:
1 2 3 4 5
| CrossEntropyLoss: 1.644450306892395 KLDivLoss: 1.1806944529215495 adjust-KL: 1.6444068769613902 LSLoss: 1.1806944608688354 NLL Loss: 1.64462411403656
|
交叉熵损失、adjust-KL、NLL
损失的结果相同,而KL散度损失和CosyVoice实现的标签平滑损失相同,而和其他结果不同了(额外说明CosyVoice的损失函数实现和这里的KL散度损失是相同的)
结论解释
KL散度又叫相对熵,相对熵=交叉熵-信息熵,则有交叉熵=相对熵+信息熵,对应adjust-KL:
1
| kl_loss.sum().item()/(logits.size(0)*logits.size(1))+entropy.item()
|
也是因为多计算了一个相对熵并求和才得到与交叉熵损失函数相同的结果,而在非标签平滑的(独热分布)情况下,信息熵为0,也就是交叉熵=相对熵完全等价的关系。
负对数似然NLL
Loss函数只能输入独热标签计算,因此在目标分布为独热分布时计算公式和交叉熵相同,在标签平滑的情况下,交叉熵损失结果有所有变化,但负对数似然损失函数无法应用标签平滑,损失结果不会变化。