Skip to content

Module Name: FlashAttentionTwo

The FlashAttentionTwo class is a PyTorch module that implements a variant of the attention mechanism, which is a key component in many state-of-the-art models in natural language processing and other fields. This class is designed to be memory-efficient and optionally supports parallel computation and mixed precision for improved performance.

Class Definition


class FlashAttentionTwo(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        causal = False,
        q_bucket_size = 512,
        k_bucket_size = 1024,
        parallel = False,
        mixed_precision = False
    ):

Parameters

  • dim (int): The dimensionality of the input data.
  • heads (int, optional): The number of attention heads. Default is 8.
  • dim_head (int, optional): The dimensionality of each attention head. Default is 64.
  • causal (bool, optional): If True, the attention mechanism is causal. Default is False.
  • q_bucket_size (int, optional): The bucket size for the query in the attention mechanism. Default is 512.
  • k_bucket_size (int, optional): The bucket size for the key in the attention mechanism. Default is 1024.
  • parallel (bool, optional): If True, the computation is performed in parallel across multiple GPUs. Default is False.
  • mixed_precision (bool, optional): If True, the computation is performed in mixed precision for improved performance. Default is False.

Methods

forward

def forward(
    self,
    x,
    context = None,
    mask = None,
    q_bucket_size = None,
    k_bucket_size = None,
):

Performs the forward pass of the attention mechanism.

Parameters
  • x (Tensor): The input data.
  • context (Tensor, optional): The context for the attention mechanism. If not provided, the input data x is used as the context.
  • mask (Tensor, optional): An optional mask for the attention mechanism.
  • q_bucket_size (int, optional): The bucket size for the query in the attention mechanism. If not provided, the value specified during initialization is used.
  • k_bucket_size (int, optional): The bucket size for the key in the attention mechanism. If not provided, the value specified during initialization is used.

Returns
  • out (Tensor): The output of the attention mechanism.

Usage Examples


Example 1: Basic Usage

from torch import nn

from zeta.nn import FlashAttentionTwo

model = FlashAttentionTwo(dim=512)
x = torch.randn(1, 10, 512)
out = model(x)

Copy code

Example 2: Using a Mask

from torch import nn

from zeta.nn import FlashAttentionTwo

model = FlashAttentionTwo(dim=512)
x = torch.randn(1, 10, 512)
mask = torch.ones(1, 10)
out = model(x, mask=mask)

Example 3: Using a Context

from torch import nn

from zeta.nn import FlashAttentionTwo

model = FlashAttentionTwo(dim=512)
x = torch.randn(1, 10, 512)
context = torch.randn(1, 10, 512)
out = model(x, context=context)

Mathematical Formula


The attention mechanism can be described by the following formula:

Attention Formula

where Q, K, and V are the query, key, and value, respectively. The softmax function ensures that the weights sum to 1, and the dot product of the weights and the value gives the output of the attention mechanism.

Additional Information


The FlashAttentionTwo class is designed to be memory-efficient and optionally supports parallel computation and mixed precision for improved performance.

  • The parallel parameter allows the computation to be performed in parallel across multiple GPUs. This can significantly speed up the computation for large models or large datasets.

  • The mixed_precision parameter allows the computation to be performed in mixed precision. This means that some operations are performed in lower precision (e.g., float16) and some in higher precision (e.g., float32). This can significantly speed up the computation and reduce memory usage on modern GPUs that support mixed precision.

  • The q_bucket_size and k_bucket_size parameters control the bucket size for the query and key in the attention mechanism, respectively. These parameters can be used to trade off between memory usage and computational efficiency. Larger bucket sizes can be more memory-efficient but may also be slower.

Common Issues


  • If you encounter out-of-memory errors, you can try reducing the q_bucket_size and k_bucket_size parameters, or enabling mixed precision computation by setting mixed_precision=True.

  • If you encounter slow computation, you can try increasing the q_bucket_size and k_bucket_size parameters, or enabling parallel computation by setting parallel=True (if you have multiple GPUs available).

References and Resources


-