Catastrophic Forgetting [Sup] {EWC} {LwF}
Description
Catastrophic forgetting, or catastrophic interference, is a phenomenon in artificial neural networks where the network rapidly forgets previously learned information when it's trained on new data or tasks.
This process is a major obstacle to continual learning, which aims for models to learn incrementally while preserving past knowledge.
Varieties
Elastic Weight Consolidation (EWC) solves catastrophic forgetting, which is when a model forgets old tasks while learning new ones. It helps the model accumulate knowledge instead of just replacing it.
The core idea is to apply an "elastic" constraint that protects the most important weights from a past task, while allowing less important weights to change freely to learn the new one.
EWC determines weight importance using the Fisher Information Matrix. It then enforces its constraints by adding a penalty term to the loss function, making it computationally costly to significantly alter the parameters identified as critical for past tasks.
Learning Without Forgetting (LwF) tackles catastrophic forgetting using knowledge distillation, preserving the old model's behavior by training the new model to mimic its predictions rather than protecting its weights.
The intuition is a teacher-student model, where the new "student" model learns from both the correct new labels and the "soft label" predictions generated by the old "teacher" model.
LwF uses a combined loss function that includes a distillation loss to align the new model with the old one, offering the key advantage of not needing the original dataset for training.
Info
While Elastic Weight Consolidation (EWC) needs the old dataset, LwF only needs the old model.
Example
import copy
def compute_importance(model, dataset):
importance = {}
model.eval()
for batch in dataset:
model.zero_grad()
output = model(batch)
loss = output.loss
loss.backward()
for n, p in model.named_parameters():
if p.grad is not None:
if n not in importance:
importance[n] = p.grad.data.clone().pow(2)
else:
importance[n] += p.grad.data.clone().pow(2)
return importance
def ewc_loss(model, old_model, importance, loss, ewc_lambda=0.01):
for n, p in model.named_parameters():
if n in importance:
loss += ewc_lambda * importance[n] * (p - old_model[n]).pow(2).sum()
return loss
def continual_fine_tune(model, tokenizer, old_dataset, new_dataset):
old_model = copy.deepcopy(model)
importance = compute_importance(old_model, old_dataset)
def compute_loss(model, inputs, return_outputs=False):
outputs = model(**inputs)
loss = outputs.loss
if i > 0 and importance is not None:
loss = ewc_loss(model, old_model, importance, loss)
return (loss, outputs) if return_outputs else loss
trainer = Trainer(
model=model,
train_dataset=new_dataset["train"],
eval_dataset=new_dataset["validation"],
compute_loss=compute_loss,
...
)
trainer.train()
Info
old_dataset
(The Past): Represents the old knowledge you want to preserve. It's used to calculate the importance of each weightโdetermining what to protect.new_dataset
(The Present): Represents the new skill you want to learn. It provides the main training objective and calculates the loss for the current taskโdefining what to learn.
import torch
import torch.nn.functional as F
def distillation_loss(student_outputs, teacher_outputs, temperature=2.0):
return F.kl_div(
F.log_softmax(student_outputs / temperature, dim=1),
F.softmax(teacher_outputs / temperature, dim=1),
reduction="batchmean"
) * (temperature ** 2)
def lwf_loss(student_model, teacher_model, data, labels, temperature=2.0, alpha=0.5):
with torch.no_grad(): # Teacher's predictions act as "soft labels". No gradients needed
teacher_outputs = teacher_model(data)
student_outputs = student_model(data)
loss = F.cross_entropy(student_outputs, labels)
dist_loss = distillation_loss(student_outputs, teacher_outputs, temperature)
return (alpha * loss) + ((1 - alpha) * dist_loss)