Skip to content

YarnEmbedding Documentation

Table of Contents

  1. Introduction
  2. Purpose and Functionality
  3. Class: YarnEmbedding
  4. Initialization
  5. Parameters
  6. Forward Method
  7. Helpers and Functions
  8. find_correction_dim
  9. find_correction_range
  10. linear_ramp_mask
  11. get_mscale
  12. Usage Examples
  13. Using the YarnEmbedding Class
  14. Using the Helper Functions
  15. Additional Information
  16. Positional Embeddings in Transformers
  17. References

1. Introduction

Welcome to the Zeta documentation for the YarnEmbedding class and related functions! Zeta is a powerful library for deep learning in PyTorch, and this documentation will provide a comprehensive understanding of the YarnEmbedding class and its associated functions.


2. Purpose and Functionality

The YarnEmbedding class and its related functions are designed to generate and apply advanced positional embeddings to input tensors. These embeddings are crucial for sequence-to-sequence models, particularly in transformer architectures. Below, we will explore their purpose and functionality.


3. Class: YarnEmbedding

The YarnEmbedding class is used to apply advanced positional embeddings to input tensors. It offers a highly configurable approach to generating embeddings tailored to the needs of transformer-based models.

Initialization

To create an instance of the YarnEmbedding class, you need to specify the following parameters:

YarnEmbedding(
    dim,
    max_position_embeddings=2048,
    base=10000,
    original_max_position_embeddings=2048,
    extrapolation_factor=1,
    attn_factor=1,
    beta_fast=32,
    beta_slow=1,
    finetuned=False,
    device=None,
)

Parameters

  • dim (int): The dimensionality of the positional embeddings.

  • max_position_embeddings (int, optional): The maximum number of position embeddings to be generated. Default is 2048.

  • base (int, optional): The base value for calculating the positional embeddings. Default is 10000.

  • original_max_position_embeddings (int, optional): The original maximum number of position embeddings used for fine-tuning. Default is 2048.

  • extrapolation_factor (int, optional): The factor used for extrapolating positional embeddings beyond the original maximum. Default is 1.

  • attn_factor (int, optional): A factor affecting the positional embeddings for attention. Default is 1.

  • beta_fast (int, optional): A parameter used for interpolation. Default is 32.

  • beta_slow (int, optional): A parameter used for interpolation. Default is 1.

  • finetuned (bool, optional): Whether to use finetuned embeddings. Default is False.

  • device (torch.device, optional): If specified, the device to which tensors will be moved.

Forward Method

The forward method of the YarnEmbedding class applies advanced positional embeddings to the input tensor. It can be called as follows:

output = yarn_embedding(input_tensor, seq_len)
  • input_tensor (Tensor): The input tensor to which positional embeddings will be applied.

  • seq_len (int): The length of the sequence for which embeddings should be generated.


4. Helpers and Functions

In addition to the YarnEmbedding class, there are several functions provided for working with positional embeddings.

find_correction_dim

This function calculates the correction dimension based on the number of rotations and other parameters.

correction_dim = find_correction_dim(num_rotations, dim, base, max_position_embeddings)
  • num_rotations (int): The number of rotations.

  • dim (int): The dimensionality of the positional embeddings.

  • base (int): The base value for calculating the positional embeddings.

  • max_position_embeddings (int): The maximum number of position embeddings.

find_correction_range

This function calculates the correction range based on low and high rotation values.

low, high = find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings)
  • low_rot (int): The low rotation value.

  • high_rot (int): The high rotation value.

  • dim (int): The dimensionality of the positional embeddings.

  • base (int): The base value for calculating the positional embeddings.

  • max_position_embeddings (int): The maximum number of position embeddings.

linear_ramp_mask

This function generates a linear ramp mask.

ramp_mask = linear_ramp_mask(min, max, dim)
  • min (float): The minimum value.

  • max (float): The maximum value.

  • dim (int): The dimensionality of the mask.

get_mscale

This function calculates the scale factor for positional embeddings.

mscale = get_mscale(scale)
  • scale (float): The scale factor.

5. Usage Examples

Let's explore some usage examples of the YarnEmbedding class and related functions to understand how to use them effectively.

Using the YarnEmbedding Class

import torch

from zeta.nn import YarnEmbedding

# Create an instance of YarnEmbedding
yarn_embedding = YarnEmbedding(dim=256, max_position_embeddings=2048)

# Apply positional embeddings to an input tensor
input_tensor = torch.rand(16, 32, 256)  # Example input tensor
output = yarn_embedding(input_tensor, seq_len=32)

Using the Helper Functions

from zeta.nn import find_correction_dim, find_correction_range, linear_ramp_mask, get_mscale
import torch

# Calculate correction dimension
correction_dim = find_correction_dim(num_rotations=8, dim=256, base=10000, max_position_embeddings=2048)

# Calculate correction range
low

, high = find_correction_range(low_rot=16, high_rot=32, dim=256, base=10000, max_position_embeddings=2048)

# Generate linear ramp mask
ramp_mask = linear_ramp_mask(min=0.2, max=0.8, dim=128)

# Calculate mscale
mscale = get_mscale(scale=2.0)

6. Additional Information

Positional Embeddings in Transformers

Positional embeddings play a crucial role in transformer architectures, allowing models to capture the sequential order of data. These embeddings are especially important for tasks involving sequences, such as natural language processing (NLP) and time series analysis.


7. References

For further information on positional embeddings and transformers, you can refer to the following resources:

This documentation provides a comprehensive overview of the Zeta library's YarnEmbedding class and related functions. It aims to help you understand the purpose, functionality, and usage of these components for advanced positional embeddings in your deep learning projects.