OLM API Reference

`olm.nn.activations.swiglu`

Source: src/olm/nn/activations/swiglu.py:1

Classes

SwiGLU(*, device: torch.device | None = None, dtype: torch.dtype | None = None) -> None

Bases: olm.nn.activations.base.ActivationBase

Source: src/olm/nn/activations/swiglu.py:7

SwiGLU activation function.

Implements the SwiGLU activation as described in "GLU Variants Improve Transformer". It applies the SiLU activation to one half of the input (the gate) and multiplies it by the other half (the value).

Equation: SwiGLU(x, W, V) = Swish_1(xW) * (xV) Here, we assume the input x is already projected/concatenated such that we chunk it. So: SwiGLU(x) = (x_1 * SiLU(x_2)) where x = [x_1, x_2]

Parameters

  • device (torch.device, optional): Target device.
  • dtype (torch.dtype, optional): Target data type.

Methods

forward(self, x: torch.Tensor) -> torch.Tensor

Source: src/olm/nn/activations/swiglu.py:26

Forward pass of SwiGLU.

Parameters

  • x (torch.Tensor): Input tensor. Expected to have an even last dimension size.

Returns

  • torch.Tensor: Output tensor with half the last dimension of the input.