Skip to content

olm.train

Training infrastructure for OLM.

class olm.train.AdamW(*args: Any, **kwargs: Any)

Bases: AdamW

AdamW optimizer with decoupled weight decay regularization.

This is a wrapper around PyTorch’s built-in AdamW implementation from “Decoupled Weight Decay Regularization” (Loshchilov & Hutter, 2017). Unlike the original Adam, weight decay is applied directly to the parameters rather than being added to the gradient.

This implementation is commonly used for training large language models and transformers, offering better generalization than standard Adam.

Note: This class inherits from PyTorch’s AdamW which ultimately inherits from torch.optim.Optimizer, maintaining compatibility with our OptimizerBase interface.

  • Parameters:
  • params – iterable of parameters to optimize or dicts defining parameter groups
  • lr – learning rate (default: 1e-3)
  • betas – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))
  • eps – term added to the denominator to improve numerical stability (default: 1e-8)
  • weight_decay – weight decay coefficient (default: 0.01)
  • amsgrad – whether to use the AMSGrad variant (default: False)
  • maximize – maximize the params based on the objective, instead of minimizing (default: False)
  • fused – whether to use the fused implementation (default: None, auto-detect)

Example

>>> model = nn.Linear(10, 5)
>>> optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
>>> optimizer.zero_grad()
>>> loss = model(input).sum()
>>> loss.backward()
>>> optimizer.step()

class olm.train.CheckpointCallback(checkpoint_dir: str = 'checkpoints', save_every: int = 1000, keep_last_n: int = 5, save_best: bool = True)

Bases: TrainerCallback

Callback to save model checkpoints at specified intervals.

  • Parameters:
  • checkpoint_dir – Directory to save checkpoints.
  • save_every – Save checkpoint every N steps.
  • keep_last_n – Keep only the last N checkpoints.
  • save_best – Whether to save the best model based on validation loss.

on_step_end(trainer, step: int, loss: float) → None

Save checkpoint after each optimization step if needed.

class olm.train.CosineAnnealingLR(*args: Any, **kwargs: Any)

Bases: SchedulerBase

Cosine annealing learning rate scheduler.

Decreases the learning rate following a cosine curve from the initial learning rate to eta_min over T_max steps.

  • Parameters:
  • optimizer – Wrapped optimizer.
  • T_max – Maximum number of iterations (steps).
  • eta_min – Minimum learning rate (default: 0).
  • last_epoch – The index of last epoch (default: -1).

Example

>>> from olm.train.schedulers import CosineAnnealingLR
>>> scheduler = CosineAnnealingLR(optimizer, T_max=1000, eta_min=1e-6)
>>> for epoch in range(epochs):
...     train(...)
...     scheduler.step()

get_lr()

Compute learning rate using cosine annealing.

class olm.train.EarlyStoppingCallback(patience: int = 5, min_delta: float = 0.0)

Bases: TrainerCallback

Callback to stop training early if validation loss doesn’t improve.

  • Parameters:
  • patience – Number of validation checks to wait for improvement.
  • min_delta – Minimum change in validation loss to qualify as improvement.

on_step_end(trainer, step: int, loss: float) → None

Check for early stopping after each step.

class olm.train.LRMonitorCallback(log_every: int = 100)

Bases: TrainerCallback

Callback to monitor and log learning rate.

  • Parameters: log_every – Log learning rate every N steps.

on_step_end(trainer, step: int, loss: float) → None

Log learning rate after each optimization step if needed.

class olm.train.LinearDecayLR(*args: Any, **kwargs: Any)

Bases: SchedulerBase

Simple linear decay scheduler that decays to zero.

This is a simplified version that always decays to 0 from the initial LR.

  • Parameters:
  • optimizer – Wrapped optimizer.
  • total_steps – Total number of steps to decay over.
  • last_epoch – The index of last epoch (default: -1).

Example

>>> from olm.train.schedulers import LinearDecayLR
>>> scheduler = LinearDecayLR(optimizer, total_steps=1000)
>>> for step in range(total_steps):
...     train(...)
...     scheduler.step()

get_lr()

Compute learning rate using linear decay.

