Module/Function Name: StochDepth¶
class torch.nn.StochDepth(stochdepth_rate):
Initializes the Stochastic Depth module that applies a stochastic binary mask to the input tensor.
Parameters:
- stochdepth_rate (float): The probability of dropping each input activation.
def forward(x):
"""
Forward pass of the Stochastic Depth module. Applies a stochastic rate of dropout to the input tensor.
Args:
- x (Tensor): The input tensor.
Returns:
- Tensor: The output tensor after applying stochastic depth.
```
if not self.training:
return x
batch_size = x.shape[0]
# Generating random tensor
rand_tensor = torch.rand(
batch_size,
1,
1,
1
).type_as(x)
# Calculating the keep probability
keep_prob = 1 - self.stochdepth_rate
# Construct binary tensor using torch floor function
binary_tensor = torch.floor(rand_tensor + keep_prob)
return x * binary_tensor
```
# Usage example:
stoch_depth = nn.StochDepth(stochdepth_rate=0.2)
output = stoch_depth(input)
"""
```