OLM API Reference

`olm.nn.feedforward.base`

Source: src/olm/nn/feedforward/base.py:1

Classes

FeedForwardBase(embed_dim: int, **kwargs)

Bases: Module, ABC

Source: src/olm/nn/feedforward/base.py:5

Abstract base class for feedforward networks in a transformer block.

Defines the interface for FFNs/MLPs. Subclasses must implement the forward method.

Attributes

  • embed_dim (int): The input and output dimension.

Methods

forward(self, x: torch.Tensor) -> torch.Tensor

Source: src/olm/nn/feedforward/base.py:25

Forward pass of the feedforward network.

Parameters

  • x (torch.Tensor): Input tensor of shape (batch, seq_len, embed_dim).

Returns

  • torch.Tensor: Output tensor of shape (batch, seq_len, embed_dim).