OLM API Reference

`olm.train.losses.zloss`

Source: src/olm/train/losses/zloss.py:1

Classes

ZLoss(z_coeff=0.0001, **kwargs)

Bases: olm.train.losses.base.LossBase

Source: src/olm/train/losses/zloss.py:4

Z-loss (logZ^2 penalty), commonly used as an auxiliary regularizer.

For each token: logZ = logsumexp(logits, dim=-1) zloss = (logZ ** 2)

Notes:

  • This does NOT include CE. Usually you add it to CE: total = ce_loss + z_coeff * z_loss

Methods

forward(self, logits, y, mask=None)

Source: src/olm/train/losses/zloss.py:24