class olm.train.LinearLR(*args: Any, **kwargs: Any)

Bases: SchedulerBase

Linear learning rate scheduler.

Linearly decreases (or increases) the learning rate from the initial learning rate to end_lr over total_steps.

  • Parameters:
  • optimizer – Wrapped optimizer.
  • total_steps – Total number of steps for the schedule.
  • end_lr – Target learning rate at the end (default: 0).
  • start_factor – Initial learning rate multiplier (default: 1.0).
  • last_epoch – The index of last epoch (default: -1).

Example

>>> from olm.train.schedulers import LinearLR
>>> # Decay from initial LR to 0
>>> scheduler = LinearLR(optimizer, total_steps=1000, end_lr=0)
>>> for step in range(total_steps):
...     train(...)
...     scheduler.step()

get_lr()

Compute learning rate using linear interpolation.

class olm.train.Lion(*args: Any, **kwargs: Any)

Bases: OptimizerBase

Lion optimizer (EvoLved Sign Momentum).

Implements the Lion algorithm from “Symbolic Discovery of Optimization Algorithms” (Chen et al., 2023). Lion uses only the sign of the gradient for updates, making it more memory-efficient than Adam while often achieving better performance.

Key differences from Adam: - Uses sign of interpolated gradient for updates (memory efficient) - Single momentum buffer instead of two (m and v in Adam) - Typically requires smaller learning rates (1/3 to 1/10 of AdamW) - Larger weight decay (3-10x that of AdamW)

  • Parameters:
  • params – iterable of parameters to optimize or dicts defining parameter groups
  • lr – learning rate (default: 1e-4, typically 3-10x smaller than AdamW)
  • betas – coefficients used for computing running averages (default: (0.9, 0.99))
  • weight_decay – weight decay coefficient (default: 0.0)
  • use_triton – whether to use Triton kernel for faster computation (default: False)

Example

>>> model = nn.Linear(10, 5)
>>> optimizer = Lion(model.parameters(), lr=1e-4, weight_decay=0.1)
>>> optimizer.zero_grad()
>>> loss = model(input).sum()
>>> loss.backward()
>>> optimizer.step()

step(closure: Callable[[], float] | None = None) → float | None

Performs a single optimization step.

  • Parameters: closure – A closure that reevaluates the model and returns the loss.
  • Returns: Optional loss value if closure is provided.

zero_grad(set_to_none: bool = True)

Sets gradients of all optimized tensors to zero.

  • Parameters: set_to_none – instead of setting to zero, set the grads to None. This is more memory efficient and can slightly improve performance.

class olm.train.MetricsLoggerCallback(log_dir: str = 'logs', log_every: int = 10)

Bases: TrainerCallback

Callback to log metrics to a JSONL file.

  • Parameters:
  • log_dir – Directory to save logs.
  • log_every – Log metrics every N steps.

on_step_end(trainer, step: int, loss: float) → None

Log metrics after each optimization step if needed.

class olm.train.OptimizerBase(*args: Any, **kwargs: Any)

Bases: Optimizer, ABC

Abstract base class for all optimizers in the OLM framework.

Provides a consistent interface for optimizer implementations, including standard methods for parameter updates, gradient zeroing, and state management. All custom optimizers should inherit from this class.

This base class extends PyTorch’s Optimizer class and adds additional functionality specific to the OLM framework.

Subclasses must implement the step() method to define the optimization logic.

extra_repr() → str

String representation of the optimizer for debugging.

Override this in subclasses to provide useful information.

load_state_dict(state_dict: Dict[str, Any])

Loads the optimizer state.

  • Parameters: state_dict – optimizer state. Should be an object returned from a call to state_dict().

state_dict() → Dict[str, Any]

Returns the state of the optimizer as a dict.

It contains two entries:

  • state: dict holding current optimization state. Its content differs between optimizer classes.
  • param_groups: list containing all parameter groups where each parameter group is a dict.

  • Returns: Dictionary containing optimizer state

abstractmethod step(closure: Callable[[], float] | None = None) → float | None

