OLM API Reference

`olm.train.schedulers.base`

Source: src/olm/train/schedulers/base.py:1

Base learning rate scheduler for OLM.

Classes

SchedulerBase(optimizer, last_epoch: int = -1, verbose: bool = False)

Bases: _LRScheduler, ABC

Source: src/olm/train/schedulers/base.py:8

Base class for all OLM learning rate schedulers.

This class extends PyTorch's _LRScheduler and provides a consistent interface for implementing custom learning rate schedules. All OLM schedulers should inherit from this class to maintain uniformity.

Subclasses must implement: - get_lr(): Compute the learning rate for the current step - _get_closed_form_lr() (optional): Closed-form solution for efficiency

Parameters

  • optimizer: Wrapped PyTorch optimizer.
  • last_epoch: The index of the last epoch (default: -1).
  • verbose: If True, prints a message to stdout for each update (default: False).

Example

class MyScheduler(SchedulerBase):
    def __init__(self, optimizer, param, last_epoch=-1):
        self.param = param
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        # Custom logic here
        return [base_lr * self.param for base_lr in self.base_lrs]

Methods

get_last_lr(self) -> List[float]

Source: src/olm/train/schedulers/base.py:64

Return last computed learning rate by current scheduler.

Returns

List of last computed learning rates.

get_lr(self) -> List[float]

Source: src/olm/train/schedulers/base.py:39

Compute learning rate for each parameter group.

This method must be implemented by subclasses to define the learning rate schedule logic.

Returns

List of learning rates, one per parameter group.

load_state_dict(self, state_dict)

Source: src/olm/train/schedulers/base.py:86

Load the scheduler state from a checkpoint.

Parameters

  • state_dict: Scheduler state returned by state_dict().
state_dict(self)

Source: src/olm/train/schedulers/base.py:73

Returns the state of the scheduler as a dict.

Contains all non-callable attributes that are specific to the scheduler and required for checkpointing.