Skip to content

BitLinear Module Documentation

==============================

Overview


The BitLinear module is a custom implementation of a linear layer in a neural network, with the added functionality of bit quantization. This module is designed to work with PyTorch's nn.Module and can be integrated into any PyTorch model architecture.

The BitLinear module performs linear transformation on the input data, followed by quantization and dequantization. The quantization process is performed using the absmax_quantize function, which quantizes the input tensor based on the absolute maximum value.

absmax_quantize Function


The absmax_quantize function is a helper function used by the BitLinear module to perform quantization and dequantization of the input tensor.

Parameters

Parameter Type Description
x torch.Tensor The input tensor to be quantized.
bits int (optional) The number of bits to use for quantization. Default is 8.

Returns

Return Value Type Description
quant torch.Tensor The quantized tensor.
dequant torch.Tensor The dequantized tensor.

BitLinear Class

The BitLinear class is a custom implementation of a linear layer that performs bit quantization on the input data.

Parameters

Parameter Type Description
in_features int The number of input features.
out_features int The number of output features.
groups int (optional) The number of groups for group normalization. Default is 1.

Methods

__init__(self, in_features, out_features, groups=1)

The constructor for the BitLinear class. Initializes the weight parameter and resets it.

reset_parameters(self)

Resets the weight parameter using the Kaiming uniform initialization method.

forward(self, input)

Performs the forward pass of the BitLinear module.

Usage Examples

Example 1: Basic Usage

import torch

from zeta.quant import BitLinear

# Initialize the BitLinear module
linear = BitLinear(10, 20)

# Create a random tensor of size (128, 10)
input = torch.randn(128, 10)

# Perform the forward pass
output = linear(input)

# Print the size of the output
print(output.size())  # torch.Size([128, 20])

Example 2: Using Different Number of Groups

import torch

from zeta.quant import BitLinear

# Initialize the BitLinear module with 2 groups
linear = BitLinear(10, 20, groups=2)

# Create a random tensor of size (128, 10)
input = torch.randn(128, 10)

# Perform the forward pass
output = linear(input)

# Print the size of the output
print(output.size())  # torch.Size([128, 20])

Example 3: Integrating with a PyTorch Model

import torch
from torch import nn

from zeta.quant import BitLinear


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = BitLinear(10, 20)

    def forward(self, x):
        return self.linear(x)


# Initialize the model
model = MyModel()

# Create a random tensor of size (128, 10)
input = torch.randn(128, 10)

# Perform the forward pass
output = model(input)

# Print the size of the output
print(output.size())  # torch.Size([128, 20])

Conclusion


The BitLinear module provides a unique way to perform linear transformation with bit quantization. This can be particularly useful in scenarios where memory efficiency is crucial. As with any other PyTorch module, it can be easily integrated into any model architecture.