SparseAttention
¶
===============
The SparseAttention
class is a PyTorch module that implements a sparse attention mechanism. This class is part of a larger effort to make transformer models more efficient by reducing the computational complexity of the self-attention mechanism.
Overview¶
In a standard transformer model, the self-attention mechanism computes attention scores for all pairs of tokens in a sequence, resulting in a quadratic computational complexity. This can be problematic for long sequences, as the time and memory requirements can become prohibitively large.
The SparseAttention
class addresses this issue by computing attention scores for a subset of token pairs, rather than all pairs. This results in a sparse attention matrix, which can be computed more efficiently than a full attention matrix.
The class supports three modes of sparse attention:
- 'all': All tokens attend to all other tokens (equivalent to standard self-attention).
- 'local': Each token attends to a fixed window of adjacent tokens.
- 'strided': Each token attends to one out of every
blocksize
tokens.
Class Definition¶
class SparseAttention(nn.Module):
def __init__(self, heads, attn_mode, local_attn_ctx=None, blocksize=32):
super(SparseAttention, self).__init__()
self.heads = heads
self.attn_mode = attn_mode
self.local_attn_ctx = local_attn_ctx
self.blocksize = blocksize
def forward(self, q, k, v):
return blocksparse_attention_impl(
q, k, v, self.heads, self.attn_mode, self.local_attn_ctx
)
Parameters¶
Parameter | Type | Description |
---|---|---|
heads |
int | The number of attention heads. |
attn_mode |
str | The mode of sparse attention. Can be 'all', 'local', or 'strided'. |
local_attn_ctx |
int, optional | The context size for local attention. Only used when attn_mode is 'local'. Default is None. |
blocksize |
int, optional | The block size for strided attention. Only used when attn_mode is 'strided'. Default is 32. |
Usage¶
Here is an example of how to use the SparseAttention
class:
import torch
from zeta.nn.attention import SparseAttention
# Define parameters
n_batch = 4
n_ctx = 1024
n_embd = 256
heads = 4
attn_mode = "all"
local_attn_ctx = 32
blocksize = 32
# Create input tensors
q = torch.randn(n_batch, n_ctx, n_embd)
k = torch.randn(n_batch, n_ctx, n_embd)
v = torch.randn(n_batch, n_ctx, n_embd)
# Create SparseAttention model
model = SparseAttention(heads, attn_mode, local_attn_ctx, blocksize)
# Forward pass
output = model(q, k, v)
# Print output
print(output[0])
In this example, the SparseAttention
model is created with 4 attention heads and 'all' attention mode. The input tensors q
, k
, and v
are randomly initialized. The forward pass of the model is then performed, and the output is printed.
Note¶
The SparseAttention
class relies on the blocksparse_attention_impl
function for the actual computation of the sparse attention. This function is not defined in the provided code, so you will need to implement it yourself or import it from elsewhere. The function should take the input tensors q
, k
, and v
, as well as the parameters heads
, attn_mode
, and local_attn_ctx
, and return the output of the sparse attention computation.