Skip to content

olm.nn.structure.combinators.parallel

Classes

Parallel(*args, **kwargs) Apply multiple blocks to the same input and merge their outputs.

class olm.nn.structure.combinators.parallel.BaseCombinator(*args: Any, **kwargs: Any)

Bases: Module, ABC

Abstract base class for combinator modules.

Subclasses implement forward to define how inputs are combined.

abstractmethod forward(x: torch.Tensor) → torch.Tensor

Compute the combinator output from an input tensor.

  • Parameters: x – Input tensor.
  • Returns: Output tensor produced by the combinator.

class olm.nn.structure.combinators.parallel.Parallel(*args: Any, **kwargs: Any)

Bases: BaseCombinator

Apply multiple blocks to the same input and merge their outputs.

The merge function takes a list of tensors and a dimension argument.

  • Parameters:
  • blocks – Modules applied in parallel to the same input.
  • merge – Function that combines the list of outputs and a dimension.
  • dim – Dimension used by the merge function when applicable.

blocks

ModuleList storing the parallel blocks.

merge

Merge function used to combine outputs.

dim

Dimension passed to the merge function.

forward(x: torch.Tensor) → torch.Tensor

Apply all blocks in parallel and merge their outputs.

  • Parameters: x – Input tensor.
  • Returns: Merged output tensor.

class olm.nn.structure.combinators.parallel.Union

Bases: object

Represent a union type

E.g. for int | str