OLM API Reference

`olm.nn.structure.combinators.parallel`

Source: src/olm/nn/structure/combinators/parallel.py:1

Classes

Parallel(blocks: List[torch.nn.modules.module.Module], merge: Callable = None, dim: int = -1)

Bases: olm.nn.structure.combinators.base.BaseCombinator

Source: src/olm/nn/structure/combinators/parallel.py:6

Apply multiple blocks to the same input and merge their outputs.

The merge function takes a list of tensors and a dimension argument.

Parameters

  • blocks: Modules applied in parallel to the same input.
  • merge: Function that combines the list of outputs and a dimension.
  • dim: Dimension used by the merge function when applicable.

Attributes

  • blocks: ModuleList storing the parallel blocks.
  • merge: Merge function used to combine outputs.
  • dim: Dimension passed to the merge function.

Methods

forward(self, x: torch.Tensor) -> torch.Tensor

Source: src/olm/nn/structure/combinators/parallel.py:37

Apply all blocks in parallel and merge their outputs.

Parameters

  • x: Input tensor.

Returns

Merged output tensor.