Skip to content

olm.nn.feedforward.geglu_ffn

Classes

GeGLUFFN(*args, **kwargs) Feed-Forward Network using GeGLU activation.

class olm.nn.feedforward.geglu_ffn.FeedForwardBase(*args: Any, **kwargs: Any)

Bases: Module, ABC

Abstract base class for feedforward networks in a transformer block.

Defines the interface for FFNs/MLPs. Subclasses must implement the forward method.

embed_dim

The input and output dimension.

  • Type: int

abstractmethod forward(x: torch.Tensor) → torch.Tensor

Forward pass of the feedforward network.

  • Parameters: x (torch.Tensor) – Input tensor of shape (batch, seq_len, embed_dim).
  • Returns: Output tensor of shape (batch, seq_len, embed_dim).
  • Return type: torch.Tensor

class olm.nn.feedforward.geglu_ffn.GeGLU(*args: Any, **kwargs: Any)

Bases: ActivationBase

GeGLU activation function.

Implements the GeGLU variant from “GLU Variants Improve Transformer”. GeGLU(x, W, V) = GELU(xW) * (xV) Here: GeGLU(x) = GELU(gate) * value

  • Parameters:
  • device (torch.device , optional) – Target device.
  • dtype (torch.dtype , optional) – Target data type.

forward(x: torch.Tensor) → torch.Tensor

Forward pass of GeGLU.

  • Parameters: x (torch.Tensor) – Input tensor.
  • Returns: Output tensor with half the last dimension.
  • Return type: torch.Tensor

class olm.nn.feedforward.geglu_ffn.GeGLUFFN(*args: Any, **kwargs: Any)

Bases: FeedForwardBase

Feed-Forward Network using GeGLU activation.

Implements: x = DownProj(GeGLU(UpProj(x))). UpProj expands to 2 * hidden_dim to support splitting for the gate.

  • Parameters:
  • embed_dim (int) – Input dimension.
  • hidden_dim (int , optional) – Hidden dimension. Defaults to 4 * embed_dim if None.
  • dropout (float , optional) – Dropout probability. Defaults to 0.0.
  • bias (bool , optional) – Whether to usage bias in linear layers. Defaults to True.
  • ff_multiplier (float , optional) – Expansion factor if hidden_dim is None. Defaults to 4.0.

forward(x)

Forward pass of the feedforward network.

  • Parameters: x (torch.Tensor) – Input tensor of shape (batch, seq_len, embed_dim).
  • Returns: Output tensor of shape (batch, seq_len, embed_dim).
  • Return type: torch.Tensor

class olm.nn.feedforward.geglu_ffn.Linear(*args: Any, **kwargs: Any)

Bases: Linear

forward(x)