Skip to content

AdaptiveParameterList Module Documentation


Overview and Introduction

The AdaptiveParameterList class extends PyTorch's nn.ParameterList to provide an adaptive parameter mechanism during the training process. By using adaptation functions, one can adjust or transform parameters based on specific criteria or observations, allowing for dynamic updates outside the traditional gradient-based update rules. This capability can be crucial in certain applications where manual interventions or parameter adjustments based on heuristics are desirable.


Class Definition: AdaptiveParameterList

class AdaptiveParameterList(nn.ParameterList):

Description:

A container module that extends PyTorch's nn.ParameterList to allow the adaptation of its parameters using specific functions. This adaptation can be applied at various stages during training or evaluation to realize sophisticated model behaviors.

Parameters:

  • parameters (List[nn.Parameter], optional): List of parameters to initialize the AdaptiveParameterList. Default: None.

Method: adapt

def adapt(self, adaptation_functions):

Description:

Adapts the parameters of the AdaptiveParameterList using the provided functions.

Parameters:

  • adaptation_functions (Dict[int, Callable]): A dictionary where keys are the indices of the parameters in the list and values are the callable functions that take in an nn.Parameter and return an nn.Parameter.

Raises:

  • ValueError: If adaptation_functions is not a dictionary.
  • ValueError: If an entry in adaptation_functions is not callable.
  • ValueError: If the output tensor of an adaptation function doesn't match the shape of the input parameter.

Usage Examples:

1. Basic Usage

import torch
import torch.nn as nn
from AdaptiveParameterList import AdaptiveParameterList
from shapeless import x  # Placeholder, as actual import statement was not provided


# Define an adaptation function
def adaptation_function(param):
    return param * 0.9


adaptive_params = AdaptiveParameterList([nn.Parameter(torch.randn(10, 10))])

# Create a dictionary with adaptation functions for the desired indices
adapt_funcs = {0: adaptation_function}

adaptive_params.adapt(adapt_funcs)

2. Using Multiple Adaptation Functions

import torch
import torch.nn as nn
from AdaptiveParameterList import AdaptiveParameterList
from shapeless import x


# Define multiple adaptation functions
def adaptation_function1(param):
    return param * 0.9


def adaptation_function2(param):
    return param + 0.1


adaptive_params = AdaptiveParameterList(
    [nn.Parameter(torch.randn(10, 10)), nn.Parameter(torch.randn(10, 10))]
)

# Apply different adaptation functions to different parameters
adapt_funcs = {0: adaptation_function1, 1: adaptation_function2}

adaptive_params.adapt(adapt_funcs)

3. Handling Errors with Adaptation Functions

import torch
import torch.nn as nn
from AdaptiveParameterList import AdaptiveParameterList
from shapeless import x


# Incorrect adaptation function (not returning a tensor of the same shape)
def wrong_adaptation_function(param):
    return param[0]


adaptive_params = AdaptiveParameterList([nn.Parameter(torch.randn(10, 10))])

try:
    adaptive_params.adapt({0: wrong_adaptation_function})
except ValueError as e:
    print(f"Error: {e}")

Mathematical Representation:

Given an AdaptiveParameterList with parameters ( P = [p_1, p_2, ... , p_n] ) and an adaptation function ( f_i ) for parameter ( p_i ), the adapted parameter ( p_i' ) is computed as:

[ p_i' = f_i(p_i) ]

Where ( f_i: \mathbb{R}^{m \times n} \rightarrow \mathbb{R}^{m \times n} ) is a function that takes a tensor (parameter) as input and returns a tensor of the same shape.


Additional Information and Tips:

  1. Ensure that the adaptation functions are defined correctly and return tensors of the same shape as their input.
  2. Adaptation can be applied at different intervals, for example, after every epoch, or after specific events during training.
  3. Care must be taken when designing adaptation functions to avoid unintentional model behaviors.

References and Resources: