Skip to content

olm.nn.attention.gqa

Classes

GroupedQueryAttention(*args, **kwargs) Grouped Query Attention (GQA) with Rotary Positional Embeddings.

class olm.nn.attention.gqa.GroupedQueryAttention(*args: Any, **kwargs: Any)

Bases: Module

Grouped Query Attention (GQA) with Rotary Positional Embeddings.

GQA is a distinct attention mechanism where the number of Key/Value heads is smaller than the number of Query heads. This reduces memory bandwidth usage during inference (smaller KV cache) while maintaining performance close to Multi-Head Attention (MHA).

If num_kv_heads == num_heads, this is equivalent to MHA. If num_kv_heads == 1, this is equivalent to Multi-Query Attention (MQA).

  • Parameters:
  • embed_dim (int) – Total dimension of the model.
  • num_heads (int) – Number of Query heads.
  • num_kv_heads (int) – Number of Key/Value heads. Must divide num_heads.
  • max_seq_len (int) – Maximum sequence length for RoPE.
  • dropout (float , optional) – Dropout probability. Defaults to 0.0.
  • rope_theta (float , optional) – Base frequency for RoPE. Defaults to 10000.0.
  • use_bias (bool , optional) – Whether to use bias in linear projections. Defaults to False.

forward(x: torch.Tensor, mask: torch.Tensor | None = None) → torch.Tensor

Forward pass of Grouped Query Attention.

  • Parameters:
  • x (torch.Tensor) – Input tensor of shape [batch, seq_len, embed_dim].
  • mask (torch.Tensor , optional) – Attention mask of shape [batch, 1, seq_len, seq_len] or [batch, seq_len, seq_len]. Defaults to None.
  • Returns: Output tensor of shape [batch, seq_len, embed_dim].
  • Return type: torch.Tensor

class olm.nn.attention.gqa.Linear(*args: Any, **kwargs: Any)

Bases: Linear

forward(x)

class olm.nn.attention.gqa.RMSNorm(*args: Any, **kwargs: Any)

Bases: NormBase

RMSNorm (Root Mean Square Layer Normalization) layer.

Implements RMSNorm as described in “Root Mean Square Layer Normalization” (https://arxiv.org/abs/1910.07467). A simplified version of LayerNorm that scales invariance properties.

  • Parameters:
  • d_model (int) – The dimension of the model to normalize.
  • eps (float , optional) – Small constant for numerical stability. Defaults to 1e-5.
  • device (torch.device , optional) – Target device.
  • dtype (torch.dtype , optional) – Target data type.

weight

Learnable scale parameter.

  • Type: nn.Parameter

forward(x: torch.Tensor) → torch.Tensor

Forward pass of RMSNorm.

  • Parameters: x (torch.Tensor) – Input tensor of shape (batch_size, sequence_length, d_model).
  • Returns: Normalized output tensor of the same shape.
  • Return type: torch.Tensor

class olm.nn.attention.gqa.RotaryPositionalEmbedding(*args: Any, **kwargs: Any)

Bases: PositionalEmbeddingBase

Rotary Positional Embedding (RoPE) as described in “RoFormer: Enhanced Transformer with Rotary Position Embedding” (arXiv 2104.09864).

This module precomputes sin/cos rotation frequencies for a given head‐dim, and then applies to query/key representations via interleaving real/imag parts (or equivalently pairs of dims).

forward(x: torch.Tensor, seq_positions: torch.LongTensor | None = None) → torch.Tensor

Apply rotary positional embedding to input tensor x.

  • Parameters:
  • x – shape (batch_size, seq_len, num_heads, head_dim)
  • seq_positions – optional tensor of shape (batch_size, seq_len) with position indices. If None, assumes positions are 0..seq_len-1 for each batch.
  • Returns: Tensor of same shape as x, with RoPE applied.