Skip to content

Catastrophic Forgetting [Sup] {EWC} {LwF}

Description

Catastrophic forgetting, or catastrophic interference, is a phenomenon in artificial neural networks where the network rapidly forgets previously learned information when it's trained on new data or tasks.

This process is a major obstacle to continual learning, which aims for models to learn incrementally while preserving past knowledge.

Varieties

Elastic Weight Consolidation (EWC) solves catastrophic forgetting, which is when a model forgets old tasks while learning new ones. It helps the model accumulate knowledge instead of just replacing it.

The core idea is to apply an "elastic" constraint that protects the most important weights from a past task, while allowing less important weights to change freely to learn the new one.

EWC determines weight importance using the Fisher Information Matrix. It then enforces its constraints by adding a penalty term to the loss function, making it computationally costly to significantly alter the parameters identified as critical for past tasks.

Learning Without Forgetting (LwF) tackles catastrophic forgetting using knowledge distillation, preserving the old model's behavior by training the new model to mimic its predictions rather than protecting its weights.

The intuition is a teacher-student model, where the new "student" model learns from both the correct new labels and the "soft label" predictions generated by the old "teacher" model.

LwF uses a combined loss function that includes a distillation loss to align the new model with the old one, offering the key advantage of not needing the original dataset for training.

Info

While Elastic Weight Consolidation (EWC) needs the old dataset, LwF only needs the old model.

Example

import copy

def compute_importance(model, dataset):
    importance = {}
    model.eval()

    for batch in dataset:
        model.zero_grad()
        output = model(batch)
        loss = output.loss
        loss.backward()

        for n, p in model.named_parameters():
            if p.grad is not None:
                if n not in importance:
                    importance[n] = p.grad.data.clone().pow(2)
                else:
                    importance[n] += p.grad.data.clone().pow(2)

    return importance

def ewc_loss(model, old_model, importance, loss, ewc_lambda=0.01):
    for n, p in model.named_parameters():
        if n in importance:
            loss += ewc_lambda * importance[n] * (p - old_model[n]).pow(2).sum()
    return loss

def continual_fine_tune(model, tokenizer, old_dataset, new_dataset):
    old_model = copy.deepcopy(model)
    importance = compute_importance(old_model, old_dataset)

    def compute_loss(model, inputs, return_outputs=False):
        outputs = model(**inputs)
        loss = outputs.loss
        if i > 0 and importance is not None:
            loss = ewc_loss(model, old_model, importance, loss)
        return (loss, outputs) if return_outputs else loss

    trainer = Trainer(
        model=model,
        train_dataset=new_dataset["train"],
        eval_dataset=new_dataset["validation"],
        compute_loss=compute_loss,
        ...
    )
    trainer.train()

Info

  • old_dataset (The Past): Represents the old knowledge you want to preserve. It's used to calculate the importance of each weightโ€”determining what to protect.
  • new_dataset (The Present): Represents the new skill you want to learn. It provides the main training objective and calculates the loss for the current taskโ€”defining what to learn.
import torch
import torch.nn.functional as F

def distillation_loss(student_outputs, teacher_outputs, temperature=2.0):
    return F.kl_div(
        F.log_softmax(student_outputs / temperature, dim=1),
        F.softmax(teacher_outputs / temperature, dim=1),
        reduction="batchmean"
    ) * (temperature ** 2)

def lwf_loss(student_model, teacher_model, data, labels, temperature=2.0, alpha=0.5):
    with torch.no_grad():  # Teacher's predictions act as "soft labels". No gradients needed
        teacher_outputs = teacher_model(data)

    student_outputs = student_model(data)
    loss = F.cross_entropy(student_outputs, labels)
    dist_loss = distillation_loss(student_outputs, teacher_outputs, temperature)

    return (alpha * loss) + ((1 - alpha) * dist_loss)