olm.train.trainer.trainer¶
Classes¶
Trainer(model, optimizer, dataloader, ...) |
Manages the training loop for Open Language Model (OLM) architectures. |
|---|---|
TrainerCallback() |
Base class for trainer callbacks. |
class olm.train.trainer.trainer.Any(*args, **kwargs)¶
Bases: object
Special type indicating an unconstrained type.
- Any is compatible with every type.
- Any assumed to have all methods.
- All values assumed to be instances of Any.
Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.
class olm.train.trainer.trainer.CrossEntropyLoss(*args: Any, **kwargs: Any)¶
Bases: LossBase
forward(logits: torch.Tensor, y: torch.Tensor) → torch.Tensor¶
Apply loss to logits and y.
class olm.train.trainer.trainer.DataLoader(*args: Any, **kwargs: Any)¶
Bases: DataLoader
Wrapper around PyTorch’s DataLoader with sensible defaults for LLM training.
This class extends torch.utils.data.DataLoader with: - Better defaults for language model training - Automatic worker configuration - Pin memory optimization for GPU training - Persistent workers for efficiency
- Parameters:
- dataset – Dataset to load from (can be map-style or iterable).
- batch_size – Number of samples per batch (default: 8).
- shuffle – Whether to shuffle data at every epoch (default: False for iterable datasets).
- num_workers – Number of worker processes for data loading (default: 0).
- pin_memory – If True, tensors are copied to CUDA pinned memory (default: True).
- drop_last – Drop the last incomplete batch if dataset size is not divisible by batch_size.
- persistent_workers – Keep workers alive between epochs for faster startup (default: True if num_workers > 0).
- prefetch_factor – Number of batches to prefetch per worker (default: 2).
- collate_fn – Function to merge samples into batches.
- **kwargs – Additional arguments passed to torch.utils.data.DataLoader.
Example¶
>>> from olm.data.datasets import DataLoader
>>> loader = DataLoader(
... dataset=my_dataset,
... batch_size=16,
... num_workers=4,
... pin_memory=True,
... )
>>> for batch in loader:
... # Training loop
... pass
class olm.train.trainer.trainer.LossBase(*args: Any, **kwargs: Any)¶
Bases: Module, ABC
Base class for all loss modules.
abstractmethod forward(logits: torch.Tensor, y: torch.Tensor, **kwargs) → torch.Tensor¶
Apply loss to logits and y.
class olm.train.trainer.trainer.TokenizerBase¶
Bases: ABC
Abstract base class for all tokenizers in OLM.
Defines the interface for converting between text strings and integer token IDs. Subclasses must implement encode and decode methods.
abstractmethod decode(tokens: torch.Tensor) → str¶
Converts a sequence of token IDs back into a text string.
- Parameters: tokens (torch.Tensor) – A 1D tensor or list of token IDs.
- Returns: The decoded text string.
- Return type: str
abstractmethod encode(text: str) → torch.Tensor¶
Converts a text string into a sequence of token IDs.
- Parameters: text (str) – The input text to tokenize.
- Returns: A 1D tensor containing the token IDs.
- Return type: torch.Tensor
class olm.train.trainer.trainer.Trainer(model: torch.nn.Module, optimizer: torch.optim.Optimizer | ~typing.Type[torch.optim.Optimizer], dataloader: ~olm.data.datasets.data_loader.DataLoader, device: str, context_length: int, grad_accum_steps: int = 1, use_amp: bool = True, loss: ~typing.Type[~olm.train.losses.base.LossBase] = , callbacks: ~typing.List[~olm.train.trainer.trainer.TrainerCallback] | None = None, scheduler: ~typing.Any | None = None, grad_clip_norm: float | None = None, warmup_steps: int | None = None, total_steps: int | None = None, min_lr: float = 0.0, learning_rate: float = 0.0003, weight_decay: float = 0.0, use_warmup_cosine: bool = True)¶
Bases: object
Manages the training loop for Open Language Model (OLM) architectures.
This trainer handles the core training logic including: - Automatic Mixed Precision (AMP) scaling - Gradient accumulation - Device management (moving data/models to GPU) - Optimization steps - Callbacks for validation, checkpointing, and custom logic - Learning rate scheduling support - Gradient clipping
model¶
The model to train.
- Type: Pipeline
optimizer¶
The optimizer to use.
- Type: torch.optim.Optimizer
dataloader¶
The data provider.
device¶
The device to train on (e.g., ‘cuda’, ‘cpu’).
- Type: str
context_length¶
The maximum sequence length for training.
- Type: int
grad_accum_steps¶
Number of steps to accumulate gradients before updating.
- Type: int
use_amp¶
Whether to use Automatic Mixed Precision.
- Type: bool
scaler¶
Gradient scaler for AMP.
- Type: GradScaler
loss¶
The loss function instance.
- Type: LossBase
callbacks¶
List of callbacks to execute during training.
- Type: List[TrainerCallback]
scheduler¶
Learning rate scheduler to step after each optimization step.
- Type: Optional
global_step¶
Current global step count.
- Type: int
current_epoch¶
Current epoch number.
- Type: int
add_callback(callback: TrainerCallback) → None¶
Add a callback to the trainer.
remove_callback(callback: TrainerCallback) → None¶
Remove a callback from the trainer.
train(epochs: int, log_interval: int = 10, max_steps: int = None, steps_per_epoch: int = None) → list[float]¶
Executes the training loop for a specified number of epochs.
- Parameters:
- epochs (int) – The number of complete passes through the dataset.
- log_interval (int) – How often to print the loss. Defaults to 10.
- max_steps (int , optional) – Maximum number of steps to train for.
- steps_per_epoch (int , optional) – Maximum number of steps per epoch. Defaults to None (unlimited).
- Returns: A list of recorded loss values.
- Return type: list[float]
class olm.train.trainer.trainer.TrainerCallback¶
Bases: object
Base class for trainer callbacks.
on_batch_begin(trainer: Trainer, batch_idx: int) → None¶
Called at the beginning of each batch.
on_batch_end(trainer: Trainer, batch_idx: int, loss: float) → None¶
Called at the end of each batch.
on_epoch_begin(trainer: Trainer, epoch: int) → None¶
Called at the beginning of each epoch.
on_epoch_end(trainer: Trainer, epoch: int) → None¶
Called at the end of each epoch.
on_step_begin(trainer: Trainer, step: int) → None¶
Called at the beginning of each optimization step (after gradient accumulation).
on_step_end(trainer: Trainer, step: int, loss: float) → None¶
Called at the end of each optimization step.
on_train_begin(trainer: Trainer) → None¶
Called at the beginning of training.
on_train_end(trainer: Trainer) → None¶
Called at the end of training.
class olm.train.trainer.trainer.Union¶
Bases: object
Represent a union type
E.g. for int | str
class olm.train.trainer.trainer.WarmupCosineScheduler(*args: Any, **kwargs: Any)¶
Bases: SchedulerBase
Combined warmup and cosine annealing scheduler.
Linearly warms up the learning rate from 0 to base_lr over warmup_steps, then applies cosine annealing decay to min_lr over the remaining steps.
- Parameters:
- optimizer – Wrapped optimizer.
- warmup_steps – Number of warmup steps.
- total_steps – Total number of training steps.
- min_lr – Minimum learning rate after decay (default: 0).
- last_epoch – The index of last epoch (default: -1).
Example¶
>>> from olm.train.schedulers import WarmupCosineScheduler
>>> scheduler = WarmupCosineScheduler(
... optimizer,
... warmup_steps=1000,
... total_steps=10000,
... min_lr=1e-6
... )
>>> for step in range(total_steps):
... train(...)
... scheduler.step()
get_lr()¶
Compute learning rate with warmup and cosine decay.