Module/Function Name: Film¶
Provides a Feature-wise Linear Modulation (FiLM) module which applies feature-wise linear modulation to the input features based on the conditioning tensor to adapt them to the given conditions.
Arguments¶
dim
(int): The dimension of the input features.hidden_dim
(int): The dimension of the hidden layer.expanse_ratio
(int, optional): The expansion ratio for the hidden layer (default = 4).conditions
(Tensor): The conditioning tensor.hiddens
(Tensor): The input features to be modulated.
Usage Examples¶
import torch
from zeta.nn import Film
# Initialize the Film layer
film_layer = Film(dim=128, hidden_dim=64, expanse_ratio=4)
# Create dummy data for conditions and hiddens
conditions = torch.randn(10, 128) # Batch size is 10, feature size is 128
hiddens = torch.randn(10, 1, 128) # Batch size is 10, sequence length is 1, feature size is 128
# Pass the data through the Film layer
modulated_features = film_layer(conditions, hiddens)
# Print the shape of the output
print(modulated_features.shape) # Output shape will be [10, 1, 128]
References and Resources¶
- Paper: Link to the paper discussing FiLM module.
- PyTorch Documentation: PyTorch Documentation ```