Skip to content

fsdp Documentation

Table of Contents

  1. Introduction
  2. Function: fsdp
  3. Initialization
  4. Parameters
  5. Mixed Precision Modes
  6. Sharding Strategies
  7. Usage Examples
  8. Basic FSDP Wrapper
  9. Automatic Layer Wrapping
  10. Advanced Configuration
  11. Additional Information
  12. FullyShardedDataParallel (FSDP)
  13. Mixed Precision Training
  14. Model Sharding
  15. References

1. Introduction

Welcome to the documentation for the Zeta library! Zeta provides a powerful utility function, fsdp, that wraps a given PyTorch model with the FullyShardedDataParallel (FSDP) wrapper. This enables efficient data parallelism and model sharding for deep learning applications.

Key Features

  • Efficient Data Parallelism: FSDP allows you to efficiently parallelize training across multiple GPUs.
  • Mixed Precision Training: Choose between BFloat16 (bf16), Float16 (fp16), or Float32 (fp32) precision modes.
  • Model Sharding: Apply gradient sharding, full model sharding, or no sharding based on your needs.

In this documentation, you will learn how to use the fsdp function effectively, understand its architecture, and explore examples of its applications.


2. Function: fsdp

The fsdp function is the core component of the Zeta library, providing a straightforward way to wrap your PyTorch model with FSDP for efficient distributed training.

Initialization

model = fsdp(
    model, auto_wrap=False, mp="fp32", shard_strat="NO_SHARD", TransformerBlock=None
)

Parameters

  • model (torch.nn.Module): The original PyTorch model to be wrapped with FSDP.
  • auto_wrap (bool, optional): If True, enables automatic wrapping of the model's layers based on the transformer_auto_wrap_policy. Default is False.
  • mp (str, optional): The mixed precision mode to be used. Can be 'bf16' for BFloat16, 'fp16' for Float16, or 'fp32' for Float32 precision. Default is 'fp32'.
  • shard_strat (str, optional): The sharding strategy to be used. Can be 'SHARD_GRAD' for sharding at gradient computation, 'FULL_SHARD' for full model sharding, or 'NO_SHARD' for no sharding. Default is 'NO_SHARD'.
  • TransformerBlock (Type, optional): A custom transformer layer type. Only used if auto_wrap is True.

Mixed Precision Modes

  • bf16 (BFloat16): Lower precision for faster training with minimal loss in accuracy.
  • fp16 (Float16): Higher precision than BFloat16 but still faster than full precision.
  • fp32 (Float32): Full single-precision floating-point precision.

Sharding Strategies

  • SHARD_GRAD (Sharding at Gradient Computation): Shards gradients during the backward pass.
  • FULL_SHARD (Full Model Sharding): Shards the entire model for parallelism.
  • NO_SHARD (No Sharding): No sharding, suitable for single-GPU training.

3. Usage Examples

Now, let's explore practical examples of using the fsdp function in various scenarios.

Basic FSDP Wrapper

import torch.nn as nn

# Define your PyTorch model
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
)

# Wrap the model with FSDP using default settings (no sharding, fp32 precision)
fsdp_model = fsdp(model)

Automatic Layer Wrapping

import torch.nn as nn


# Define a custom transformer layer type
class TransformerBlock(nn.Module):
    def __init__(self):
        # Define your custom transformer layer here
        pass


# Define your PyTorch model with transformer layers
model = nn.Sequential(
    nn.Linear(784, 256),
    TransformerBlock(),
    nn.Linear(256, 10),
)

# Wrap the model with FSDP and enable automatic layer wrapping
fsdp_model = fsdp(model, auto_wrap=True, TransformerBlock=TransformerBlock)

Advanced Configuration

import torch.nn as nn

# Define your PyTorch model
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
)

# Wrap the model with FSDP with custom settings (full model sharding, bf16 precision)
fsdp_model = fsdp(model, mp="bf16", shard_strat="FULL_SHARD")

These examples demonstrate how to use the fsdp function to wrap your PyTorch models with FSDP for distributed training with various configurations.


4. Additional Information

FullyShardedDataParallel (FSDP)

FSDP is a powerful wrapper that enables efficient data parallelism and model sharding. It optimizes gradient communication and memory usage during distributed training.

Mixed Precision Training

Mixed precision training involves using lower-precision data types for certain parts of the training pipeline, leading to faster training times with minimal loss in accuracy.

Model Sharding

Model sharding is a technique used to distribute model parameters across multiple devices or GPUs, improving training speed and memory efficiency.


5. References

For further information and research papers related to FSDP, mixed precision training, and model sharding, please refer to the following resources:

Explore these references to gain a deeper understanding of the techniques and concepts implemented in the Zeta library and the fsdp function.

Feel free to reach out to

the Zeta community for any questions or discussions regarding this library. Happy deep learning!