Skip to content

Probing [NLP]

Descrption

Probing involves training simple models on the internal representations of an LLM to assess what linguistic properties are captured at different layers.

Different layers in a transformer specialize in different linguistic properties. Lower layers capture syntax and token identity; middle layers handle grammar and sentence structure; and higher layers focus on semantics, reasoning, and factual recall. This hierarchy emerges naturally during training, with lower layers excelling in syntactic tasks and higher layers in semantic reasoning. Probing studies confirm this specialization, aiding interpretability, fine-tuning, and model compression for task-specific optimizations.

Example

import torch
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

def probe_bert_layers(model, tokenizer, texts, labels, layer_nums):
    def get_embeddings(text):  # Get BERT embeddings for each layer
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model(inputs, output_hidden_states=True)
        return outputs.hidden_states

    results = {}
    for layer_num in layer_nums:
        embeddings = [get_embeddings(text)[layer_num].squeeze().mean(dim=0).numpy() for text in texts]
        X_train, X_test, y_train, y_test = train_test_split(embeddings, labels, test_size=0.2, random_state=42)

        # Train and evaluate probe
        probe = LogisticRegression(random_state=42)
        probe.fit(X_train, y_train)
        y_pred = probe.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        results[f"layer_{layer_num}"] = accuracy

    return results

model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)

texts = ["The cat sat on the mat.", "The dog chased the ball.", ...]
labels = [0, 1, ...]  # Corresponding labels (e.g., 0 for simple, 1 for complex sentences)

layer_nums = [1, 6, 12]  # Layers to probe
probe_results = probe_bert_layers(model, tokenizer, texts, labels, layer_nums)

for layer, accuracy in probe_results.items():
    print(f"{layer} accuracy: {accuracy:.2f}")