Skip to content

Pytorch Hyper-Optimization

A list of hyper-optimized PyTorch features, such as torch.compile, torch.dynamo, and other modules and decorators, is a great idea for quick reference. Below is a table that includes a description, use case, and an example for each feature:

Feature Description Use Case Python Example
torch.compile Converts standard PyTorch code into a fused, optimized form. Use to optimize PyTorch models for faster inference and sometimes training, by fusing operations and eliminating Python overhead. @torch.compile
def model(x):
  return x + x
torch.dynamo A dynamic Python-to-TorchScript compiler. Optimizes PyTorch code dynamically by compiling it into TorchScript, enhancing performance, especially in inference. import torch.dynamo
@torch.dynamo.optimize
def model(x):
  return x.mm(x)
torch.fx A toolkit for capturing and transforming PyTorch programs. Useful for program capture, transformation, and symbolic tracing for custom modifications or optimizations. import torch.fx
def forward(self, x):
  return self.conv(x)
graph_module = torch.fx.symbolic_trace(model)
torch.jit JIT compiler that translates a subset of Python and PyTorch code into TorchScript. Converts models to TorchScript for performance improvements and cross-platform compatibility. import torch.jit
@torch.jit.script
def fn(x, y):
  return x + y
torch.nn.utils.prune Provides utilities for model pruning. Reduces model size and complexity for deployment or efficiency, by removing unnecessary weights. import torch.nn.utils.prune as prune
prune.random_unstructured(module, name='weight', amount=0.3)
torch.nn.utils.fusion Fuses multiple operations into a single operation. Optimizes certain sequences of ops for performance, particularly in CNNs. import torch.nn.utils.fusion
fused_module = torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
torch.utils.checkpoint Enables gradient checkpointing. Reduces memory usage in training large models by trading compute for memory. from torch.utils.checkpoint import checkpoint
output = checkpoint(model, input)
torch.utils.bottleneck A tool to identify performance bottlenecks. Diagnoses the source of slowdowns in PyTorch models. import torch.utils.bottleneck
torch.utils.bottleneck.run(model, input)
torch.utils.data.DataLoader Provides an iterable over a dataset. Essential for efficient loading, batching, and shuffling of data in training and inference. from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

Each of these features serves a specific purpose in optimizing and enhancing the performance and usability of PyTorch models. The examples provided are basic and intended to illustrate how these features might be implemented in a PyTorch workflow.