Skip to content

Stochastic Weight Averaging (SWA)

Description

SWA is a technique that improves neural network generalization by averaging weights from multiple points along the optimization trajectory, effectively finding flatter, more robust minima that perform better on unseen data than the typically sharp minima found by conventional optimization methods.

SWA involves averaging multiple points along the trajectory of Stochastic gradient descent (SGD) with a modified learning rate schedule.

Example

from torch.optim.swa_utils import AveragedModel, SWALR

swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

for epoch in range(100):
    # Train the model here ...

    if epoch > 75:  # Start SWA after epoch 75
        swa_model.update_parameters(model)
        swa_scheduler.step()