Skip to content

SparseAttention

===============

The SparseAttention class is a PyTorch module that implements a sparse attention mechanism. This class is part of a larger effort to make transformer models more efficient by reducing the computational complexity of the self-attention mechanism.

Overview


In a standard transformer model, the self-attention mechanism computes attention scores for all pairs of tokens in a sequence, resulting in a quadratic computational complexity. This can be problematic for long sequences, as the time and memory requirements can become prohibitively large.

The SparseAttention class addresses this issue by computing attention scores for a subset of token pairs, rather than all pairs. This results in a sparse attention matrix, which can be computed more efficiently than a full attention matrix.

The class supports three modes of sparse attention:

  • 'all': All tokens attend to all other tokens (equivalent to standard self-attention).
  • 'local': Each token attends to a fixed window of adjacent tokens.
  • 'strided': Each token attends to one out of every blocksize tokens.

Class Definition

class SparseAttention(nn.Module):
    def __init__(self, heads, attn_mode, local_attn_ctx=None, blocksize=32):
        super(SparseAttention, self).__init__()
        self.heads = heads
        self.attn_mode = attn_mode
        self.local_attn_ctx = local_attn_ctx
        self.blocksize = blocksize

    def forward(self, q, k, v):
        return blocksparse_attention_impl(
            q, k, v, self.heads, self.attn_mode, self.local_attn_ctx
        )

Parameters


Parameter Type Description
heads int The number of attention heads.
attn_mode str The mode of sparse attention. Can be 'all', 'local', or 'strided'.
local_attn_ctx int, optional The context size for local attention. Only used when attn_mode is 'local'. Default is None.
blocksize int, optional The block size for strided attention. Only used when attn_mode is 'strided'. Default is 32.

Usage


Here is an example of how to use the SparseAttention class:

import torch

from zeta.nn.attention import SparseAttention

# Define parameters
n_batch = 4
n_ctx = 1024
n_embd = 256
heads = 4
attn_mode = "all"
local_attn_ctx = 32
blocksize = 32

# Create input tensors
q = torch.randn(n_batch, n_ctx, n_embd)
k = torch.randn(n_batch, n_ctx, n_embd)
v = torch.randn(n_batch, n_ctx, n_embd)

# Create SparseAttention model
model = SparseAttention(heads, attn_mode, local_attn_ctx, blocksize)

# Forward pass
output = model(q, k, v)

# Print output
print(output[0])

In this example, the SparseAttention model is created with 4 attention heads and 'all' attention mode. The input tensors qk, and v are randomly initialized. The forward pass of the model is then performed, and the output is printed.

Note


The SparseAttention class relies on the blocksparse_attention_impl function for the actual computation of the sparse attention. This function is not defined in the provided code, so you will need to implement it yourself or import it from elsewhere. The function should take the input tensors qk, and v, as well as the parameters headsattn_mode, and local_attn_ctx, and return the output of the sparse attention computation.