Skip to content

BaseAttention Abstract Class


The BaseAttention class is an abstract base class that defines the interface for all attention mechanisms. It includes the basic structure and methods that all attention mechanisms should have.

from abc import abstractmethod

import torch.nn as nn

class BaseAttention(nn.Module):
    def __init__(self, dim):
        self.dim = dim

    def forward(self, x, context=None, mask=None):


The FlashAttentionTwo class extends the BaseAttention abstract base class and implements the specific attention mechanism.

class FlashAttentionTwo(BaseAttention):
    def __init__(
        heads = 8,
        dim_head = 64,
        causal = False,
        q_bucket_size = 512,
        k_bucket_size = 1024,
        parallel = False,
        mixed_precision = False
        super().__init__(dim, heads, dim_head)
        self.causal = causal
        self.parallel = parallel
        self.mixed_precision = mixed_precision
        self.q_bucket_size = q_bucket_size
        self.k_bucket_size = k_bucket_size
        # ... rest of the implementation ...

    def forward(
        context = None,
        mask = None,
        q_bucket_size = None,
        k_bucket_size = None,
        # ... implementation of the forward method ...

Rules for Using the BaseAttention Class

  1. Any class that extends the BaseAttention class must implement the forward method. This method defines how the attention mechanism operates.

  2. The __init__ method of the BaseAttention class takes three parameters: dimheads, and dim_head. Any class that extends BaseAttention should pass these parameters to the __init__ method of the base class.

  3. The forward method of the BaseAttention class takes three parameters: xcontext, and mask. Any class that extends BaseAttention should include these parameters in its forward method.

Example of Using the FlashAttentionTwo Class

from zeta.nn.attention import FlashAttentionTwo

# Create an instance of the FlashAttentionTwo class
attention = FlashAttentionTwo(dim=512, heads=8, dim_head=64)

# Create some input data
x = torch.randn(1, 10, 512)

# Apply the attention mechanism
out = attention(x)

In this example, we first create an instance of the FlashAttentionTwo class. We then create some input data x and apply the attention mechanism to this data by calling the forward method of the attention instance.