Skip to content

Module/Function Name: Film

Provides a Feature-wise Linear Modulation (FiLM) module which applies feature-wise linear modulation to the input features based on the conditioning tensor to adapt them to the given conditions.

Arguments

  • dim (int): The dimension of the input features.
  • hidden_dim (int): The dimension of the hidden layer.
  • expanse_ratio (int, optional): The expansion ratio for the hidden layer (default = 4).
  • conditions (Tensor): The conditioning tensor.
  • hiddens (Tensor): The input features to be modulated.

Usage Examples

import torch
from zeta.nn import Film

# Initialize the Film layer
film_layer = Film(dim=128, hidden_dim=64, expanse_ratio=4)

# Create dummy data for conditions and hiddens
conditions = torch.randn(10, 128)  # Batch size is 10, feature size is 128
hiddens = torch.randn(10, 1, 128)  # Batch size is 10, sequence length is 1, feature size is 128

# Pass the data through the Film layer
modulated_features = film_layer(conditions, hiddens)

# Print the shape of the output
print(modulated_features.shape)  # Output shape will be [10, 1, 128]

References and Resources