Uncovering Reasoning in LLMs with Sparse Autoencoders
Summary
Large Language Models (LLMs) like DeepSeek-R1 show remarkable reasoning abilities, but how these abilities are internally represented has remained a mystery. This paper explores the mechanistic interpretability of reasoning in LLMs using Sparse Autoencoders (SAEs) — a tool that decomposes LLM activations into human-interpretable features. In this post, we’ll: • Explain the SAE architecture used • Compute and visualize ReasonScore • Explore feature steering with sample completions • Provide live visualizations using Python + Streamlit
1. Sparse Autoencoders: Theory + Code
SAE Architecture
Sparse Autoencoders aim to reconstruct LLM activations using a sparse hidden representation:
We have talked about these guys in this post: Detecting AI-Generated Text: Challenges and Solutions
import torch
import torch.nn
from pathlib import Path
import yaml
from typing import Dict, Any, Tuple, Union
from torch import Tensor
class SparseAutoencoder(torch.nn.Module):
"""
Sparse Autoencoder implementation in PyTorch.
This class implements an autoencoder with sparsity constraints on the latent representation.
The model includes methods for saving/loading, computing sparsity loss, and summarizing its configuration.
Args:
input_dim (int): Dimensionality of the input data.
hidden_dim (int): Dimensionality of the latent space.
activation (str): Activation function to use in the encoder (default: "relu").
sparsity_target (float): Target sparsity level for latent activations.
sparsity_penalty (float): Weight for the sparsity regularization term.
layer_index (int, optional): Index of the layer in a larger model (used for context).
"""
def __init__(self,
input_dim: int,
hidden_dim: int,
activation: str = "relu",
sparsity_target: float = 0.03,
sparsity_penalty: float = 0.0001,
layer_index: int = None):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.encoder = torch.nn.Linear(input_dim, hidden_dim)
self.decoder = torch.nn.Linear(hidden_dim, input_dim)
# Validate and set activation function
if activation not in ["relu", "tanh", "sigmoid", "leaky_relu"]:
raise ValueError(f"Unsupported activation function: {activation}")
self.activation_fn = getattr(torch.nn.functional, activation)
self.sparsity_target = sparsity_target
self.sparsity_penalty = sparsity_penalty
self.layer_index = layer_index
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""
Perform the forward pass through the sparse autoencoder.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- recon (torch.Tensor): Reconstructed input of shape (batch_size, input_dim).
- z (torch.Tensor): Latent representation of shape (batch_size, hidden_dim).
"""
z = self.activation_fn(self.encoder(x))
recon = self.decoder(z)
return recon, z
def compute_sparsity_loss(self, z: Tensor) -> Tensor:
"""
Compute the sparsity loss using L1 regularization.
Args:
z (torch.Tensor): Latent representation of shape (batch_size, hidden_dim).
Returns:
torch.Tensor: Sparsity loss value.
"""
return self.sparsity_penalty * torch.mean(torch.abs(z))
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SparseAutoencoder":
"""
Create a SparseAutoencoder instance from a configuration dictionary.
Args:
config (Dict[str, Any]): Dictionary containing model hyperparameters.
Returns:
SparseAutoencoder: Initialized model instance.
"""
return cls(
input_dim=config["input_dim"],
hidden_dim=config["hidden_dim"],
activation=config.get("activation", "relu"),
sparsity_target=config.get("sparsity_target", 0.03),
sparsity_penalty=config.get("sparsity_penalty", 0.0001),
layer_index=config.get("layer_index", None)
)
def save(self, path: Union[str, Path]):
"""
Save the model's state_dict and configuration to the specified directory.
Args:
path (Union[str, Path]): Directory path where the model will be saved.
"""
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
torch.save(self.state_dict(), path / "pytorch_model.bin")
config = {
"input_dim": self.input_dim,
"hidden_dim": self.hidden_dim,
"activation": self.activation_fn.__name__,
"sparsity_target": self.sparsity_target,
"sparsity_penalty": self.sparsity_penalty,
"layer_index": self.layer_index
}
with open(path / "config.yaml", "w") as f:
yaml.dump(config, f)
@classmethod
def load(cls, path: Union[str, Path], device: str = None) -> "SparseAutoencoder":
"""
Load a SparseAutoencoder instance from the specified directory.
Args:
path (Union[str, Path]): Directory path where the model is saved.
device (str, optional): Device to load the model onto (e.g., "cpu" or "cuda").
Returns:
SparseAutoencoder: Loaded model instance.
"""
path = Path(path)
with open(path / "config.yaml", "r") as f:
config = yaml.safe_load(f)
model = cls.from_config(config)
state_dict = torch.load(path / "pytorch_model.bin", map_location=device or "cpu")
model.load_state_dict(state_dict)
return model
@staticmethod
def compute_topk_activations_mean(hidden_states: Tensor, sae_model: "SparseAutoencoder", top_k: int = 20) -> float:
"""
Compute the mean of the top-k activations in the latent space.
Args:
hidden_states (torch.Tensor): Hidden states from an intermediate layer of shape (1, seq_len, hidden_dim).
sae_model (SparseAutoencoder): Trained SparseAutoencoder instance.
top_k (int): Number of top activations to consider.
Returns:
float: Mean of the top-k activations.
"""
_, z = sae_model(hidden_states.squeeze(0)) # shape: (seq_len, hidden_dim)
token_scores = torch.abs(z).sum(dim=0) # L1 norm per feature across tokens
topk_vals, _ = torch.topk(token_scores, k=top_k)
return topk_vals.mean().item()
def summary(self, print_output: bool = True) -> Dict[str, Any]:
"""
Generate a summary of the model's configuration and parameters.
Args:
print_output (bool): Whether to print the summary to the console.
Returns:
Dict[str, Any]: Summary information about the model.
"""
info = {
"input_dim": self.input_dim,
"hidden_dim": self.hidden_dim,
"activation": self.activation_fn.__name__,
"sparsity_target": self.sparsity_target,
"sparsity_penalty": self.sparsity_penalty,
"layer_index": self.layer_index,
"total_params": sum(p.numel() for p in self.parameters()),
"trainable_params": sum(p.numel() for p in self.parameters() if p.requires_grad)
}
if print_output:
print("🔍 Sparse Autoencoder Summary:")
for k, v in info.items():
print(f" {k}: {v}")
return info
2. Extracting Reasoning Features
We compute a ReasonScore based on how much a feature activates on reasoning-related tokens. Python Code to Compute ReasonScore
Perfect! Here’s a draft blog post outline based on your implementation roadmap for the paper:
🔍 Building Explainable Reasoning in LLMs: A Hands-On Demo of Sparse Feature Steering
Published: March 2025
By: [Your Name]
🧠 Overview
In this post, I share my experience implementing and extending the ideas from the paper:
“I Have Covered All the Bases Here: Interpreting Reasoning Features in Large Language Models via Sparse Autoencoders”
We not only reproduce the core ideas — identifying reasoning features via Sparse Autoencoders (SAE) — but build a full end-to-end reasoning pipeline with visualization, prompt optimization, and feature steering.
🚀 What We Built
- 🔬 Extracted hidden activations from TinyLlama using Hugging Face Transformers
- 🧬 Trained Sparse Autoencoders to compress and expose interpretable features
- 📊 Computed ReasonScore using top-k sparse activations
- 🧠 Steered generation by activating reasoning features
- 📈 Built a Gradio Explainability Dashboard to visualize features, tokens, and decisions
- ⚡ Integrated DSPy for future prompt optimization
✅ Feature Checklist
Here’s a snapshot of the system:
| Feature | Status | No it’s disappointing actually I hate when it does **** like that |——–|——–| | Activation extraction & caching | ✅ | | SAE config & training | ✅ | | ReasonScore keyword + token modes | ✅ | | Feature steering via DSPy module | ✅ | | Dashboard with visualization | ✅ | | Prompt optimizer loop | ⏳ In Progress | | DSPy Teleprompter support | ⏳ Next | | Postgres + pgvector logging | ✅ | | RAG-based filtering & prompt validation | ✅ | | Prompt-triggered explainability from agents | ✅ |
📦 Dataset
We use a filtered subset of the OpenOrca dataset (openorca-1k-short
) to ensure consistent prompt length and GPU-friendliness:
prompts:
dataset: ernanhughes/openorca-1k-short
text_column: question
max_seq_len: 64
pad_to_max: true
🔧 Training the Sparse Autoencoder
Each LLM layer’s activation is flattened from [B, T, H] → [B, T*H]
and passed to a small autoencoder. We log:
- Loss curves
- Layer index
- Hidden dimension
- Sparsity targets
✅ All configs are Hydra-powered and traceable via JSONL and PostgreSQL.
🧪 ReasonScore: Measuring Reasoning Activation
To rank which prompts engage reasoning:
compute_mean_topk_feature_score(hidden_states, sae, top_k=20)
We support:
- ReasonScore from keyword token matches
- ReasonScore from token ID sets
- Optional top-k normalization
🧠 Feature Steering with DSPy
We implemented a SAESteeringModule
that activates reasoning features during prompt generation:
steerer = SAESteeringModule(model, tokenizer, sae, top_features=[3377, 2101])
result = steerer({"instruction": "Why do people vote in elections?"})
This enables us to:
- Trigger stronger reasoning activations
- Log and compare steered vs baseline generations
🖼️ Gradio Explainability Dashboard
This visual interface lets analysts:
- See which reasoning features were activated
- View the top evidence tokens and context
- Explore token-level scores as a heatmap
📍 Bonus: We plan to integrate a “1-click optimizer” using DSPy’s Teleprompter
.
🧰 What’s Next
- Integrating multi-prompt prompt optimization loops
- Visualizing feature clusters with UMAP
- Allowing analysts to flag, re-score, and steer prompts manually
- Pushing the system to explain financial decisions (our long-term goal)
📎 Repo & Demos
I