TransformerBlock
Documentation
Table of Contents
- Introduction
- Purpose and Functionality
- Class:
TransformerBlock
- Initialization
- Parameters
- Attention Mechanism
- Multi-Head Attention
- Rotary Embedding
- Feedforward Network
- Caching and Optimization
- Usage Examples
- Basic Usage
- Fine-Tuning
- Additional Information
- Layernorm
- Position Embeddings
- References
1. Introduction
Welcome to the Zeta documentation for the TransformerBlock
class! Zeta is a versatile library that offers tools for efficient training of deep learning models using PyTorch. This documentation will provide a comprehensive overview of the TransformerBlock
class, its architecture, purpose, and usage.
2. Purpose and Functionality
The TransformerBlock
class is a fundamental component of the Zeta library. It is designed to be used within a transformer-based architecture, and its primary purpose is to process input data efficiently. Below, we'll explore the key functionalities and features of the TransformerBlock
class.
3. Class: TransformerBlock
The TransformerBlock
class is the building block of transformer-based models. It performs various operations, including multi-head attention and feedforward network, to process input data. Let's dive into the details of this class.
Initialization
To create a TransformerBlock
instance, you need to specify various parameters and configurations. Here's an example of how to initialize it:
TransformerBlock(
dim=512,
dim_head=64,
causal=True,
heads=8,
qk_rmsnorm=False,
qk_scale=8,
ff_mult=4,
attn_dropout=0.0,
ff_dropout=0.0,
use_xpos=True,
xpos_scale_base=512,
flash_attn=False,
)
Parameters
-
dim
(int): The dimension of the input data. -
dim_head
(int): The dimension of each attention head. -
causal
(bool): Whether to use a causal (auto-regressive) attention mechanism. Default isTrue
. -
heads
(int): The number of attention heads. -
qk_rmsnorm
(bool): Whether to apply root mean square normalization to query and key vectors. Default isFalse
. -
qk_scale
(int): Scaling factor for query and key vectors. Used whenqk_rmsnorm
isTrue
. Default is8
. -
ff_mult
(int): Multiplier for the feedforward network dimension. Default is4
. -
attn_dropout
(float): Dropout probability for attention layers. Default is0.0
. -
ff_dropout
(float): Dropout probability for the feedforward network. Default is0.0
. -
use_xpos
(bool): Whether to use positional embeddings. Default isTrue
. -
xpos_scale_base
(int): Scaling factor for positional embeddings. Default is512
. -
flash_attn
(bool): Whether to use Flash Attention mechanism. Default isFalse
.
Attention Mechanism
The TransformerBlock
class includes a powerful attention mechanism that allows the model to focus on relevant parts of the input data. It supports both regular and Flash Attention.
Multi-Head Attention
The class can split the attention mechanism into multiple heads, allowing the model to capture different patterns in the data simultaneously. The number of attention heads is controlled by the heads
parameter.
Rotary Embedding
Rotary embeddings are used to enhance the model's ability to handle sequences of different lengths effectively. They are applied to query and key vectors to improve length extrapolation.
Feedforward Network
The TransformerBlock
class includes a feedforward network that processes the attention output. It can be customized by adjusting the ff_mult
parameter.
Caching and Optimization
The class includes mechanisms for caching causal masks and rotary embeddings, which can improve training efficiency. It also provides options for fine-tuning specific modules within the block.
4. Usage Examples
Now, let's explore some usage examples of the TransformerBlock
class to understand how to use it effectively.
Basic Usage
# Create a TransformerBlock instance
transformer_block = TransformerBlock(dim=512, heads=8)
# Process input data
output = transformer_block(input_data)
Fine-Tuning
# Create a TransformerBlock instance with fine-tuning modules
lora_q = YourCustomModule()
lora_k = YourCustomModule()
lora_v = YourCustomModule()
lora_o = YourCustomModule()
transformer_block = TransformerBlock(
dim=512, heads=8, finetune_modules=(lora_q, lora_k, lora_v, lora_o)
)
# Process input data
output = transformer_block(input_data)
5. Additional Information
Layernorm
The TransformerBlock
class uses layer normalization (layernorm) to normalize input data before processing. This helps stabilize and accelerate training.
Position Embeddings
Position embeddings are used to provide the model with information about the position of tokens
in the input sequence. They are crucial for handling sequences of different lengths effectively.
6. References
- Original Transformer Paper
- Attention Is All You Need
- Flash Attention: Scaling Vision Transformers with Hybrid Attention for Image and Video Recognition
- Layer Normalization
This documentation provides a comprehensive guide to the TransformerBlock
class in the Zeta library, explaining its purpose, functionality, parameters, and usage. You can now effectively integrate this class into your deep learning models for various natural language processing tasks and beyond.