YarnEmbedding
Documentation¶
Table of Contents¶
- Introduction
- Purpose and Functionality
- Class:
YarnEmbedding
- Initialization
- Parameters
- Forward Method
- Helpers and Functions
find_correction_dim
find_correction_range
linear_ramp_mask
get_mscale
- Usage Examples
- Using the
YarnEmbedding
Class - Using the Helper Functions
- Additional Information
- Positional Embeddings in Transformers
- References
1. Introduction ¶
Welcome to the Zeta documentation for the YarnEmbedding
class and related functions! Zeta is a powerful library for deep learning in PyTorch, and this documentation will provide a comprehensive understanding of the YarnEmbedding
class and its associated functions.
2. Purpose and Functionality ¶
The YarnEmbedding
class and its related functions are designed to generate and apply advanced positional embeddings to input tensors. These embeddings are crucial for sequence-to-sequence models, particularly in transformer architectures. Below, we will explore their purpose and functionality.
3. Class: YarnEmbedding
¶
The YarnEmbedding
class is used to apply advanced positional embeddings to input tensors. It offers a highly configurable approach to generating embeddings tailored to the needs of transformer-based models.
Initialization ¶
To create an instance of the YarnEmbedding
class, you need to specify the following parameters:
YarnEmbedding(
dim,
max_position_embeddings=2048,
base=10000,
original_max_position_embeddings=2048,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1,
finetuned=False,
device=None,
)
Parameters ¶
-
dim
(int): The dimensionality of the positional embeddings. -
max_position_embeddings
(int, optional): The maximum number of position embeddings to be generated. Default is2048
. -
base
(int, optional): The base value for calculating the positional embeddings. Default is10000
. -
original_max_position_embeddings
(int, optional): The original maximum number of position embeddings used for fine-tuning. Default is2048
. -
extrapolation_factor
(int, optional): The factor used for extrapolating positional embeddings beyond the original maximum. Default is1
. -
attn_factor
(int, optional): A factor affecting the positional embeddings for attention. Default is1
. -
beta_fast
(int, optional): A parameter used for interpolation. Default is32
. -
beta_slow
(int, optional): A parameter used for interpolation. Default is1
. -
finetuned
(bool, optional): Whether to use finetuned embeddings. Default isFalse
. -
device
(torch.device, optional): If specified, the device to which tensors will be moved.
Forward Method ¶
The forward
method of the YarnEmbedding
class applies advanced positional embeddings to the input tensor. It can be called as follows:
-
input_tensor
(Tensor): The input tensor to which positional embeddings will be applied. -
seq_len
(int): The length of the sequence for which embeddings should be generated.
4. Helpers and Functions ¶
In addition to the YarnEmbedding
class, there are several functions provided for working with positional embeddings.
find_correction_dim
¶
This function calculates the correction dimension based on the number of rotations and other parameters.
-
num_rotations
(int): The number of rotations. -
dim
(int): The dimensionality of the positional embeddings. -
base
(int): The base value for calculating the positional embeddings. -
max_position_embeddings
(int): The maximum number of position embeddings.
find_correction_range
¶
This function calculates the correction range based on low and high rotation values.
-
low_rot
(int): The low rotation value. -
high_rot
(int): The high rotation value. -
dim
(int): The dimensionality of the positional embeddings. -
base
(int): The base value for calculating the positional embeddings. -
max_position_embeddings
(int): The maximum number of position embeddings.
linear_ramp_mask
¶
This function generates a linear ramp mask.
-
min
(float): The minimum value. -
max
(float): The maximum value. -
dim
(int): The dimensionality of the mask.
get_mscale
¶
This function calculates the scale factor for positional embeddings.
scale
(float): The scale factor.
5. Usage Examples ¶
Let's explore some usage examples of the YarnEmbedding
class and related functions to understand how to use them effectively.
Using the YarnEmbedding
Class ¶
import torch
from zeta.nn import YarnEmbedding
# Create an instance of YarnEmbedding
yarn_embedding = YarnEmbedding(dim=256, max_position_embeddings=2048)
# Apply positional embeddings to an input tensor
input_tensor = torch.rand(16, 32, 256) # Example input tensor
output = yarn_embedding(input_tensor, seq_len=32)
Using the Helper Functions ¶
from zeta.nn import find_correction_dim, find_correction_range, linear_ramp_mask, get_mscale
import torch
# Calculate correction dimension
correction_dim = find_correction_dim(num_rotations=8, dim=256, base=10000, max_position_embeddings=2048)
# Calculate correction range
low
, high = find_correction_range(low_rot=16, high_rot=32, dim=256, base=10000, max_position_embeddings=2048)
# Generate linear ramp mask
ramp_mask = linear_ramp_mask(min=0.2, max=0.8, dim=128)
# Calculate mscale
mscale = get_mscale(scale=2.0)
6. Additional Information ¶
Positional Embeddings in Transformers ¶
Positional embeddings play a crucial role in transformer architectures, allowing models to capture the sequential order of data. These embeddings are especially important for tasks involving sequences, such as natural language processing (NLP) and time series analysis.
7. References ¶
For further information on positional embeddings and transformers, you can refer to the following resources:
-
Attention Is All You Need (Transformer) - The original transformer paper.
-
PyTorch Documentation - Official PyTorch documentation for related concepts and functions.
This documentation provides a comprehensive overview of the Zeta library's YarnEmbedding
class and related functions. It aims to help you understand the purpose, functionality, and usage of these components for advanced positional embeddings in your deep learning projects.