Source: src/olm/nn/structure/combinators/parallel.py:1
Classes
Parallel(blocks: List[torch.nn.modules.module.Module], merge: Callable = None, dim: int = -1)
Bases: olm.nn.structure.combinators.base.BaseCombinator
Source: src/olm/nn/structure/combinators/parallel.py:6
Apply multiple blocks to the same input and merge their outputs.
The merge function takes a list of tensors and a dimension argument.
Parameters
blocks: Modules applied in parallel to the same input.merge: Function that combines the list of outputs and a dimension.dim: Dimension used by the merge function when applicable.
Attributes
blocks: ModuleList storing the parallel blocks.merge: Merge function used to combine outputs.dim: Dimension passed to the merge function.
Methods
forward(self, x: torch.Tensor) -> torch.Tensor
Source: src/olm/nn/structure/combinators/parallel.py:37
Apply all blocks in parallel and merge their outputs.
Parameters
x: Input tensor.
Returns
Merged output tensor.