Skip to content

olm.nn.embeddings.token_embed

Classes

Embedding(*args, **kwargs) Token Embedding layer.

class olm.nn.embeddings.token_embed.Embedding(*args: Any, **kwargs: Any)

Bases: Module

Token Embedding layer.

Wraps standard PyTorch embedding with a clean interface. Maps integer indices to dense vectors.

  • Parameters:
  • vocab_size (int) – Size of the vocabulary.
  • embedding_dim (int) – Dimensionality of the word embeddings.

embedding

The underlying PyTorch embedding layer.

  • Type: nn.Embedding

forward(x: torch.Tensor) → torch.Tensor

Forward pass of the Embedding layer.

  • Parameters: x (torch.Tensor) – Input tensor of shape (batch_size, seq_len) containing token IDs.
  • Returns: Output tensor of shape (batch_size, seq_len, embedding_dim).
  • Return type: torch.Tensor