Performs a single optimization step.

  • Parameters: closure – A closure that reevaluates the model and returns the loss. Some optimization algorithms (e.g., L-BFGS) require multiple evaluations of the loss function.
  • Returns: Optional loss value if closure is provided.

zero_grad(set_to_none: bool = True)

Sets gradients of all optimized tensors to zero or None.

  • Parameters: set_to_none – Instead of setting to zero, set the grads to None. This is more memory efficient and can slightly improve performance. Default: True

class olm.train.SchedulerBase(*args: Any, **kwargs: Any)

Bases: _LRScheduler, ABC

Base class for all OLM learning rate schedulers.

This class extends PyTorch’s _LRScheduler and provides a consistent interface for implementing custom learning rate schedules. All OLM schedulers should inherit from this class to maintain uniformity.

Subclasses must implement: : - get_lr(): Compute the learning rate for the current step - _get_closed_form_lr() (optional): Closed-form solution for efficiency

  • Parameters:
  • optimizer – Wrapped PyTorch optimizer.
  • last_epoch – The index of the last epoch (default: -1).
  • verbose – If True, prints a message to stdout for each update (default: False).

Example

>>> class MyScheduler(SchedulerBase):
...     def __init__(self, optimizer, param, last_epoch=-1):
...         self.param = param
...         super().__init__(optimizer, last_epoch)
...
...     def get_lr(self):
...         # Custom logic here
...         return [base_lr * self.param for base_lr in self.base_lrs]

get_last_lr() → List[float]

Return last computed learning rate by current scheduler.

  • Returns: List of last computed learning rates.

abstractmethod get_lr() → List[float]

Compute learning rate for each parameter group.

This method must be implemented by subclasses to define the learning rate schedule logic.

  • Returns: List of learning rates, one per parameter group.

load_state_dict(state_dict)

Load the scheduler state from a checkpoint.

  • Parameters: state_dict – Scheduler state returned by state_dict().

state_dict()

Returns the state of the scheduler as a dict.

Contains all non-callable attributes that are specific to the scheduler and required for checkpointing.

class olm.train.ThroughputCallback(log_every: int = 100, context_length: int = 1024, batch_size: int = 8)

Bases: TrainerCallback

Callback to monitor training throughput (tokens/sec, samples/sec).

  • Parameters:
  • log_every – Log throughput every N steps.
  • context_length – Length of each sequence.
  • batch_size – Total batch size (including gradient accumulation).

on_step_begin(trainer, step: int) → None

Record start time of the step.

on_step_end(trainer, step: int, loss: float) → None

Calculate and log throughput.

class olm.train.Trainer(model: torch.nn.Module, optimizer: torch.optim.Optimizer | ~typing.Type[torch.optim.Optimizer], dataloader: ~olm.data.datasets.data_loader.DataLoader, device: str, context_length: int, grad_accum_steps: int = 1, use_amp: bool = True, loss: ~typing.Type[~olm.train.losses.base.LossBase] = , callbacks: ~typing.List[~olm.train.trainer.trainer.TrainerCallback] | None = None, scheduler: ~typing.Any | None = None, grad_clip_norm: float | None = None, warmup_steps: int | None = None, total_steps: int | None = None, min_lr: float = 0.0, learning_rate: float = 0.0003, weight_decay: float = 0.0, use_warmup_cosine: bool = True)

Bases: object

Manages the training loop for Open Language Model (OLM) architectures.

This trainer handles the core training logic including: - Automatic Mixed Precision (AMP) scaling - Gradient accumulation - Device management (moving data/models to GPU) - Optimization steps - Callbacks for validation, checkpointing, and custom logic - Learning rate scheduling support - Gradient clipping

model

The model to train.

  • Type: Pipeline

optimizer

The optimizer to use.

  • Type: torch.optim.Optimizer

dataloader

The data provider.

device

The device to train on (e.g., ‘cuda’, ‘cpu’).

  • Type: str

context_length

The maximum sequence length for training.

  • Type: int

grad_accum_steps

Number of steps to accumulate gradients before updating.

  • Type: int

use_amp

Whether to use Automatic Mixed Precision.

  • Type: bool

scaler

Gradient scaler for AMP.

  • Type: GradScaler

