LocalMHA: Local Multi-Head Attention for PyTorch¶
Overview¶
The LocalMHA
module is a local multi-head attention mechanism designed to process sequences in smaller, fixed-size windows, allowing it to handle long sequences more efficiently. This module is especially useful when working with long sequences where global attention mechanisms become computationally expensive. It combines local attention with the power of multi-head attention to capture information from different representation subspaces.
Key Concepts:
-
Local Attention: Instead of attending to all positions in the input sequence, local attention restricts the attention to a small fixed-sized window around each position.
-
Multi-Head Attention: The input is split into multiple heads, allowing the network to attend to information from different representation subspaces simultaneously.
Class Definition¶
Parameters:¶
-
dim (int)
: Dimensionality of the input sequence. -
window_size (int)
: The size of the local attention window. The module will attend to this fixed-size window around each position. -
dim_head (int, optional)
: Dimensionality of each attention head. Default is 64. -
heads (int, optional)
: Number of attention heads. Default is 8. -
dropout (float, optional)
: Dropout probability applied after the attention mechanism. Default is 0.0. -
causal (bool, optional)
: If set toTrue
, the attention mechanism will be causal, ensuring that each position only attends to previous positions. Default isFalse
. -
prenorm (bool, optional)
: If set toTrue
, layer normalization is applied before the multi-head attention mechanism. Default isFalse
. -
qk_rmsnorm (bool, optional)
: If set toTrue
, root mean square normalization is applied to the query and key tensors. Default isFalse
. -
qk_scale (int, optional)
: Scaling factor for queries and keys whenqk_rmsnorm
is set toTrue
. Default is 8. -
use_xpos (bool, optional)
: If set toTrue
, the attention mechanism uses relative positional embeddings. Default isFalse
. -
xpos_scale_base (float, optional)
: Base scaling factor for relative positional embeddings. IfNone
, it defaults to the square root of the dimension of the model. Only used whenuse_xpos
isTrue
. -
exact_windowsize (bool, optional)
: If set toTrue
, the attention window size is strictly adhered to, without any additional padding. Default isTrue
.
Method: forward
¶
This method performs the forward pass of the LocalMHA
module.
Parameters:¶
-
x (torch.Tensor)
: The input tensor with shape[batch_size, sequence_length, dim]
. -
mask (torch.Tensor, optional)
: A boolean mask tensor with shape[batch_size, sequence_length]
. Positions withTrue
values will be masked and won't be attended to. -
attn_bias (torch.Tensor, optional)
: Additional bias to add to the attention scores before softmax.
Returns:¶
torch.Tensor
: The output tensor after local multi-head attention with shape[batch_size, sequence_length, dim]
.
Example Usage¶
from torch import tensor
from zeta import LocalMHA
# Sample data
x = tensor(
[[...], [...], ...]
) # Example input tensor with shape [batch_size, sequence_length, dim]
# Initialize the LocalMHA module
local_mha = LocalMHA(dim=512, window_size=5)
# Forward pass
output = local_mha(x)
Mathematical Formula¶
For a given input \( x \):
- Linearly project \( x \) into queries \( Q \), keys \( K \), and values \( V \).
- If
qk_rmsnorm
isTrue
, apply RMS normalization to \( Q \) and \( K \). - For each position \( i \) in \( x \), compute attention scores with all positions in the window around \( i \).
- Apply softmax to the scores, then compute the attention output as a weighted sum of \( V \) based on these scores.
- Finally, concatenate all head outputs and linearly project to get the final output.
Additional Information¶
The LocalMHA
module provides a balance between computational efficiency and the ability to capture long-range dependencies. While it restricts attention to local windows, the use of multi-head attention allows it to attend to different features within that window. The optional use of RMS normalization and relative positional embeddings further extends its capabilities.
References¶
For a deeper understanding of multi-head attention, see the original Transformer paper. For details on local attention, you might refer to relevant literature on efficient transformers or localized attention mechanisms.