Skip to content

Iterative Pruning [Pruning] [Structured & Unstructured]

Description

Iterative pruning allows you to prune a small fraction of weights at a time over multiple training steps. This method reduces the risk of drastic performance drops and provides more opportunities for the model to recover and adjust to the pruning.

The iterative approach also allows for fine-tuning after each pruning step, enabling the model to "heal" from the weight reduction.

Info

The gradual removal of weights ensures that the model has enough time to adjust between each pruning step.

Example

Iteratively prune 10% of the model after every 10 epochs:

for epoch in range(1, num_epochs+1):
    # Regular training step here

    if epoch % 10 == 0:
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear):
                # prune.l1_unstructured(module, name="weight", amount=0.1)  # Unstructured
                prune.ln_structured(module, name="weight", amount=0.1, n=2, dim=0)  # Structured
                prune.remove(module, "weight")  # Remove pruning mask after each step

    # Regular validation step here