olm.nn.embeddings.positional.alibi¶
Classes¶
ALiBiPositionalBias(*args, **kwargs) |
Attention with Linear Biases (ALiBi) as described in "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation" (arXiv 2108.12409). |
|---|---|
class olm.nn.embeddings.positional.alibi.ALiBiPositionalBias(*args: Any, **kwargs: Any)¶
Bases: PositionalEmbeddingBase
Attention with Linear Biases (ALiBi) as described in “Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation” (arXiv 2108.12409).
Instead of adding positional information to embeddings, ALiBi adds a bias to attention scores that is proportional to the distance between query and key positions. This allows the model to extrapolate to longer sequences than seen during training.
The bias is computed as: bias[i,j] = -m * |i - j|
where m is a head-specific slope.
forward(seq_len_q: int, seq_len_k: int, device: torch.device | None = None) → torch.Tensor¶
Get ALiBi bias for the given query and key sequence lengths.
- Parameters:
- seq_len_q – length of query sequence
- seq_len_k – length of key sequence (usually same as seq_len_q)
- device – device to place the bias tensor on
- Returns: Bias tensor of shape (1, num_heads, seq_len_q, seq_len_k) This should be added to attention scores before softmax.
class olm.nn.embeddings.positional.alibi.PositionalEmbeddingBase(*args: Any, **kwargs: Any)¶
Bases: Module, ABC
Abstract base class for all positional embedding implementations.
Positional embeddings add information about token positions in a sequence to help the model understand order and relative positions. Different positional embedding strategies have different properties:
- Learned (Absolute): Simple, effective, but limited to max_seq_len
- Sinusoidal: Deterministic, can extrapolate to longer sequences
- RoPE: Applied to Q/K directly, enables relative position modeling
- ALiBi: Adds bias to attention scores, excellent extrapolation
All positional embedding implementations should inherit from this base class and implement the forward method.
extra_repr() → str¶
String representation of the module for debugging.
Override this in subclasses to provide useful information.
abstractmethod forward(*args, **kwargs) → torch.Tensor¶
Apply positional information to input tensor(s).
The signature and behavior of this method varies by implementation: - Some add to embeddings (Absolute, Sinusoidal) - Some rotate representations (RoPE) - Some return bias to add to attention scores (ALiBi)
- Returns: Transformed tensor(s) with positional information applied