loss

The loss function instance.

callbacks

List of callbacks to execute during training.

scheduler

Learning rate scheduler to step after each optimization step.

  • Type: Optional

global_step

Current global step count.

  • Type: int

current_epoch

Current epoch number.

  • Type: int

add_callback(callback: TrainerCallback) → None

Add a callback to the trainer.

remove_callback(callback: TrainerCallback) → None

Remove a callback from the trainer.

train(epochs: int, log_interval: int = 10, max_steps: int = None, steps_per_epoch: int = None) → list[float]

Executes the training loop for a specified number of epochs.

  • Parameters:
  • epochs (int) – The number of complete passes through the dataset.
  • log_interval (int) – How often to print the loss. Defaults to 10.
  • max_steps (int , optional) – Maximum number of steps to train for.
  • steps_per_epoch (int , optional) – Maximum number of steps per epoch. Defaults to None (unlimited).
  • Returns: A list of recorded loss values.
  • Return type: list[float]

class olm.train.TrainerCallback

Bases: object

Base class for trainer callbacks.

on_batch_begin(trainer: Trainer, batch_idx: int) → None

Called at the beginning of each batch.

on_batch_end(trainer: Trainer, batch_idx: int, loss: float) → None

Called at the end of each batch.

on_epoch_begin(trainer: Trainer, epoch: int) → None

Called at the beginning of each epoch.

on_epoch_end(trainer: Trainer, epoch: int) → None

Called at the end of each epoch.

on_step_begin(trainer: Trainer, step: int) → None

Called at the beginning of each optimization step (after gradient accumulation).

on_step_end(trainer: Trainer, step: int, loss: float) → None

Called at the end of each optimization step.

on_train_begin(trainer: Trainer) → None

Called at the beginning of training.

on_train_end(trainer: Trainer) → None

Called at the end of training.

class olm.train.ValidationCallback(val_dataloader, eval_every: int = 500, device: str = 'cuda', use_amp: bool = True)

Bases: TrainerCallback

Callback to perform validation at specified intervals.

  • Parameters:
  • val_dataloader – Validation dataloader.
  • eval_every – Validate every N steps.
  • device – Device to run validation on.
  • use_amp – Whether to use automatic mixed precision.

on_step_end(trainer, step: int, loss: float) → None

Run validation after each optimization step if needed.

class olm.train.WarmupCosineScheduler(*args: Any, **kwargs: Any)

Bases: SchedulerBase

Combined warmup and cosine annealing scheduler.

Linearly warms up the learning rate from 0 to base_lr over warmup_steps, then applies cosine annealing decay to min_lr over the remaining steps.

  • Parameters:
  • optimizer – Wrapped optimizer.
  • warmup_steps – Number of warmup steps.
  • total_steps – Total number of training steps.
  • min_lr – Minimum learning rate after decay (default: 0).
  • last_epoch – The index of last epoch (default: -1).

Example

>>> from olm.train.schedulers import WarmupCosineScheduler
>>> scheduler = WarmupCosineScheduler(
...     optimizer,
...     warmup_steps=1000,
...     total_steps=10000,
...     min_lr=1e-6
... )
>>> for step in range(total_steps):
...     train(...)
...     scheduler.step()

get_lr()

Compute learning rate with warmup and cosine decay.

class olm.train.WarmupLR(*args: Any, **kwargs: Any)

Bases: SchedulerBase

Learning rate warmup scheduler.

Linearly increases the learning rate from 0 to the base learning rate over warmup_steps.

  • Parameters:
  • optimizer – Wrapped optimizer.
  • warmup_steps – Number of warmup steps.
  • start_lr – Initial learning rate (default: 0).
  • last_epoch – The index of last epoch (default: -1).

Example

>>> from olm.train.schedulers import WarmupLR
>>> scheduler = WarmupLR(optimizer, warmup_steps=1000)
>>> for step in range(warmup_steps):
...     train(...)
...     scheduler.step()

get_lr()

Compute learning rate during warmup.

Modules

callbacks Callbacks for the Trainer class.
optim
regularization
schedulers Learning rate schedulers for OLM training.
trainer