Uncovering Reasoning in LLMs with Sparse Autoencoders

Page content

Summary

Large Language Models (LLMs) like DeepSeek-R1 show remarkable reasoning abilities, but how these abilities are internally represented has remained a mystery. This paper explores the mechanistic interpretability of reasoning in LLMs using Sparse Autoencoders (SAEs) — a tool that decomposes LLM activations into human-interpretable features. In this post, we’ll: • Explain the SAE architecture used • Compute and visualize ReasonScore • Explore feature steering with sample completions • Provide live visualizations using Python + Streamlit

1. Sparse Autoencoders: Theory + Code

SAE Architecture

Sparse Autoencoders aim to reconstruct LLM activations using a sparse hidden representation:

We have talked about these guys in this post: Detecting AI-Generated Text: Challenges and Solutions


import torch
import torch.nn
from pathlib import Path
import yaml
from typing import Dict, Any, Tuple, Union
from torch import Tensor


class SparseAutoencoder(torch.nn.Module):
    """
    Sparse Autoencoder implementation in PyTorch.
    
    This class implements an autoencoder with sparsity constraints on the latent representation.
    The model includes methods for saving/loading, computing sparsity loss, and summarizing its configuration.
    
    Args:
        input_dim (int): Dimensionality of the input data.
        hidden_dim (int): Dimensionality of the latent space.
        activation (str): Activation function to use in the encoder (default: "relu").
        sparsity_target (float): Target sparsity level for latent activations.
        sparsity_penalty (float): Weight for the sparsity regularization term.
        layer_index (int, optional): Index of the layer in a larger model (used for context).
    """

    def __init__(self, 
                 input_dim: int, 
                 hidden_dim: int, 
                 activation: str = "relu", 
                 sparsity_target: float = 0.03, 
                 sparsity_penalty: float = 0.0001,
                 layer_index: int = None):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.encoder = torch.nn.Linear(input_dim, hidden_dim)
        self.decoder = torch.nn.Linear(hidden_dim, input_dim)
        
        # Validate and set activation function
        if activation not in ["relu", "tanh", "sigmoid", "leaky_relu"]:
            raise ValueError(f"Unsupported activation function: {activation}")
        self.activation_fn = getattr(torch.nn.functional, activation)
        
        self.sparsity_target = sparsity_target
        self.sparsity_penalty = sparsity_penalty
        self.layer_index = layer_index

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Perform the forward pass through the sparse autoencoder.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
        
        Returns:
            Tuple[torch.Tensor, torch.Tensor]: 
                - recon (torch.Tensor): Reconstructed input of shape (batch_size, input_dim).
                - z (torch.Tensor): Latent representation of shape (batch_size, hidden_dim).
        """
        z = self.activation_fn(self.encoder(x))
        recon = self.decoder(z)
        return recon, z

    def compute_sparsity_loss(self, z: Tensor) -> Tensor:
        """
        Compute the sparsity loss using L1 regularization.
        
        Args:
            z (torch.Tensor): Latent representation of shape (batch_size, hidden_dim).
        
        Returns:
            torch.Tensor: Sparsity loss value.
        """
        return self.sparsity_penalty * torch.mean(torch.abs(z))

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "SparseAutoencoder":
        """
        Create a SparseAutoencoder instance from a configuration dictionary.
        
        Args:
            config (Dict[str, Any]): Dictionary containing model hyperparameters.
        
        Returns:
            SparseAutoencoder: Initialized model instance.
        """
        return cls(
            input_dim=config["input_dim"],
            hidden_dim=config["hidden_dim"],
            activation=config.get("activation", "relu"),
            sparsity_target=config.get("sparsity_target", 0.03),
            sparsity_penalty=config.get("sparsity_penalty", 0.0001),
            layer_index=config.get("layer_index", None)
        )

    def save(self, path: Union[str, Path]):
        """
        Save the model's state_dict and configuration to the specified directory.
        
        Args:
            path (Union[str, Path]): Directory path where the model will be saved.
        """
        path = Path(path)
        path.mkdir(parents=True, exist_ok=True)
        torch.save(self.state_dict(), path / "pytorch_model.bin")
        config = {
            "input_dim": self.input_dim,
            "hidden_dim": self.hidden_dim,
            "activation": self.activation_fn.__name__,
            "sparsity_target": self.sparsity_target,
            "sparsity_penalty": self.sparsity_penalty,
            "layer_index": self.layer_index
        }
        with open(path / "config.yaml", "w") as f:
            yaml.dump(config, f)

    @classmethod
    def load(cls, path: Union[str, Path], device: str = None) -> "SparseAutoencoder":
        """
        Load a SparseAutoencoder instance from the specified directory.
        
        Args:
            path (Union[str, Path]): Directory path where the model is saved.
            device (str, optional): Device to load the model onto (e.g., "cpu" or "cuda").
        
        Returns:
            SparseAutoencoder: Loaded model instance.
        """
        path = Path(path)
        with open(path / "config.yaml", "r") as f:
            config = yaml.safe_load(f)
        model = cls.from_config(config)
        state_dict = torch.load(path / "pytorch_model.bin", map_location=device or "cpu")
        model.load_state_dict(state_dict)
        return model

    @staticmethod
    def compute_topk_activations_mean(hidden_states: Tensor, sae_model: "SparseAutoencoder", top_k: int = 20) -> float:
        """
        Compute the mean of the top-k activations in the latent space.
        
        Args:
            hidden_states (torch.Tensor): Hidden states from an intermediate layer of shape (1, seq_len, hidden_dim).
            sae_model (SparseAutoencoder): Trained SparseAutoencoder instance.
            top_k (int): Number of top activations to consider.
        
        Returns:
            float: Mean of the top-k activations.
        """
        _, z = sae_model(hidden_states.squeeze(0))  # shape: (seq_len, hidden_dim)
        token_scores = torch.abs(z).sum(dim=0)  # L1 norm per feature across tokens
        topk_vals, _ = torch.topk(token_scores, k=top_k)
        return topk_vals.mean().item()

    def summary(self, print_output: bool = True) -> Dict[str, Any]:
        """
        Generate a summary of the model's configuration and parameters.
        
        Args:
            print_output (bool): Whether to print the summary to the console.
        
        Returns:
            Dict[str, Any]: Summary information about the model.
        """
        info = {
            "input_dim": self.input_dim,
            "hidden_dim": self.hidden_dim,
            "activation": self.activation_fn.__name__,
            "sparsity_target": self.sparsity_target,
            "sparsity_penalty": self.sparsity_penalty,
            "layer_index": self.layer_index,
            "total_params": sum(p.numel() for p in self.parameters()),
            "trainable_params": sum(p.numel() for p in self.parameters() if p.requires_grad)
        }

        if print_output:
            print("🔍 Sparse Autoencoder Summary:")
            for k, v in info.items():
                print(f"  {k}: {v}")

        return info

2. Extracting Reasoning Features

We compute a ReasonScore based on how much a feature activates on reasoning-related tokens. Python Code to Compute ReasonScore

Perfect! Here’s a draft blog post outline based on your implementation roadmap for the paper:


🔍 Building Explainable Reasoning in LLMs: A Hands-On Demo of Sparse Feature Steering

Published: March 2025
By: [Your Name]


🧠 Overview

In this post, I share my experience implementing and extending the ideas from the paper:

“I Have Covered All the Bases Here: Interpreting Reasoning Features in Large Language Models via Sparse Autoencoders”

We not only reproduce the core ideas — identifying reasoning features via Sparse Autoencoders (SAE) — but build a full end-to-end reasoning pipeline with visualization, prompt optimization, and feature steering.


🚀 What We Built

  • 🔬 Extracted hidden activations from TinyLlama using Hugging Face Transformers
  • 🧬 Trained Sparse Autoencoders to compress and expose interpretable features
  • 📊 Computed ReasonScore using top-k sparse activations
  • 🧠 Steered generation by activating reasoning features
  • 📈 Built a Gradio Explainability Dashboard to visualize features, tokens, and decisions
  • ⚡ Integrated DSPy for future prompt optimization

✅ Feature Checklist

Here’s a snapshot of the system:

| Feature | Status | No it’s disappointing actually I hate when it does **** like that |——–|——–| | Activation extraction & caching | ✅ | | SAE config & training | ✅ | | ReasonScore keyword + token modes | ✅ | | Feature steering via DSPy module | ✅ | | Dashboard with visualization | ✅ | | Prompt optimizer loop | ⏳ In Progress | | DSPy Teleprompter support | ⏳ Next | | Postgres + pgvector logging | ✅ | | RAG-based filtering & prompt validation | ✅ | | Prompt-triggered explainability from agents | ✅ |


📦 Dataset

We use a filtered subset of the OpenOrca dataset (openorca-1k-short) to ensure consistent prompt length and GPU-friendliness:

prompts:
  dataset: ernanhughes/openorca-1k-short
  text_column: question
  max_seq_len: 64
  pad_to_max: true

🔧 Training the Sparse Autoencoder

Each LLM layer’s activation is flattened from [B, T, H] → [B, T*H] and passed to a small autoencoder. We log:

  • Loss curves
  • Layer index
  • Hidden dimension
  • Sparsity targets

✅ All configs are Hydra-powered and traceable via JSONL and PostgreSQL.


🧪 ReasonScore: Measuring Reasoning Activation

To rank which prompts engage reasoning:

compute_mean_topk_feature_score(hidden_states, sae, top_k=20)

We support:

  • ReasonScore from keyword token matches
  • ReasonScore from token ID sets
  • Optional top-k normalization

🧠 Feature Steering with DSPy

We implemented a SAESteeringModule that activates reasoning features during prompt generation:

steerer = SAESteeringModule(model, tokenizer, sae, top_features=[3377, 2101])
result = steerer({"instruction": "Why do people vote in elections?"})

This enables us to:

  • Trigger stronger reasoning activations
  • Log and compare steered vs baseline generations

🖼️ Gradio Explainability Dashboard

dashboard-preview

This visual interface lets analysts:

  • See which reasoning features were activated
  • View the top evidence tokens and context
  • Explore token-level scores as a heatmap

📍 Bonus: We plan to integrate a “1-click optimizer” using DSPy’s Teleprompter.


🧰 What’s Next

  • Integrating multi-prompt prompt optimization loops
  • Visualizing feature clusters with UMAP
  • Allowing analysts to flag, re-score, and steer prompts manually
  • Pushing the system to explain financial decisions (our long-term goal)

📎 Repo & Demos


I