Skip to content

Module/Function Name: StochDepth

class torch.nn.StochDepth(stochdepth_rate):

Initializes the Stochastic Depth module that applies a stochastic binary mask to the input tensor.

Parameters:
- stochdepth_rate (float): The probability of dropping each input activation.

def forward(x):
    """
    Forward pass of the Stochastic Depth module. Applies a stochastic rate of dropout to the input tensor.

    Args:
    - x (Tensor): The input tensor.

    Returns:
    - Tensor: The output tensor after applying stochastic depth.
    ```
    if not self.training:
        return x

    batch_size = x.shape[0]

    # Generating random tensor
    rand_tensor = torch.rand(
        batch_size,
        1,
        1,
        1
    ).type_as(x)

    # Calculating the keep probability
    keep_prob = 1 - self.stochdepth_rate

    # Construct binary tensor using torch floor function
    binary_tensor = torch.floor(rand_tensor + keep_prob)

    return x * binary_tensor

    ```

    # Usage example:

    stoch_depth = nn.StochDepth(stochdepth_rate=0.2)
    output = stoch_depth(input)
    """

```