PytorchGELUTanh¶
Overview¶
The PytorchGELUTanh
class in Python is a fast C implementation of the tanh approximation of the GeLU activation function. This implementation is meant to be faster and as effective as other implementations of GeLU (Gaussian Error Linear Units) function like NewGELU and FastGELU. However, it is not an exact numerical match to them due to possible rounding errors.
This documentation provides an in-depth guide to using the PytorchGELUTanh
class. It includes general information about the class, the method documentation, and various usage examples.
Introduction¶
In Neural Networks, activation functions decide whether a neuron should be activated or not by calculating the weighted sum and adding bias with it. One of these activation functions is the Gaussian Error Linear Units (GeLU) function. GeLU function approximates the cumulative distribution function of the standard Gaussian distribution and helps in faster learning during the initial phase of training.
The PytorchGELUTanh
class provides a fast C implementation of the tanh approximation of the GeLU activation function.
Class Definition¶
class PytorchGELUTanh(nn.Module):
"""
A fast C implementation of the tanh approximation of the GeLU activation function. See
https://arxiv.org/abs/1606.08415.
This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
match due to rounding errors.
"""
def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.12.0"):
raise ImportError(
f"You are using torch=={torch.__version__}, but torch>=1.12.0"
" is required to use PytorchGELUTanh. Please upgrade torch."
)
def forward(self, input: Tensor) -> Tensor:
return nn.functional.gelu(input, approximate="tanh")
General Information¶
The PytorchGELUTanh
class only requires PyTorch version 1.12.0 or higher.
This class contains the following methods:
Method | Definition |
---|---|
__init__ |
This is the constructor method for the PytorchGELUTanh class in which the superclass is initialized and a check is made to ensure that the version of PyTorch being used supports the class. If not, an import error is raised. |
forward |
This method applies the tanh approximation of the GeLU active function to the provided tensor input. |
The forward
method takes in a tensor as an input argument and returns a tensor as an output. The input and output tensors are of the same size.
Usage Examples¶
Example 1: Basic Usage¶
In this basic example, we create an instance of the PytorchGELUTanh
class and pass a tensor to its forward
method to apply the tanh approximation of the GeLU function.
# Import necessary libraries
import torch
from packaging import version
from torch import Tensor, nn
from torch.nn.functional import gelu
from zeta.nn import PytorchGELUTanh
# Create an instance of the PytorchGELUTanh class.
gelutanh = PytorchGELUTanh()
# Create a tensor.
x = torch.randn(3)
# Print the tensor before and after applying the GeLU Tanh activation function.
print("Before: ", x)
print("After: ", gelutanh.forward(x))
Example 2: Application to Deep Learning¶
The PytorchGELUTanh
class can be used in place of traditional activation functions in deep learning models. Here is an example of its usage in a feed-forward neural network.
# Import necessary libraries
import torch
from torch import Tensor, nn
from torch.nn.functional import gelu
from zeta.nn import PytorchGELUTanh
# Define a feed-forward neural network with 2 layers and the PytorchGELUTanh activation function
class FeedForwardNN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20) # 10 input neurons, 20 output neurons
self.gelu = PytorchGELUTanh() # Our custom activation function
self.fc2 = nn.Linear(20, 1) # Final layer
def forward(self, x):
x = self.fc1(x)
x = self.gelu(x) # Apply the PytorchGELUTanh activation
x = self.fc2(x)
return x
# Instantiate the model
model = FeedForwardNN()
# Print the model architecture
print(model)
This completes the documentation for the PytorchGELUTanh
Python class, but feel free to reference the official PyTorch documentation and ensure you are using a version of PyTorch that is compatible with this class.