Source: src/olm/train/losses/kllloss.py:1
Classes
KLLoss(kl_coeff=1.0, from_logits=True, **kwargs)
Bases: olm.train.losses.base.LossBase
Source: src/olm/train/losses/kllloss.py:7
Forward KL penalty between a policy distribution and a reference distribution.
Supports two input modes:
-
From logits: logits : [B, T, V] (policy logits) ref : [B, T, V] (reference logits)
-
From log-probs: logp : [B, T, V] (policy log-probs) ref_logp : [B, T, V] (reference log-probs)
Optional: loss_mask : [B, T] (bool or 0/1) to mask tokens
Output: scalar loss (unless reduction="none") = kl_coeff * reduced(KL per token)
Notes:
- This computes the full-distribution KL per token (sums over vocab V).
- That's the common KL regularizer used to keep the policy close to a reference.