Module Name: FlashAttentionTwo¶
The FlashAttentionTwo
class is a PyTorch module that implements a variant of the attention mechanism, which is a key component in many state-of-the-art models in natural language processing and other fields. This class is designed to be memory-efficient and optionally supports parallel computation and mixed precision for improved performance.
Class Definition¶
class FlashAttentionTwo(nn.Module):
def __init__(
self,
*,
dim,
heads = 8,
dim_head = 64,
causal = False,
q_bucket_size = 512,
k_bucket_size = 1024,
parallel = False,
mixed_precision = False
):
Parameters¶
dim
(int): The dimensionality of the input data.heads
(int, optional): The number of attention heads. Default is 8.dim_head
(int, optional): The dimensionality of each attention head. Default is 64.causal
(bool, optional): If True, the attention mechanism is causal. Default is False.q_bucket_size
(int, optional): The bucket size for the query in the attention mechanism. Default is 512.k_bucket_size
(int, optional): The bucket size for the key in the attention mechanism. Default is 1024.parallel
(bool, optional): If True, the computation is performed in parallel across multiple GPUs. Default is False.mixed_precision
(bool, optional): If True, the computation is performed in mixed precision for improved performance. Default is False.
Methods¶
forward
¶
Performs the forward pass of the attention mechanism.
Parameters¶
x
(Tensor): The input data.context
(Tensor, optional): The context for the attention mechanism. If not provided, the input datax
is used as the context.mask
(Tensor, optional): An optional mask for the attention mechanism.q_bucket_size
(int, optional): The bucket size for the query in the attention mechanism. If not provided, the value specified during initialization is used.k_bucket_size
(int, optional): The bucket size for the key in the attention mechanism. If not provided, the value specified during initialization is used.
Returns¶
out
(Tensor): The output of the attention mechanism.
Usage Examples¶
Example 1: Basic Usage¶
from torch import nn
from zeta.nn import FlashAttentionTwo
model = FlashAttentionTwo(dim=512)
x = torch.randn(1, 10, 512)
out = model(x)
Copy code
Example 2: Using a Mask¶
from torch import nn
from zeta.nn import FlashAttentionTwo
model = FlashAttentionTwo(dim=512)
x = torch.randn(1, 10, 512)
mask = torch.ones(1, 10)
out = model(x, mask=mask)
Example 3: Using a Context¶
from torch import nn
from zeta.nn import FlashAttentionTwo
model = FlashAttentionTwo(dim=512)
x = torch.randn(1, 10, 512)
context = torch.randn(1, 10, 512)
out = model(x, context=context)
Mathematical Formula¶
The attention mechanism can be described by the following formula:
where Q, K, and V are the query, key, and value, respectively. The softmax function ensures that the weights sum to 1, and the dot product of the weights and the value gives the output of the attention mechanism.
Additional Information¶
The FlashAttentionTwo
class is designed to be memory-efficient and optionally supports parallel computation and mixed precision for improved performance.
-
The
parallel
parameter allows the computation to be performed in parallel across multiple GPUs. This can significantly speed up the computation for large models or large datasets. -
The
mixed_precision
parameter allows the computation to be performed in mixed precision. This means that some operations are performed in lower precision (e.g., float16) and some in higher precision (e.g., float32). This can significantly speed up the computation and reduce memory usage on modern GPUs that support mixed precision. -
The
q_bucket_size
andk_bucket_size
parameters control the bucket size for the query and key in the attention mechanism, respectively. These parameters can be used to trade off between memory usage and computational efficiency. Larger bucket sizes can be more memory-efficient but may also be slower.
Common Issues¶
-
If you encounter out-of-memory errors, you can try reducing the
q_bucket_size
andk_bucket_size
parameters, or enabling mixed precision computation by settingmixed_precision=True
. -
If you encounter slow computation, you can try increasing the
q_bucket_size
andk_bucket_size
parameters, or enabling parallel computation by settingparallel=True
(if you have multiple GPUs available).
References and Resources¶
-
Attention Is All You Need: This is the original paper that introduced the concept of attention in deep learning.
-
PyTorch Documentation: The official PyTorch documentation provides detailed information about the PyTorch library and its modules.
-
Efficient Attention: Attention with Linear Complexities: This paper introduces the concept of bucketing in the attention mechanism to improve memory efficiency.
-
Mixed Precision Training: This paper introduces the concept of mixed precision training, which can significantly speed up computation and reduce memory usage on modern GPUs.
-
PyTorch Tutorials: The official PyTorch tutorials provide many examples of how to use PyTorch for various tasks.
-