Pytorch Hyper-Optimization
A list of hyper-optimized PyTorch features, such as torch.compile
, torch.dynamo
, and other modules and decorators, is a great idea for quick reference. Below is a table that includes a description, use case, and an example for each feature:
Feature | Description | Use Case | Python Example |
---|---|---|---|
torch.compile |
Converts standard PyTorch code into a fused, optimized form. | Use to optimize PyTorch models for faster inference and sometimes training, by fusing operations and eliminating Python overhead. | @torch.compile def model(x): return x + x |
torch.dynamo |
A dynamic Python-to-TorchScript compiler. | Optimizes PyTorch code dynamically by compiling it into TorchScript, enhancing performance, especially in inference. | import torch.dynamo @torch.dynamo.optimize def model(x): return x.mm(x) |
torch.fx |
A toolkit for capturing and transforming PyTorch programs. | Useful for program capture, transformation, and symbolic tracing for custom modifications or optimizations. | import torch.fx def forward(self, x): return self.conv(x) graph_module = torch.fx.symbolic_trace(model) |
torch.jit |
JIT compiler that translates a subset of Python and PyTorch code into TorchScript. | Converts models to TorchScript for performance improvements and cross-platform compatibility. | import torch.jit @torch.jit.script def fn(x, y): return x + y |
torch.nn.utils.prune |
Provides utilities for model pruning. | Reduces model size and complexity for deployment or efficiency, by removing unnecessary weights. | import torch.nn.utils.prune as prune prune.random_unstructured(module, name='weight', amount=0.3) |
torch.nn.utils.fusion |
Fuses multiple operations into a single operation. | Optimizes certain sequences of ops for performance, particularly in CNNs. | import torch.nn.utils.fusion fused_module = torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn) |
torch.utils.checkpoint |
Enables gradient checkpointing. | Reduces memory usage in training large models by trading compute for memory. | from torch.utils.checkpoint import checkpoint output = checkpoint(model, input) |
torch.utils.bottleneck |
A tool to identify performance bottlenecks. | Diagnoses the source of slowdowns in PyTorch models. | import torch.utils.bottleneck torch.utils.bottleneck.run(model, input) |
torch.utils.data.DataLoader |
Provides an iterable over a dataset. | Essential for efficient loading, batching, and shuffling of data in training and inference. | from torch.utils.data import DataLoader dataloader = DataLoader(dataset, batch_size=32, shuffle=True) |
Each of these features serves a specific purpose in optimizing and enhancing the performance and usability of PyTorch models. The examples provided are basic and intended to illustrate how these features might be implemented in a PyTorch workflow.