Skip to content

Module/Function Name: UMambaBlock

UMambaBlock is a 5d Mamba block designed to serve as a building block for 5d visual models. In accordance with the article published on https://arxiv.org/pdf/2401.04722.pdf, this module enables transformation across 5D space-time data for efficient information processing.

The module's core concepts pertain to the input dimension (dim), the depth of the Mamba block, the state dimension (d_state), the expansion factor (expand), the rank of the temporal difference (dt_rank), the dimension of the convolutional kernel (d_conv), and the inclusion of bias in linear and convolutional layers.

Class Definition:

class UMambaBlock(nn.Module):
    """
    UMambaBlock is a 5d Mamba block that can be used as a building block for a 5d visual model
    From the paper: https://arxiv.org/pdf/2401.04722.pdf

    Args:
        dim (int): The input dimension.
        dim_inner (Optional[int]): The inner dimension. If not provided, it is set to dim * expand.
        depth (int): The depth of the Mamba block.
        d_state (int): The state dimension. Default is 16.
        expand (int): The expansion factor. Default is 2.
        dt_rank (Union[int, str]): The rank of the temporal difference (Δ) tensor. Default is "auto".
        d_conv (int): The dimension of the convolutional kernel. Default is 4.
        conv_bias (bool): Whether to include bias in the convolutional layer. Default is True.
        bias (bool): Whether to include bias in the linear layers. Default is False.
    """

    def __init__(
        self,
        dim: int = None,
        depth: int = 5,
        d_state: int = 16,
        expand: int = 2,
        d_conv: int = 4,
        conv_bias: bool = True,
        bias: bool = False,
    ):
        # Class initialization and setup
        ...

    def forward(self, x: Tensor):
        """
        B, C, H, W, D
        """
        # Forward pass implementation
        ...

Detailed Explanation:

The UMambaBlock class serves as a thorough representation of a 5d Mamba block. It encapsulates the input dimension, depth, state dimension, expansion factor, and other key parameters. By instantiating this block, users can process 5D visual data, further taking advantage of hyperparameters to customize the block for specific application requirements.

Usage Examples:

Example 1:

import torch

from zeta.nn import UMambaBlock

# img:         B, C, H, W, D
img_tensor = torch.randn(1, 64, 10, 10, 10)

# Initialize Mamba block
block = UMambaBlock(dim=64, depth=1)

# Forward pass
y = block(img_tensor)
print(y.shape)

Example 2:

import torch

from zeta.nn import UMambaBlock

# img:         B, C, H, W, D
img_tensor = torch.randn(1, 64, 10, 10, 10)

# Initialize Mamba block with custom parameters
block = UMambaBlock(dim=64, depth=3, expand=3)

# Forward pass
y = block(img_tensor)
print(y.shape)

Example 3:

import torch

from zeta.nn import UMambaBlock

# img:         B, C, H, W, D
img_tensor = torch.randn(1, 64, 5, 5, 20)

# Initialize Mamba block with altered state dimension and convolutional kernel size
block = UMambaBlock(dim=64, d_state=32, d_conv=6)

# Forward pass
y = block(img_tensor)
print(y.shape)

Additional Information and Tips:

The user may benefit from customizing various hyperparameters such as the input dimension, depth, and state dimension to tailor the UMambaBlock for specific use cases. Common useful tips include managing the Mamba block's rank parameter and identifying key transformations to optimize for handling high-dimensional spatiotemporal data.

References and Resources:

By following this well-structured and detailed documentation, developers and research practitioners can readily understand and adopt the UMambaBlock module for 5D image and video data processing.