Learning without Forgetting (LwF) [Sup] [Catastrophic Forgetting]
Description
LwF tackles catastrophic forgetting using knowledge distillation, preserving the old model's behavior by training the new model to mimic its predictions rather than protecting its weights.
The intuition is a teacher-student model, where the new "student" model learns from both the correct new labels and the "soft label" predictions generated by the old "teacher" model.
LwF uses a combined loss function that includes a distillation loss to align the new model with the old one, offering the key advantage of not needing the original dataset for training.
Info
While Elastic Weight Consolidation (EWC) needs the old dataset, LwF only needs the old model.
Example
import torch
import torch.nn.functional as F
def distillation_loss(student_outputs, teacher_outputs, temperature=2.0):
return F.kl_div(
F.log_softmax(student_outputs / temperature, dim=1),
F.softmax(teacher_outputs / temperature, dim=1),
reduction="batchmean"
) * (temperature ** 2)
def lwf_loss(student_model, teacher_model, data, labels, temperature=2.0, alpha=0.5):
with torch.no_grad(): # Teacher's predictions act as "soft labels". No gradients needed
teacher_outputs = teacher_model(data)
student_outputs = student_model(data)
loss = F.cross_entropy(student_outputs, labels)
dist_loss = distillation_loss(student_outputs, teacher_outputs, temperature)
return (alpha * loss) + ((1 - alpha) * dist_loss)