Skip to content

olm.nn.activations.swiglu

Classes

SwiGLU(*args, **kwargs) SwiGLU activation function.

class olm.nn.activations.swiglu.ActivationBase(*args: Any, **kwargs: Any)

Bases: Module, ABC

Abstract base class for all activation functions.

Ensures a consistent interface for activation layers, handling device and dtype initialization. Subclasses must implement the forward method.

device

The device the module is on.

  • Type: torch.device, optional

dtype

The data type of the module parameters.

  • Type: torch.dtype

abstractmethod forward(x: torch.Tensor) → torch.Tensor

Apply activation to x.

class olm.nn.activations.swiglu.SwiGLU(*args: Any, **kwargs: Any)

Bases: ActivationBase

SwiGLU activation function.

Implements the SwiGLU activation as described in “GLU Variants Improve Transformer”. It applies the SiLU activation to one half of the input (the gate) and multiplies it by the other half (the value).

Equation: : SwiGLU(x, W, V) = Swish_1(xW) * (xV) Here, we assume the input x is already projected/concatenated such that we chunk it. So: SwiGLU(x) = (x_1 * SiLU(x_2)) where x = [x_1, x_2]

  • Parameters:
  • device (torch.device , optional) – Target device.
  • dtype (torch.dtype , optional) – Target data type.

forward(x: torch.Tensor) → torch.Tensor

Forward pass of SwiGLU.

  • Parameters: x (torch.Tensor) – Input tensor. Expected to have an even last dimension size.
  • Returns: Output tensor with half the last dimension of the input.
  • Return type: torch.Tensor