FlashAttention¶
The FlashAttention module performs efficient attention computations, specifically designed for leveraging hardware capabilities on certain NVIDIA GPUs. It offers the option to perform "flash" attention which can be computationally faster on specific GPU architectures.
Class Definition:¶
Parameters:¶
causal
(bool, optional): Determines whether to apply causal masking. Default: False.dropout
(float, optional): Dropout probability. Default: 0.flash
(bool, optional): Whether to use flash attention. Requires PyTorch version 2.0 or above. Default: True.
Methods:¶
__init__(self, causal=False, dropout=0., flash=True)
¶
Initializes the FlashAttention module.
get_mask(self, i, j, device)
¶
Generates a mask for attention computation.
Parameters:¶
i
(int): Length of the query sequence.j
(int): Length of the key sequence.device
(torch.device): Device to place the mask tensor.
Returns:¶
torch.Tensor
: Mask tensor of shape(i, j)
.
flash_attn(self, q, k, v, mask=None, attn_bias=None)
¶
Performs flash attention computation.
Parameters:¶
q
(torch.Tensor): Query tensor of shape(batch, heads, q_len, dim)
.k
(torch.Tensor): Key tensor of shape(batch, heads, k_len, dim)
.v
(torch.Tensor): Value tensor of shape(batch, heads, v_len, dim)
.mask
(torch.Tensor, optional): Mask tensor of shape(batch, heads, q_len, k_len)
. Default: None.attn_bias
(torch.Tensor, optional): Attention bias tensor of shape(batch, heads, q_len, k_len)
. Default: None.
Returns:¶
torch.Tensor
: Output tensor of shape(batch, heads, q_len, dim)
.
forward(self, q, k, v, mask=None, attn_bias=None)
¶
Performs the attention computation using einstein notation.
Parameters:¶
q
(torch.Tensor): Query tensor of shape(batch, heads, q_len, dim)
.k
(torch.Tensor): Key tensor of shape(batch, heads, k_len, dim)
.v
(torch.Tensor): Value tensor of shape(batch, heads, v_len, dim)
.mask
(torch.Tensor, optional): Mask tensor of shape(batch, heads, q_len, k_len)
. Default: None.attn_bias
(torch.Tensor, optional): Attention bias tensor of shape(batch, heads, q_len, k_len)
. Default: None.
Returns:¶
torch.Tensor
: Attention output tensor.
Usage Examples:¶
-
Basic Usage:
-
Using Flash Attention with Masking:
-
Using Causal Flash Attention with Dropout:
Additional Tips:¶
- The
FlashAttention
module is optimized for NVIDIA A100 GPUs. On these GPUs, usingflash=True
is recommended for faster computation. - Ensure that PyTorch version is 2.0 or above when enabling flash attention.
- The mask generated using
get_mask
method is useful for attention computations where certain positions need to be masked out.
References:¶
- Original Attention Mechanism: Attention Is All You Need