Source: src/olm/train/trainer/trainer.py:1
Classes
Trainer(model: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer | Type[torch.optim.optimizer.Optimizer], dataloader: olm.data.datasets.data_loader.DataLoader, device: str, context_length: int, grad_accum_steps: int = 1, use_amp: bool = True, loss: Type[olm.train.losses.base.LossBase] = <class 'olm.train.losses.cross_entropy.CrossEntropyLoss'>, callbacks: List[olm.train.trainer.trainer.TrainerCallback] | None = None, scheduler: Any | None = None, grad_clip_norm: float | None = None, warmup_steps: int | None = None, total_steps: int | None = None, min_lr: float = 0.0, learning_rate: float = 0.0003, weight_decay: float = 0.0, use_warmup_cosine: bool = True)
Source: src/olm/train/trainer/trainer.py:53
Manages the training loop for Open Language Model (OLM) architectures.
This trainer handles the core training logic including:
- Automatic Mixed Precision (AMP) scaling
- Gradient accumulation
- Device management (moving data/models to GPU)
- Optimization steps
- Callbacks for validation, checkpointing, and custom logic
- Learning rate scheduling support
- Gradient clipping
Training contract:
The dataloader must yield (input_ids, labels). Both tensors are
moved to device and truncated to context_length. The model must
return logits shaped [batch, seq_len, vocab_size] so the configured
loss can compare logits against labels shaped [batch, seq_len].
Attributes
model(Pipeline): The model to train.optimizer(torch.optim.Optimizer): The optimizer to use.dataloader(olm.data.datasets.data_loader.DataLoader): The data provider.device(str): The device to train on (e.g., 'cuda', 'cpu').context_length(int): The maximum sequence length for training.grad_accum_steps(int): Number of steps to accumulate gradients before updating.use_amp(bool): Whether to use Automatic Mixed Precision.scaler(GradScaler): Gradient scaler for AMP.loss(LossBase): The loss function instance.callbacks(List[TrainerCallback]): List of callbacks to execute during training.scheduler(Optional): Learning rate scheduler to step after each optimization step.global_step(int): Current global step count.current_epoch(int): Current epoch number.total_tokens_processed(int): Total number of tokens processed during training.step_start_time(float): Timestamp of the current step start.training_start_time(float): Timestamp when training began.
Methods
add_callback(self, callback: olm.train.trainer.trainer.TrainerCallback) -> None
Source: src/olm/train/trainer/trainer.py:213
Add a callback to the trainer.
remove_callback(self, callback: olm.train.trainer.trainer.TrainerCallback) -> None
Source: src/olm/train/trainer/trainer.py:217
Remove a callback from the trainer.
train(self, epochs: int, log_interval: int = 10, max_steps: int = None, steps_per_epoch: int = None) -> list[float]
Source: src/olm/train/trainer/trainer.py:267
Executes the training loop for a specified number of epochs.
Parameters
epochs(int): The number of complete passes through the dataset.log_interval(int): How often to print the loss. Defaults to 10.max_steps(int, optional): Maximum number of steps to train for.steps_per_epoch(int, optional): Maximum number of steps per epoch. Defaults to None (unlimited).
Returns
list[float]: A list of recorded loss values.
TrainerCallback()
Source: src/olm/train/trainer/trainer.py:17
Base class for trainer callbacks.
Methods
on_batch_begin(self, trainer: 'Trainer', batch_idx: int) -> None
Source: src/olm/train/trainer/trainer.py:36
Called at the beginning of each batch.
on_batch_end(self, trainer: 'Trainer', batch_idx: int, loss: float) -> None
Source: src/olm/train/trainer/trainer.py:40
Called at the end of each batch.
on_epoch_begin(self, trainer: 'Trainer', epoch: int) -> None
Source: src/olm/train/trainer/trainer.py:28
Called at the beginning of each epoch.
on_epoch_end(self, trainer: 'Trainer', epoch: int) -> None
Source: src/olm/train/trainer/trainer.py:32
Called at the end of each epoch.
on_step_begin(self, trainer: 'Trainer', step: int) -> None
Source: src/olm/train/trainer/trainer.py:44
Called at the beginning of each optimization step (after gradient accumulation).
on_step_end(self, trainer: 'Trainer', step: int, loss: float) -> None
Source: src/olm/train/trainer/trainer.py:48
Called at the end of each optimization step.
on_train_begin(self, trainer: 'Trainer') -> None
Source: src/olm/train/trainer/trainer.py:20
Called at the beginning of training.
on_train_end(self, trainer: 'Trainer') -> None
Source: src/olm/train/trainer/trainer.py:24
Called at the end of training.