Source: src/olm/nn/attention/base.py:1
Classes
AttentionBase(embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True)
Bases: Module, ABC
Source: src/olm/nn/attention/base.py:8
Abstract base class for attention mechanisms.
Provides the common structure for attention layers, including QKV projections
and output projection. Subclasses must implement the specific attention logic
in compute_attention.
Attributes
embed_dim(int): Total dimension of the model.num_heads(int): Number of parallel attention heads.head_dim(int): Dimension of each attention head.scale(float): Scaling factor for dot products (1 / sqrt(head_dim)).dropout(nn.Dropout): Dropout layer applied to attention weights.q_proj(Linear): Linear projection for Query.k_proj(Linear): Linear projection for Key.v_proj(Linear): Linear projection for Value.out_proj(Linear): Linear projection for Output.
Methods
compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor
Source: src/olm/nn/attention/base.py:57
Computes the attention scores and output.
Parameters
q(torch.Tensor): Query tensor [batch, heads, seq, head_dim].k(torch.Tensor): Key tensor [batch, heads, seq, head_dim].v(torch.Tensor): Value tensor [batch, heads, seq, head_dim].mask(torch.Tensor, optional): Attention mask. Defaults to None.
Returns
torch.Tensor: The attention output [batch, heads, seq, head_dim].
forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor
Source: src/olm/nn/attention/base.py:73
Standard forward pass for attention layers.
Projects input to Q, K, V, calls compute_attention, and projects output.
Parameters
x(torch.Tensor): Input tensor [batch, seq, embed_dim].mask(torch.Tensor, optional): Attention mask. Defaults to None.
Returns
torch.Tensor: Output tensor [batch, seq, embed_dim].
AttentionwithRoPEBase(embed_dim: int, num_heads: int, max_seq_len: int, dropout: float = 0.0, bias: bool = True, rope_theta: float = 10000.0)
Bases: Module, ABC
Source: src/olm/nn/attention/base.py:95
Abstract base class for attention mechanisms with Rotary Positional Embedding.
Provides the common structure for attention layers, including QKV projections
and output projection. Subclasses must implement the specific attention logic
in compute_attention.
Attributes
embed_dim(int): Total dimension of the model.num_heads(int): Number of parallel attention heads.head_dim(int): Dimension of each attention head.scale(float): Scaling factor for dot products (1 / sqrt(head_dim)).dropout(nn.Dropout): Dropout layer applied to attention weights.q_proj(Linear): Linear projection for Query.k_proj(Linear): Linear projection for Key.v_proj(Linear): Linear projection for Value.out_proj(Linear): Linear projection for Output.
Methods
compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor
Source: src/olm/nn/attention/base.py:148
Computes the attention scores and output.
Parameters
q(torch.Tensor): Query tensor [batch, heads, seq, head_dim].k(torch.Tensor): Key tensor [batch, heads, seq, head_dim].v(torch.Tensor): Value tensor [batch, heads, seq, head_dim].mask(torch.Tensor, optional): Attention mask. Defaults to None.
Returns
torch.Tensor: The attention output [batch, heads, seq, head_dim].
forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor
Source: src/olm/nn/attention/base.py:164
Standard forward pass for attention layers.
Projects input to Q, K, V, calls compute_attention, and projects output.
Parameters
x(torch.Tensor): Input tensor [batch, seq, embed_dim].mask(torch.Tensor, optional): Attention mask. Defaults to None.
Returns
torch.Tensor: Output tensor [batch, seq, embed_dim].