Skip to content

Quantization {QAT} {PTQ} {Mixed-Precision} {Dynamic}

Description

The weights of an LLM are numeric values with a given precision, which can be expressed by the number of bits like float64 or float32. If we lower the amount of bits to represent a value, we get a less accurate result. However, if we lower the number of bits we also lower the memory requirements of that model.

Info

Notice the lowered accuracy when we halve the number of bits.

Varieties

Post-Training Quantization (PTQ) is the most straightforward form of quantization and is applied after a model has been fully trained. It doesn't require model retraining and works by converting the high-precision weights and activations into lower-precision formats, typically INT8.

PTQ is ideal for models where retraining is expensive or impractical, and it works best for tasks that are not overly sensitive to precision loss.

Info

Some PTQ methods often require a calibration step on a representative dataset to determine optimal quantization parameters such as scaling factors and zero points, capture activation distributions during inference, and minimize the error between original and quantized model outputs.

This calibration process helps the quantization algorithm understand the numerical range and distribution of weights and activations across the network, allowing more accurate mapping from higher precision formats (such as FP32) to lower precision formats (such as INT8 or INT4), ultimately preserving model accuracy while reducing memory footprint and computational requirements for deployment.

Quantization-Aware Training (QAT) goes beyond simple PTQ by incorporating the effects of quantization into the training process itself. This allows the model to learn how to compensate for the quantization-induced noise.

During QAT, both weights and activations are simulated at lower precision during training but are kept at higher precision for gradient calculations. This method is particularly useful when the application requires high performance with aggressive quantization.

Info

QAT typically results in better model accuracy compared to PTQ, particularly for more complex or critical applications.

Dynamic quantization calculates quantization parameters on the fly during inference based on the actual input values. This adapts better to varying data distributions but introduces some computational overhead compared to static approaches.

Mixed-precision quantization is a more flexible approach that leverages multiple levels of numerical precision within a single model. For instance, less critical layers of the model can use INT8, while more sensitive layers remain in FP16 or FP32. This allows greater control over the trade-off between performance and precision.

Example

import torch
import torch.quantization as quant

model = ...  # Load pre-trained model

# Convert model to quantization-ready state
model.eval()
model.qconfig = quant.default_qconfig

# Prepare for static quantization
model_prepared = quant.prepare(model)

# Calibration step: run representative data through the model
# (This example uses random data; replace with real samples)
for _ in range(100):
    sample_input = torch.randn(1, 784)
    model_prepared(sample_input)

# Apply quantization
model_quantized = quant.convert(model_prepared)
import torch.quantization as quant

model = ...  # Load pre-trained model

# Set up QAT
model.train()
model.qconfig = quant.get_default_qat_qconfig("fbgemm")

# Prepare for QAT
model_prepared = quant.prepare_qat(model)

# Training loop (for simplicity, only showing initialization)
for epoch in range(num_epochs):
    train_one_epoch(model_prepared, train_loader, optimizer)
    validate(model_prepared, val_loader)

# Convert to quantized version
model_quantized = quant.convert(model_prepared.eval())
import torch
import torch.quantization as quant

model = ...  # Pre-trained model
quantized_model = quant.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

Info

We use torch.quantization.quantize_dynamic to dynamically quantize the linear layers of a pre-trained model

from torch.cuda.amp import autocast

model = ...  # Load a normal pre-trained FP32 model

with autocast():  # Use FP16 where possible, fall back to FP32 for sensitive computations
    output = model(input_data)

Info

The autocast context manager is designed to convert standard float32 (full-precision) models to mixed precision.

If the model is already saved in half-precision (float16), there's nothing for autocast to do.

Vs

Strategy Accuracy Complexity Performance Resources
PTQ Good for simple models; declines with complexity Low; minimal setup 75% storage reduction; 2-4x speedup Low; minimal compute needed
QAT Highest; best for sub-8-bit High; requires extended training High compression with the best accuracy High; intensive training needs
Dynamic Good for RNNs; weak for CNNs Medium; runtime overhead Good memory savings; slower compute Medium; runtime processing
Mixed-Precision High; flexible precision options Medium-high; layer-specific tuning Hardware-dependent speedup Medium-high during setup