Skip to content

olm.train.callbacks

Callbacks for the Trainer class.

class olm.train.callbacks.CheckpointCallback(checkpoint_dir: str = 'checkpoints', save_every: int = 1000, keep_last_n: int = 5, save_best: bool = True)

Bases: TrainerCallback

Callback to save model checkpoints at specified intervals.

  • Parameters:
  • checkpoint_dir – Directory to save checkpoints.
  • save_every – Save checkpoint every N steps.
  • keep_last_n – Keep only the last N checkpoints.
  • save_best – Whether to save the best model based on validation loss.

on_step_end(trainer, step: int, loss: float) → None

Save checkpoint after each optimization step if needed.

class olm.train.callbacks.EarlyStoppingCallback(patience: int = 5, min_delta: float = 0.0)

Bases: TrainerCallback

Callback to stop training early if validation loss doesn’t improve.

  • Parameters:
  • patience – Number of validation checks to wait for improvement.
  • min_delta – Minimum change in validation loss to qualify as improvement.

on_step_end(trainer, step: int, loss: float) → None

Check for early stopping after each step.

class olm.train.callbacks.LRMonitorCallback(log_every: int = 100)

Bases: TrainerCallback

Callback to monitor and log learning rate.

  • Parameters: log_every – Log learning rate every N steps.

on_step_end(trainer, step: int, loss: float) → None

Log learning rate after each optimization step if needed.

class olm.train.callbacks.MetricsLoggerCallback(log_dir: str = 'logs', log_every: int = 10)

Bases: TrainerCallback

Callback to log metrics to a JSONL file.

  • Parameters:
  • log_dir – Directory to save logs.
  • log_every – Log metrics every N steps.

on_step_end(trainer, step: int, loss: float) → None

Log metrics after each optimization step if needed.

class olm.train.callbacks.ThroughputCallback(log_every: int = 100, context_length: int = 1024, batch_size: int = 8)

Bases: TrainerCallback

Callback to monitor training throughput (tokens/sec, samples/sec).

  • Parameters:
  • log_every – Log throughput every N steps.
  • context_length – Length of each sequence.
  • batch_size – Total batch size (including gradient accumulation).

on_step_begin(trainer, step: int) → None

Record start time of the step.

on_step_end(trainer, step: int, loss: float) → None

Calculate and log throughput.

class olm.train.callbacks.ValidationCallback(val_dataloader, eval_every: int = 500, device: str = 'cuda', use_amp: bool = True)

Bases: TrainerCallback

Callback to perform validation at specified intervals.

  • Parameters:
  • val_dataloader – Validation dataloader.
  • eval_every – Validate every N steps.
  • device – Device to run validation on.
  • use_amp – Whether to use automatic mixed precision.

on_step_end(trainer, step: int, loss: float) → None

Run validation after each optimization step if needed.

Modules

checkpoint_cb Checkpoint callback for saving model checkpoints during training.
early_stopping_cb Early stopping callback to prevent overfitting.
lr_monitor_cb Learning rate monitoring callback.
metrics_logger_cb Metrics logging callback for tracking training metrics.
throughput_cb Throughput monitoring callback.
validation_cb Validation callback for running validation during training.