Skip to content

Mechanistic Interpretability (MI) [NLP]

Description

Mechanistic interpretability (MI) is an emerging field that aims to understand how neural networks process information at a detailed, component level—similar to how we might reverse-engineer a mechanical device. Rather than just observing inputs and outputs, MI seeks to trace how information flows through the network, identify specific computational patterns, and understand how different parts of the network (such as individual neurons or attention heads) contribute to the model's behavior.

Info

MI is important because it goes beyond surface-level explanations to uncover the internal mechanisms of how neural networks, particularly complex models like LLMs, actually work. By analyzing how specific components—like neurons, layers, or attention heads—process and transform information

Example

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

class InterpretableTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def get_attention_patterns(self, x):
        """Extract attention weights from each layer"""
        x = self.embedding(x)
        attention_patterns = []

        for layer in self.transformer.layers:
            # Register a hook to capture attention weights
            attention_weights = None

            def hook(module, input, output):
                nonlocal attention_weights
                attention_weights = output[1]  # attention weights

            handle = layer.self_attn.register_forward_hook(hook)
            x = layer(x)
            attention_patterns.append(attention_weights)
            handle.remove()

        return attention_patterns

    def analyze_neuron_activations(self, x, layer_idx):
        """Analyze individual neuron activations in a specific layer"""
        activations = []

        def hook(module, input, output):
            activations.append(output.detach())

        # Register hook on specific layer
        handle = list(self.transformer.layers)[layer_idx].register_forward_hook(hook)

        # Forward pass
        with torch.no_grad():
            self(x)

        handle.remove()
        layer_activations = activations[0]

        # Find most active neurons
        mean_activation = layer_activations.mean(dim=(0, 1))  # Average across batch and sequence
        top_neurons = torch.topk(mean_activation, k=10)

        return top_neurons.indices, top_neurons.values

    def intervention_study(self, x, layer_idx, neuron_idx):
        """Study how zeroing out specific neurons affects the output"""
        original_output = None
        modified_output = None

        def hook_original(module, input, output):
            nonlocal original_output
            original_output = output.detach()

        def hook_modified(module, input, output):
            nonlocal modified_output
            modified = output.clone()
            modified[:, :, neuron_idx] = 0  # Zero out specific neuron
            modified_output = modified
            return modified

        layer = list(self.transformer.layers)[layer_idx]

        # Get original output
        handle = layer.register_forward_hook(hook_original)
        self(x)
        handle.remove()

        # Get modified output
        handle = layer.register_forward_hook(hook_modified)
        self(x)
        handle.remove()

        return original_output, modified_output

def visualize_attention(attention_weights, tokens=None):
    """Visualize attention patterns"""
    plt.figure(figsize=(10, 8))
    plt.imshow(attention_weights[0].cpu(), cmap="viridis")

    if tokens is not None:
        plt.xticks(range(len(tokens)), tokens, rotation=45)
        plt.yticks(range(len(tokens)), tokens)

    plt.colorbar()
    plt.title("Attention Pattern")
    plt.show()


# Initialize model
model = InterpretableTransformer(vocab_size=1000, d_model=256, nhead=8, num_layers=4)

# Sample input
input_ids = torch.randint(0, 1000, (1, 20))  # Batch size 1, sequence length 20

# Get attention patterns
attention_patterns = model.get_attention_patterns(input_ids)

# Analyze neuron activations
top_neurons, activation_values = model.analyze_neuron_activations(input_ids, layer_idx=0)

# Perform intervention study
original, modified = model.intervention_study(input_ids, layer_idx=0, neuron_idx=42)

# Visualize attention
visualize_attention(attention_patterns[0])  # Visualize first layer's attention

Info

Each component helps us understand different aspects of the model:

  • Attention patterns show how the model relates different tokens to each other
  • The neuron activation analysis reveals which neurons are most important for processing specific inputs
  • Causal intervention helps us understand the role of specific neurons by observing how the output changes when we modify them
  • Visualization tools help us interpret these patterns more intuitively