Skip to content

Module Name: StochasticSkipBlock

Overview and Introduction:

Tabular Deep Learning models sometimes struggle with overfitting on noisy data. Stochastic Skip Block is a PyTorch module designed to combat this problem by introducing stochasticity in between the network layers. This module applies an innovative concept of skipping certain layers during training with a defined probability, thereby creating a diverse set of thinner networks.

Given a set of layers encapsulated in a module, the StochasticSkipBlock will either apply this module to the input or return the input directly bypassing the module completely. The decision whether to apply or skip the module is randomized with a user-defined probability. This way the model creates uncertainty and works as an efficient regularizer preventing overfitting on training data. Moreover, it contributes to faster convergence during training and better generalization in prediction phase.

Class Definition:

Below is the class definition for the module:

class StochasticSkipBlock(nn.Module):
    """
    A module that implements stochastic skip connections in a neural network.

    Args:
        sb1 (nn.Module): The module to be skipped with a certain probability.
        p (float): The probability of skipping the module. Default is 0.5.

    Returns:
        torch.Tensor: The output tensor after applying the stochastic skip connection.
    """

    def __init__(self, sb1, p=0.5):
        super().__init__()
        self.sb1 = sb1
        self.p = p

    def forward(self, x: torch.Tensor):
        """
        Forward pass of the StochasticSkipBlock.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after applying the module.
        """
        if self.training and torch.rand(1).item() < self.p:
            return x  # Skip the sb1
        else:
            return self.sb1(x)

Parameters

Argument Default Description
sb1 None The layers encapsulated in nn.Module object to be skipped with a certain probability.
p 0.5 The probability of skipping the module.

Use Cases

Use Case 1: Basic Usage

This is a basic example of using StochasticSkipBlock in a feed forward neural network.

First, you need to import the necessary module:

import torch
import torch.nn as nn
from torch.nn.functional import relu

from zeta.nn import StochasticSkipBlock

Now, you need to define the architecture of the model:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 20)
        self.layer2 = StochasticSkipBlock(
            nn.Sequential(nn.Linear(20, 20), nn.ReLU()), p=0.5
        )  # 50% chance to skip the subsequence of layers
        self.layer3 = nn.Linear(20, 1)

    def forward(self, x):
        x = relu(self.layer1(x))
        x = self.layer2(x)
        x = self.layer3(x)
        return x

Now, you can instantiate your model:

model = MyModel()
input = torch.randn(32, 10)
output = model(input)

Use Case 2: Convolutional Neural Network

This example shows how to embed StochasticSkipBlock in between convolutional layers of a CNN model.

class MyCNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5)
        self.conv2 = StochasticSkipBlock(nn.Conv2d(32, 64, kernel_size=5), p=0.6)
        self.fc1 = nn.Linear(64 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(self.conv2(x), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Use Case 3: Training the model using DataLoader

This shows how to train the model using StochasticSkipBlock module. Please note, This example assumes you have your dataloader ('train_dataloader') ready with training data.

import torch.optim as optim
from torch.nn.functional import binary_cross_entropy
from torch.optim import SGD

from zeta.nn import StochasticSkipBlock

# initiate model
model = MyModel()

# defining loss function
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(50):  # loop over the dataset
    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print("Epoch %d loss: %.3f" % (epoch + 1, running_loss))

print("Finished Training")

Additional Tips

To get the most out of the StochasticSkipBlock, adjust the skipping probability parameter p. A higher probability means there's more chance a layer will be skipped during the training phase. Experiment with different values of p to find the optimal one that gives your model the best result.

The StochasticSkipBlock module introduces randomness in your model's training process; therefore, results might vary slightly each time you train your model. Consider setting a seed for your PyTorch application to ensure reproducibility.

Conclusion

StochasticSkipBlock is a flexible module that makes it easy to introduce stochasticity into your model's architecture, acting as a regularizer that could improve your model's performance. It's important to experiment with this module to see how much randomness helps your specific use case.

References

  1. Deep Networks with Stochastic Depth
  2. Understanding the difficulty of training deep feedforward neural networks
  3. Maxout Networks