filmconditioning

FilmConditioning Module

Introduction: The FilmConditioning module applies feature-wise affine transformations to the input tensor, conditioning it based on a conditioning tensor. This module is particularly useful in scenarios where feature-based conditioning is required in convolutional neural network architectures.

Args: Number of channels (int): Specifies the number of channels in the input tensor.

Attributes: num_channels (int): Number of channels in the input tensor. projection_add (nn.Linear): Linear layer for additive projection. projection_mult (nn.Linear): Linear layer for multiplicative projection.

Class Definition:

class FilmConditioning(nn.Module):
    def __init__(self, num_channels: int, *args, **kwargs):
        super().__init__()
        self.num_channels = num_channels
        self._projection_add = nn.Linear(num_channels, num_channels)
        self._projection_mult = nn.Linear(num_channels, num_channels)

Functionality and Usage: The __init__ method initializes the module and its attributes. Two linear layers are defined for additive and multiplicative projections of conditioning. The forward method applies affine transformations to the input tensor based on the conditioning tensor.

def forward(self, conv_filters: torch.Tensor, conditioning: torch.Tensor):
    projected_cond_add = self._projection_add(conditioning)
    projected_cond_mult = self._projection_mult(conditioning)
    # Modifying the result is based on the conditioning tensor
    return result

Usage Examples:

Usage Example 1: Applying Film Conditioning

import torch
import torch.nn as nn

from zeta.nn import FilmConditioning

# Define input tensors
conv_filters = torch.randn(10, 3, 32, 32)
conditioning = torch.randn(10, 3)

# Create an instance of FilmConditioning
film_conditioning = FilmConditioning(3)

# Applying film conditioning
result = film_conditioning(conv_filters, conditioning)
print(result.shape)

Usage Example 2: Applying Film Conditioning for another example

import torch
import torch.nn as nn

from zeta.nn import FilmConditioning

# Define input tensors
conv_filters = torch.randn(5, 4, 20, 20)
conditioning = torch.randn(5, 4)

# Create an instance of FilmConditioning
film_conditioning = FilmConditioning(4)

# Applying film conditioning
result = film_conditioning(conv_filters, conditioning)
print(result.shape)

Usage Example 3: Usage Example

import torch
import torch.nn as nn

from zeta.nn import FilmConditioning

# Define input tensors
conv_filters = torch.randn(8, 2, 50, 50)
conditioning = torch.randn(8, 2)

# Create an instance of FilmConditioning
film_conditioning = FilmConditioning(2)

# Applying film conditioning
result = film_conditioning(conv_filters, conditioning)
print(result.shape)

References and Resources: Expected format for the documentation should be provided here for any references.