Skip to content

Layer-Wise Adaptive Regularization

Description

Layer-wise adaptive regularization involves applying different regularization strengths to different layers of the model.

This can be particularly effective for LLMs, where lower layers may benefit from less regularization to capture fundamental patterns, while higher layers might need stronger regularization to prevent overfitting.

Example

import torch.nn as nn

class LayerwiseAdaptiveRegularization(nn.Module):
    def __init__(self, base_model, base_dropout=0.1, dropout_step=0.02):
        super().__init__()
        self.base_model = base_model
        for i, layer in enumerate(base_model.transformer.h):
            dropout = base_dropout + i * dropout_step
            layer.attn.dropout.p = dropout
            layer.mlp.dropout.p = dropout

    def forward(self, *args, **kwargs):
        return self.base_model(*args, **kwargs)

base_model = create_lm_model()
model = LayerwiseAdaptiveRegularization(base_model)