MixtureOfSoftmaxes
Documentation¶
The MixtureOfSoftmaxes
module is designed to improve the modeling capabilities of the softmax function by allowing the combination of multiple softmax distributions. It takes an input tensor and computes a weighted sum of softmax outputs from different softmax layers. These weights are learned during training, enabling the model to adapt to the data's characteristics effectively.
The primary use case of the MoS module is in scenarios where a single softmax may not capture the complex relationships between input features and output classes. By combining multiple softmax distributions with learned mixture weights, the module provides a flexible approach to handle such situations.
Once you have the dependencies installed, you can import the module in your Python code.
Usage ¶
Initialization ¶
To use the MixtureOfSoftmaxes
module, you need to create an instance of it by providing the following arguments during initialization:
num_mixtures
(int): The number of softmax mixtures.input_size
(int): The size of the input feature dimension.num_classes
(int): The number of classes in the output dimension.
Here's an example of how to initialize the module:
Forward Pass ¶
Once you've initialized the MixtureOfSoftmaxes
module, you can perform the forward pass by passing an input tensor x
to it. The forward pass calculates the combined output from the mixture of softmaxes.
The output
tensor will contain the combined result from the mixture of softmax distributions.
Examples ¶
Basic Example ¶
Here's a simple example of how to use the MixtureOfSoftmaxes
module to handle a classification task:
import torch
from torch import nn
from zeta.ops import MixtureOfSoftmaxes
# Initialize the module
mos = MixtureOfSoftmaxes(num_mixtures=3, input_size=128, num_classes=10)
# Generate random input data
x = torch.randn(32, 128)
# Perform the forward pass
output = mos(x)
print(output.shape) # Expected output shape: torch.Size([32, 10])
In this example, we create an instance of MixtureOfSoftmaxes
with three mixtures, an input size of 128, and ten output classes. We then generate random input data and perform a forward pass to get the output.
Complex Task ¶
In more complex scenarios, the MoS module can be applied to tasks where traditional softmax may not be sufficient. For example, in natural language processing (NLP), the MoS module can be used to model complex relationships between words and their meanings.
import torch
from torch import nn
from zeta.ops import MixtureOfSoftmaxes
# Initialize the module
mos = MixtureOfSoftmaxes(
num_mixtures=5, input_size=128, num_classes=10000
) # Large vocabulary size
# Generate input data (word embeddings)
x = torch.randn(32, 128)
# Perform the forward pass
output = mos(x)
print(output.shape) # Expected output shape: torch.Size([32, 10000])
In this example, we initialize the MoS module with five mixtures and a large vocabulary size (10,000 classes). This demonstrates the module's ability to handle complex tasks with a significant number of output classes.
Parameters ¶
Here are the parameters that can be passed during the initialization of the MixtureOfSoftmaxes
module:
Parameter | Description | Data Type | Default Value |
---|---|---|---|
num_mixtures |
Number of softmax mixtures. | int | - |
input_size |
Size of the input feature dimension. | int | - |
num_classes |
Number of classes in the output dimension. | int | - |
Return Value ¶
The forward
method of the MixtureOfSoftmaxes
module returns two values:
attn_output
(Tensor): The combined output from the mixture of softmaxes.attn_output_weights
(Optional[Tensor]): The attention weights. Only returned whenneed_weights
is set toTrue
.
Additional Information ¶
-
The MoS module can be used in a variety of deep learning tasks, including classification, natural language processing, and more.
-
It is important to fine-tune the number of mixtures and other hyperparameters based on the specific task and dataset.