olm.nn.structure.combinators¶
class olm.nn.structure.combinators.BaseCombinator(*args: Any, **kwargs: Any)¶
Bases: Module, ABC
Abstract base class for combinator modules.
Subclasses implement forward to define how inputs are combined.
abstractmethod forward(x: torch.Tensor) → torch.Tensor¶
Compute the combinator output from an input tensor.
- Parameters: x – Input tensor.
- Returns: Output tensor produced by the combinator.
class olm.nn.structure.combinators.Parallel(*args: Any, **kwargs: Any)¶
Bases: BaseCombinator
Apply multiple blocks to the same input and merge their outputs.
The merge function takes a list of tensors and a dimension argument.
- Parameters:
- blocks – Modules applied in parallel to the same input.
- merge – Function that combines the list of outputs and a dimension.
- dim – Dimension used by the merge function when applicable.
blocks¶
ModuleList storing the parallel blocks.
merge¶
Merge function used to combine outputs.
dim¶
Dimension passed to the merge function.
forward(x: torch.Tensor) → torch.Tensor¶
Apply all blocks in parallel and merge their outputs.
- Parameters: x – Input tensor.
- Returns: Merged output tensor.
class olm.nn.structure.combinators.Repeat(*args: Any, **kwargs: Any)¶
Bases: BaseCombinator
Repeat a module a fixed number of times in sequence.
The module function should return a new module instance each call.
- Parameters:
- module_func – Callable returning a new module instance.
- num_repeat – Number of times to repeat the module.
module¶
Factory callable used to create new modules.
num_repeat¶
Number of repeats.
stack¶
ModuleList containing the repeated modules.
forward(x: torch.Tensor) → torch.Tensor¶
Apply the repeated modules in sequence.
- Parameters: x – Input tensor.
- Returns: Output tensor after all repeats.
class olm.nn.structure.combinators.Residual(*args: Any, **kwargs: Any)¶
Bases: BaseCombinator
Residual wrapper that adds the block output to its input.
- Parameters: block – Module applied to the input before residual addition.
block¶
Module used for the residual transformation.
forward(x: torch.Tensor) → torch.Tensor¶
Apply the block and add the result to the input.
- Parameters: x – Input tensor.
- Returns: Output tensor with residual connection applied.
Modules¶
base |
|
|---|---|
parallel |
|
repeat |
|
residual |