From 5ebe376dd7fa6ac9ed89f8d0c3f21b9c19fc3325 Mon Sep 17 00:00:00 2001 From: Noam Diamant Date: Wed, 20 Aug 2025 17:51:07 +0300 Subject: [PATCH 1/4] moved BatchTopKSAE to dictionary.py --- dictionary_learning/__init__.py | 15 ++- dictionary_learning/dictionary.py | 94 +++++++++++++++++++ dictionary_learning/trainers/__init__.py | 3 +- dictionary_learning/trainers/batch_top_k.py | 83 +--------------- dictionary_learning/trainers/jumprelu.py | 3 +- .../trainers/matryoshka_batch_top_k.py | 3 +- dictionary_learning/trainers/top_k.py | 3 +- dictionary_learning/trainers/trainer.py | 20 ---- 8 files changed, 112 insertions(+), 112 deletions(-) diff --git a/dictionary_learning/__init__.py b/dictionary_learning/__init__.py index 2067aaa..a485b1a 100644 --- a/dictionary_learning/__init__.py +++ b/dictionary_learning/__init__.py @@ -1,6 +1,17 @@ __version__ = "0.1.0" -from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder +from .dictionary import ( + AutoEncoder, + GatedAutoEncoder, + JumpReluAutoEncoder, + BatchTopKSAE, +) from .buffer import ActivationBuffer -__all__ = ["AutoEncoder", "GatedAutoEncoder", "JumpReluAutoEncoder", "ActivationBuffer"] +__all__ = [ + "AutoEncoder", + "GatedAutoEncoder", + "JumpReluAutoEncoder", + "BatchTopKSAE", + "ActivationBuffer", +] diff --git a/dictionary_learning/dictionary.py b/dictionary_learning/dictionary.py index 238a866..d51e392 100644 --- a/dictionary_learning/dictionary.py +++ b/dictionary_learning/dictionary.py @@ -8,6 +8,23 @@ import torch.nn.init as init import einops +@t.no_grad() +def set_decoder_norm_to_unit_norm( + W_dec_DF: t.nn.Parameter, activation_dim: int, d_sae: int +) -> t.Tensor: + """There's a major footgun here: we use this with both nn.Linear and nn.Parameter decoders. + nn.Linear stores the decoder weights in a transposed format (d_model, d_sae). So, we pass the dimensions in + to catch this error.""" + + D, F = W_dec_DF.shape + + assert D == activation_dim + assert F == d_sae + + eps = t.finfo(W_dec_DF.dtype).eps + norm = t.norm(W_dec_DF.data, dim=0, keepdim=True) + W_dec_DF.data /= norm + eps + return W_dec_DF.data class Dictionary(ABC, nn.Module): """ @@ -379,6 +396,83 @@ def from_pretrained( device = autoencoder.W_enc.device return autoencoder.to(dtype=dtype, device=device) +class BatchTopKSAE(Dictionary, nn.Module): + def __init__(self, activation_dim: int, dict_size: int, k: int): + super().__init__() + self.activation_dim = activation_dim + self.dict_size = dict_size + + assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" + self.register_buffer("k", t.tensor(k, dtype=t.int)) + self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) + + self.decoder = nn.Linear(dict_size, activation_dim, bias=False) + self.decoder.weight.data = set_decoder_norm_to_unit_norm( + self.decoder.weight, activation_dim, dict_size + ) + + self.encoder = nn.Linear(activation_dim, dict_size) + self.encoder.weight.data = self.decoder.weight.T.clone() + self.encoder.bias.data.zero_() + self.b_dec = nn.Parameter(t.zeros(activation_dim)) + + def encode( + self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True + ): + post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) + + if use_threshold: + encoded_acts_BF = post_relu_feat_acts_BF * ( + post_relu_feat_acts_BF > self.threshold + ) + else: + # Flatten and perform batch top-k + flattened_acts = post_relu_feat_acts_BF.flatten() + post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) + + encoded_acts_BF = ( + t.zeros_like(post_relu_feat_acts_BF.flatten()) + .scatter_(-1, post_topk.indices, post_topk.values) + .reshape(post_relu_feat_acts_BF.shape) + ) + + if return_active: + return encoded_acts_BF, encoded_acts_BF.sum(0) > 0, post_relu_feat_acts_BF + else: + return encoded_acts_BF + + def decode(self, x: t.Tensor) -> t.Tensor: + return self.decoder(x) + self.b_dec + + def forward(self, x: t.Tensor, output_features: bool = False): + encoded_acts_BF = self.encode(x) + x_hat_BD = self.decode(encoded_acts_BF) + + if not output_features: + return x_hat_BD + else: + return x_hat_BD, encoded_acts_BF + + def scale_biases(self, scale: float): + self.encoder.bias.data *= scale + self.b_dec.data *= scale + if self.threshold >= 0: + self.threshold *= scale + + @classmethod + def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": + state_dict = t.load(path) + dict_size, activation_dim = state_dict["encoder.weight"].shape + if k is None: + k = state_dict["k"].item() + elif "k" in state_dict and k != state_dict["k"].item(): + raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") + + autoencoder = cls(activation_dim, dict_size, k) + autoencoder.load_state_dict(state_dict) + if device is not None: + autoencoder.to(device) + return autoencoder # TODO merge this with AutoEncoder class AutoEncoderNew(Dictionary, nn.Module): diff --git a/dictionary_learning/trainers/__init__.py b/dictionary_learning/trainers/__init__.py index 4135a82..f115379 100644 --- a/dictionary_learning/trainers/__init__.py +++ b/dictionary_learning/trainers/__init__.py @@ -4,7 +4,7 @@ from .gated_anneal import GatedAnnealTrainer from .top_k import TopKTrainer from .jumprelu import JumpReluTrainer -from .batch_top_k import BatchTopKTrainer, BatchTopKSAE +from .batch_top_k import BatchTopKTrainer __all__ = [ @@ -15,5 +15,4 @@ "TopKTrainer", "JumpReluTrainer", "BatchTopKTrainer", - "BatchTopKSAE", ] diff --git a/dictionary_learning/trainers/batch_top_k.py b/dictionary_learning/trainers/batch_top_k.py index 8cb2ecf..aded58f 100644 --- a/dictionary_learning/trainers/batch_top_k.py +++ b/dictionary_learning/trainers/batch_top_k.py @@ -5,94 +5,13 @@ from collections import namedtuple from typing import Optional -from ..dictionary import Dictionary +from ..dictionary import Dictionary, set_decoder_norm_to_unit_norm, BatchTopKSAE from ..trainers.trainer import ( SAETrainer, get_lr_schedule, - set_decoder_norm_to_unit_norm, remove_gradient_parallel_to_decoder_directions, ) - -class BatchTopKSAE(Dictionary, nn.Module): - def __init__(self, activation_dim: int, dict_size: int, k: int): - super().__init__() - self.activation_dim = activation_dim - self.dict_size = dict_size - - assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" - self.register_buffer("k", t.tensor(k, dtype=t.int)) - self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) - - self.decoder = nn.Linear(dict_size, activation_dim, bias=False) - self.decoder.weight.data = set_decoder_norm_to_unit_norm( - self.decoder.weight, activation_dim, dict_size - ) - - self.encoder = nn.Linear(activation_dim, dict_size) - self.encoder.weight.data = self.decoder.weight.T.clone() - self.encoder.bias.data.zero_() - self.b_dec = nn.Parameter(t.zeros(activation_dim)) - - def encode( - self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True - ): - post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) - - if use_threshold: - encoded_acts_BF = post_relu_feat_acts_BF * ( - post_relu_feat_acts_BF > self.threshold - ) - else: - # Flatten and perform batch top-k - flattened_acts = post_relu_feat_acts_BF.flatten() - post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) - - encoded_acts_BF = ( - t.zeros_like(post_relu_feat_acts_BF.flatten()) - .scatter_(-1, post_topk.indices, post_topk.values) - .reshape(post_relu_feat_acts_BF.shape) - ) - - if return_active: - return encoded_acts_BF, encoded_acts_BF.sum(0) > 0, post_relu_feat_acts_BF - else: - return encoded_acts_BF - - def decode(self, x: t.Tensor) -> t.Tensor: - return self.decoder(x) + self.b_dec - - def forward(self, x: t.Tensor, output_features: bool = False): - encoded_acts_BF = self.encode(x) - x_hat_BD = self.decode(encoded_acts_BF) - - if not output_features: - return x_hat_BD - else: - return x_hat_BD, encoded_acts_BF - - def scale_biases(self, scale: float): - self.encoder.bias.data *= scale - self.b_dec.data *= scale - if self.threshold >= 0: - self.threshold *= scale - - @classmethod - def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": - state_dict = t.load(path) - dict_size, activation_dim = state_dict["encoder.weight"].shape - if k is None: - k = state_dict["k"].item() - elif "k" in state_dict and k != state_dict["k"].item(): - raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") - - autoencoder = cls(activation_dim, dict_size, k) - autoencoder.load_state_dict(state_dict) - if device is not None: - autoencoder.to(device) - return autoencoder - - class BatchTopKTrainer(SAETrainer): def __init__( self, diff --git a/dictionary_learning/trainers/jumprelu.py b/dictionary_learning/trainers/jumprelu.py index 2ac717a..af4efdb 100644 --- a/dictionary_learning/trainers/jumprelu.py +++ b/dictionary_learning/trainers/jumprelu.py @@ -5,12 +5,11 @@ from torch import nn from typing import Optional -from ..dictionary import Dictionary, JumpReluAutoEncoder +from ..dictionary import Dictionary, JumpReluAutoEncoder, set_decoder_norm_to_unit_norm from ..trainers.trainer import ( SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, - set_decoder_norm_to_unit_norm, remove_gradient_parallel_to_decoder_directions, ) diff --git a/dictionary_learning/trainers/matryoshka_batch_top_k.py b/dictionary_learning/trainers/matryoshka_batch_top_k.py index 03c195b..a20c35f 100644 --- a/dictionary_learning/trainers/matryoshka_batch_top_k.py +++ b/dictionary_learning/trainers/matryoshka_batch_top_k.py @@ -6,11 +6,10 @@ from typing import Optional from math import isclose -from ..dictionary import Dictionary +from ..dictionary import Dictionary, set_decoder_norm_to_unit_norm from ..trainers.trainer import ( SAETrainer, get_lr_schedule, - set_decoder_norm_to_unit_norm, remove_gradient_parallel_to_decoder_directions, ) diff --git a/dictionary_learning/trainers/top_k.py b/dictionary_learning/trainers/top_k.py index e81259f..9ca39cd 100644 --- a/dictionary_learning/trainers/top_k.py +++ b/dictionary_learning/trainers/top_k.py @@ -10,11 +10,10 @@ from typing import Optional from ..config import DEBUG -from ..dictionary import Dictionary +from ..dictionary import Dictionary, set_decoder_norm_to_unit_norm from ..trainers.trainer import ( SAETrainer, get_lr_schedule, - set_decoder_norm_to_unit_norm, remove_gradient_parallel_to_decoder_directions, ) diff --git a/dictionary_learning/trainers/trainer.py b/dictionary_learning/trainers/trainer.py index 15eb4ed..ffa47df 100644 --- a/dictionary_learning/trainers/trainer.py +++ b/dictionary_learning/trainers/trainer.py @@ -61,26 +61,6 @@ def step(self, closure=None): p /= p.norm(dim=0, keepdim=True) -# The next two functions could be replaced with the ConstrainedAdam Optimizer -@torch.no_grad() -def set_decoder_norm_to_unit_norm( - W_dec_DF: torch.nn.Parameter, activation_dim: int, d_sae: int -) -> torch.Tensor: - """There's a major footgun here: we use this with both nn.Linear and nn.Parameter decoders. - nn.Linear stores the decoder weights in a transposed format (d_model, d_sae). So, we pass the dimensions in - to catch this error.""" - - D, F = W_dec_DF.shape - - assert D == activation_dim - assert F == d_sae - - eps = torch.finfo(W_dec_DF.dtype).eps - norm = torch.norm(W_dec_DF.data, dim=0, keepdim=True) - W_dec_DF.data /= norm + eps - return W_dec_DF.data - - @torch.no_grad() def remove_gradient_parallel_to_decoder_directions( W_dec_DF: torch.Tensor, From c95673a8d8f5c41c2f5956d699a30e58e725520b Mon Sep 17 00:00:00 2001 From: Noam Diamant Date: Wed, 20 Aug 2025 18:01:53 +0300 Subject: [PATCH 2/4] moved MatryoshkaBatchTopKSAE to dictionary.py --- dictionary_learning/__init__.py | 2 + dictionary_learning/dictionary.py | 102 +++++++++++++++++ dictionary_learning/trainers/__init__.py | 2 + .../trainers/matryoshka_batch_top_k.py | 106 +----------------- 4 files changed, 107 insertions(+), 105 deletions(-) diff --git a/dictionary_learning/__init__.py b/dictionary_learning/__init__.py index a485b1a..a5d143c 100644 --- a/dictionary_learning/__init__.py +++ b/dictionary_learning/__init__.py @@ -5,6 +5,7 @@ GatedAutoEncoder, JumpReluAutoEncoder, BatchTopKSAE, + MatryoshkaBatchTopKSAE, ) from .buffer import ActivationBuffer @@ -13,5 +14,6 @@ "GatedAutoEncoder", "JumpReluAutoEncoder", "BatchTopKSAE", + "MatryoshkaBatchTopKSAE", "ActivationBuffer", ] diff --git a/dictionary_learning/dictionary.py b/dictionary_learning/dictionary.py index d51e392..dbf3b4b 100644 --- a/dictionary_learning/dictionary.py +++ b/dictionary_learning/dictionary.py @@ -474,6 +474,108 @@ def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": autoencoder.to(device) return autoencoder +class MatryoshkaBatchTopKSAE(Dictionary, nn.Module): + def __init__( + self, activation_dim: int, dict_size: int, k: int, group_sizes: list[int] + ): + super().__init__() + self.activation_dim = activation_dim + self.dict_size = dict_size + + assert sum(group_sizes) == dict_size, "group sizes must sum to dict_size" + assert all(s > 0 for s in group_sizes), "all group sizes must be positive" + + assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" + self.register_buffer("k", t.tensor(k, dtype=t.int)) + self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) + + self.active_groups = len(group_sizes) + group_indices = [0] + list(t.cumsum(t.tensor(group_sizes), dim=0)) + self.group_indices = group_indices + + self.register_buffer("group_sizes", t.tensor(group_sizes)) + + self.W_enc = nn.Parameter(t.empty(activation_dim, dict_size)) + self.b_enc = nn.Parameter(t.zeros(dict_size)) + self.W_dec = nn.Parameter( + t.nn.init.kaiming_uniform_(t.empty(dict_size, activation_dim)) + ) + self.b_dec = nn.Parameter(t.zeros(activation_dim)) + + # We must transpose because we are using nn.Parameter, not nn.Linear + self.W_dec.data = set_decoder_norm_to_unit_norm( + self.W_dec.data.T, activation_dim, dict_size + ).T + self.W_enc.data = self.W_dec.data.clone().T + + def encode( + self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True + ): + post_relu_feat_acts_BF = nn.functional.relu( + (x - self.b_dec) @ self.W_enc + self.b_enc + ) + + if use_threshold: + encoded_acts_BF = post_relu_feat_acts_BF * ( + post_relu_feat_acts_BF > self.threshold + ) + else: + # Flatten and perform batch top-k + flattened_acts = post_relu_feat_acts_BF.flatten() + post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) + + encoded_acts_BF = ( + t.zeros_like(post_relu_feat_acts_BF.flatten()) + .scatter_(-1, post_topk.indices, post_topk.values) + .reshape(post_relu_feat_acts_BF.shape) + ) + + max_act_index = self.group_indices[self.active_groups] + encoded_acts_BF[:, max_act_index:] = 0 + + if return_active: + return encoded_acts_BF, encoded_acts_BF.sum(0) > 0, post_relu_feat_acts_BF + else: + return encoded_acts_BF + + def decode(self, x: t.Tensor) -> t.Tensor: + return x @ self.W_dec + self.b_dec + + def forward(self, x: t.Tensor, output_features: bool = False): + encoded_acts_BF = self.encode(x) + x_hat_BD = self.decode(encoded_acts_BF) + + if not output_features: + return x_hat_BD + else: + return x_hat_BD, encoded_acts_BF + + @t.no_grad() + def scale_biases(self, scale: float): + self.b_enc.data *= scale + self.b_dec.data *= scale + if self.threshold >= 0: + self.threshold *= scale + + @classmethod + def from_pretrained( + cls, path, k=None, device=None, **kwargs + ) -> "MatryoshkaBatchTopKSAE": + state_dict = t.load(path) + activation_dim, dict_size = state_dict["W_enc"].shape + if k is None: + k = state_dict["k"].item() + elif "k" in state_dict and k != state_dict["k"].item(): + raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") + + group_sizes = state_dict["group_sizes"].tolist() + + autoencoder = cls(activation_dim, dict_size, k=k, group_sizes=group_sizes) + autoencoder.load_state_dict(state_dict) + if device is not None: + autoencoder.to(device) + return autoencoder + # TODO merge this with AutoEncoder class AutoEncoderNew(Dictionary, nn.Module): """ diff --git a/dictionary_learning/trainers/__init__.py b/dictionary_learning/trainers/__init__.py index f115379..cd3e14e 100644 --- a/dictionary_learning/trainers/__init__.py +++ b/dictionary_learning/trainers/__init__.py @@ -5,6 +5,7 @@ from .top_k import TopKTrainer from .jumprelu import JumpReluTrainer from .batch_top_k import BatchTopKTrainer +from .matryoshka_batch_top_k import MatryoshkaBatchTopKTrainer __all__ = [ @@ -15,4 +16,5 @@ "TopKTrainer", "JumpReluTrainer", "BatchTopKTrainer", + "MatryoshkaBatchTopKTrainer", ] diff --git a/dictionary_learning/trainers/matryoshka_batch_top_k.py b/dictionary_learning/trainers/matryoshka_batch_top_k.py index a20c35f..ef691ac 100644 --- a/dictionary_learning/trainers/matryoshka_batch_top_k.py +++ b/dictionary_learning/trainers/matryoshka_batch_top_k.py @@ -6,7 +6,7 @@ from typing import Optional from math import isclose -from ..dictionary import Dictionary, set_decoder_norm_to_unit_norm +from ..dictionary import Dictionary, set_decoder_norm_to_unit_norm, MatryoshkaBatchTopKSAE from ..trainers.trainer import ( SAETrainer, get_lr_schedule, @@ -32,110 +32,6 @@ def apply_temperature(probabilities: list[float], temperature: float) -> list[fl return scaled_probs.tolist() - -class MatryoshkaBatchTopKSAE(Dictionary, nn.Module): - def __init__( - self, activation_dim: int, dict_size: int, k: int, group_sizes: list[int] - ): - super().__init__() - self.activation_dim = activation_dim - self.dict_size = dict_size - - assert sum(group_sizes) == dict_size, "group sizes must sum to dict_size" - assert all(s > 0 for s in group_sizes), "all group sizes must be positive" - - assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" - self.register_buffer("k", t.tensor(k, dtype=t.int)) - self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) - - self.active_groups = len(group_sizes) - group_indices = [0] + list(t.cumsum(t.tensor(group_sizes), dim=0)) - self.group_indices = group_indices - - self.register_buffer("group_sizes", t.tensor(group_sizes)) - - self.W_enc = nn.Parameter(t.empty(activation_dim, dict_size)) - self.b_enc = nn.Parameter(t.zeros(dict_size)) - self.W_dec = nn.Parameter( - t.nn.init.kaiming_uniform_(t.empty(dict_size, activation_dim)) - ) - self.b_dec = nn.Parameter(t.zeros(activation_dim)) - - # We must transpose because we are using nn.Parameter, not nn.Linear - self.W_dec.data = set_decoder_norm_to_unit_norm( - self.W_dec.data.T, activation_dim, dict_size - ).T - self.W_enc.data = self.W_dec.data.clone().T - - def encode( - self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True - ): - post_relu_feat_acts_BF = nn.functional.relu( - (x - self.b_dec) @ self.W_enc + self.b_enc - ) - - if use_threshold: - encoded_acts_BF = post_relu_feat_acts_BF * ( - post_relu_feat_acts_BF > self.threshold - ) - else: - # Flatten and perform batch top-k - flattened_acts = post_relu_feat_acts_BF.flatten() - post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) - - encoded_acts_BF = ( - t.zeros_like(post_relu_feat_acts_BF.flatten()) - .scatter_(-1, post_topk.indices, post_topk.values) - .reshape(post_relu_feat_acts_BF.shape) - ) - - max_act_index = self.group_indices[self.active_groups] - encoded_acts_BF[:, max_act_index:] = 0 - - if return_active: - return encoded_acts_BF, encoded_acts_BF.sum(0) > 0, post_relu_feat_acts_BF - else: - return encoded_acts_BF - - def decode(self, x: t.Tensor) -> t.Tensor: - return x @ self.W_dec + self.b_dec - - def forward(self, x: t.Tensor, output_features: bool = False): - encoded_acts_BF = self.encode(x) - x_hat_BD = self.decode(encoded_acts_BF) - - if not output_features: - return x_hat_BD - else: - return x_hat_BD, encoded_acts_BF - - @t.no_grad() - def scale_biases(self, scale: float): - self.b_enc.data *= scale - self.b_dec.data *= scale - if self.threshold >= 0: - self.threshold *= scale - - @classmethod - def from_pretrained( - cls, path, k=None, device=None, **kwargs - ) -> "MatryoshkaBatchTopKSAE": - state_dict = t.load(path) - activation_dim, dict_size = state_dict["W_enc"].shape - if k is None: - k = state_dict["k"].item() - elif "k" in state_dict and k != state_dict["k"].item(): - raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") - - group_sizes = state_dict["group_sizes"].tolist() - - autoencoder = cls(activation_dim, dict_size, k=k, group_sizes=group_sizes) - autoencoder.load_state_dict(state_dict) - if device is not None: - autoencoder.to(device) - return autoencoder - - class MatryoshkaBatchTopKTrainer(SAETrainer): def __init__( self, From 27fa480341f7ff6e9df853d9864fb92f043131a6 Mon Sep 17 00:00:00 2001 From: Noam Diamant Date: Wed, 20 Aug 2025 18:07:30 +0300 Subject: [PATCH 3/4] moved TopKSAE to dictionary.py --- dictionary_learning/__init__.py | 2 + dictionary_learning/dictionary.py | 103 ++++++++++++++++++++++ dictionary_learning/trainers/__init__.py | 2 + dictionary_learning/trainers/top_k.py | 107 +---------------------- 4 files changed, 108 insertions(+), 106 deletions(-) diff --git a/dictionary_learning/__init__.py b/dictionary_learning/__init__.py index a5d143c..b516515 100644 --- a/dictionary_learning/__init__.py +++ b/dictionary_learning/__init__.py @@ -6,6 +6,7 @@ JumpReluAutoEncoder, BatchTopKSAE, MatryoshkaBatchTopKSAE, + AutoEncoderTopK, ) from .buffer import ActivationBuffer @@ -15,5 +16,6 @@ "JumpReluAutoEncoder", "BatchTopKSAE", "MatryoshkaBatchTopKSAE", + "AutoEncoderTopK", "ActivationBuffer", ] diff --git a/dictionary_learning/dictionary.py b/dictionary_learning/dictionary.py index dbf3b4b..aa4ec13 100644 --- a/dictionary_learning/dictionary.py +++ b/dictionary_learning/dictionary.py @@ -396,6 +396,109 @@ def from_pretrained( device = autoencoder.W_enc.device return autoencoder.to(dtype=dtype, device=device) +class AutoEncoderTopK(Dictionary, nn.Module): + """ + The top-k autoencoder architecture and initialization used in https://arxiv.org/abs/2406.04093 + NOTE: (From Adam Karvonen) There is an unmaintained implementation using Triton kernels in the topk-triton-implementation branch. + We abandoned it as we didn't notice a significant speedup and it added complications, which are noted + in the AutoEncoderTopK class docstring in that branch. + + With some additional effort, you can train a Top-K SAE with the Triton kernels and modify the state dict for compatibility with this class. + Notably, the Triton kernels currently have the decoder to be stored in nn.Parameter, not nn.Linear, and the decoder weights must also + be stored in the same shape as the encoder. + """ + + def __init__(self, activation_dim: int, dict_size: int, k: int): + super().__init__() + self.activation_dim = activation_dim + self.dict_size = dict_size + + assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" + self.register_buffer("k", t.tensor(k, dtype=t.int)) + self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) + + self.decoder = nn.Linear(dict_size, activation_dim, bias=False) + self.decoder.weight.data = set_decoder_norm_to_unit_norm( + self.decoder.weight, activation_dim, dict_size + ) + + self.encoder = nn.Linear(activation_dim, dict_size) + self.encoder.weight.data = self.decoder.weight.T.clone() + self.encoder.bias.data.zero_() + + self.b_dec = nn.Parameter(t.zeros(activation_dim)) + + def encode( + self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = False + ): + post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) + + if use_threshold: + encoded_acts_BF = post_relu_feat_acts_BF * ( + post_relu_feat_acts_BF > self.threshold + ) + if return_topk: + post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) + return ( + encoded_acts_BF, + post_topk.values, + post_topk.indices, + post_relu_feat_acts_BF, + ) + else: + return encoded_acts_BF + + post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) + + # We can't split immediately due to nnsight + tops_acts_BK = post_topk.values + top_indices_BK = post_topk.indices + + buffer_BF = t.zeros_like(post_relu_feat_acts_BF) + encoded_acts_BF = buffer_BF.scatter_( + dim=-1, index=top_indices_BK, src=tops_acts_BK + ) + + if return_topk: + return encoded_acts_BF, tops_acts_BK, top_indices_BK, post_relu_feat_acts_BF + else: + return encoded_acts_BF + + def decode(self, x: t.Tensor) -> t.Tensor: + return self.decoder(x) + self.b_dec + + def forward(self, x: t.Tensor, output_features: bool = False): + encoded_acts_BF = self.encode(x) + x_hat_BD = self.decode(encoded_acts_BF) + if not output_features: + return x_hat_BD + else: + return x_hat_BD, encoded_acts_BF + + def scale_biases(self, scale: float): + self.encoder.bias.data *= scale + self.b_dec.data *= scale + if self.threshold >= 0: + self.threshold *= scale + + def from_pretrained(path, k: Optional[int] = None, device=None): + """ + Load a pretrained autoencoder from a file. + """ + state_dict = t.load(path) + dict_size, activation_dim = state_dict["encoder.weight"].shape + + if k is None: + k = state_dict["k"].item() + elif "k" in state_dict and k != state_dict["k"].item(): + raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") + + autoencoder = AutoEncoderTopK(activation_dim, dict_size, k) + autoencoder.load_state_dict(state_dict) + if device is not None: + autoencoder.to(device) + return autoencoder + class BatchTopKSAE(Dictionary, nn.Module): def __init__(self, activation_dim: int, dict_size: int, k: int): super().__init__() diff --git a/dictionary_learning/trainers/__init__.py b/dictionary_learning/trainers/__init__.py index cd3e14e..0702343 100644 --- a/dictionary_learning/trainers/__init__.py +++ b/dictionary_learning/trainers/__init__.py @@ -6,6 +6,7 @@ from .jumprelu import JumpReluTrainer from .batch_top_k import BatchTopKTrainer from .matryoshka_batch_top_k import MatryoshkaBatchTopKTrainer +from .top_k import TopKTrainer __all__ = [ @@ -17,4 +18,5 @@ "JumpReluTrainer", "BatchTopKTrainer", "MatryoshkaBatchTopKTrainer", + "TopKTrainer", ] diff --git a/dictionary_learning/trainers/top_k.py b/dictionary_learning/trainers/top_k.py index 9ca39cd..b2ef990 100644 --- a/dictionary_learning/trainers/top_k.py +++ b/dictionary_learning/trainers/top_k.py @@ -10,7 +10,7 @@ from typing import Optional from ..config import DEBUG -from ..dictionary import Dictionary, set_decoder_norm_to_unit_norm +from ..dictionary import Dictionary, set_decoder_norm_to_unit_norm, AutoEncoderTopK from ..trainers.trainer import ( SAETrainer, get_lr_schedule, @@ -46,111 +46,6 @@ def geometric_median(points: t.Tensor, max_iter: int = 100, tol: float = 1e-5): return guess - -class AutoEncoderTopK(Dictionary, nn.Module): - """ - The top-k autoencoder architecture and initialization used in https://arxiv.org/abs/2406.04093 - NOTE: (From Adam Karvonen) There is an unmaintained implementation using Triton kernels in the topk-triton-implementation branch. - We abandoned it as we didn't notice a significant speedup and it added complications, which are noted - in the AutoEncoderTopK class docstring in that branch. - - With some additional effort, you can train a Top-K SAE with the Triton kernels and modify the state dict for compatibility with this class. - Notably, the Triton kernels currently have the decoder to be stored in nn.Parameter, not nn.Linear, and the decoder weights must also - be stored in the same shape as the encoder. - """ - - def __init__(self, activation_dim: int, dict_size: int, k: int): - super().__init__() - self.activation_dim = activation_dim - self.dict_size = dict_size - - assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" - self.register_buffer("k", t.tensor(k, dtype=t.int)) - self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) - - self.decoder = nn.Linear(dict_size, activation_dim, bias=False) - self.decoder.weight.data = set_decoder_norm_to_unit_norm( - self.decoder.weight, activation_dim, dict_size - ) - - self.encoder = nn.Linear(activation_dim, dict_size) - self.encoder.weight.data = self.decoder.weight.T.clone() - self.encoder.bias.data.zero_() - - self.b_dec = nn.Parameter(t.zeros(activation_dim)) - - def encode( - self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = False - ): - post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) - - if use_threshold: - encoded_acts_BF = post_relu_feat_acts_BF * ( - post_relu_feat_acts_BF > self.threshold - ) - if return_topk: - post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) - return ( - encoded_acts_BF, - post_topk.values, - post_topk.indices, - post_relu_feat_acts_BF, - ) - else: - return encoded_acts_BF - - post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) - - # We can't split immediately due to nnsight - tops_acts_BK = post_topk.values - top_indices_BK = post_topk.indices - - buffer_BF = t.zeros_like(post_relu_feat_acts_BF) - encoded_acts_BF = buffer_BF.scatter_( - dim=-1, index=top_indices_BK, src=tops_acts_BK - ) - - if return_topk: - return encoded_acts_BF, tops_acts_BK, top_indices_BK, post_relu_feat_acts_BF - else: - return encoded_acts_BF - - def decode(self, x: t.Tensor) -> t.Tensor: - return self.decoder(x) + self.b_dec - - def forward(self, x: t.Tensor, output_features: bool = False): - encoded_acts_BF = self.encode(x) - x_hat_BD = self.decode(encoded_acts_BF) - if not output_features: - return x_hat_BD - else: - return x_hat_BD, encoded_acts_BF - - def scale_biases(self, scale: float): - self.encoder.bias.data *= scale - self.b_dec.data *= scale - if self.threshold >= 0: - self.threshold *= scale - - def from_pretrained(path, k: Optional[int] = None, device=None): - """ - Load a pretrained autoencoder from a file. - """ - state_dict = t.load(path) - dict_size, activation_dim = state_dict["encoder.weight"].shape - - if k is None: - k = state_dict["k"].item() - elif "k" in state_dict and k != state_dict["k"].item(): - raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") - - autoencoder = AutoEncoderTopK(activation_dim, dict_size, k) - autoencoder.load_state_dict(state_dict) - if device is not None: - autoencoder.to(device) - return autoencoder - - class TopKTrainer(SAETrainer): """ Top-K SAE training scheme. From 5546fe30a61e4b89a3975cbeb04ac150999d209b Mon Sep 17 00:00:00 2001 From: Noam Diamant Date: Wed, 20 Aug 2025 18:10:30 +0300 Subject: [PATCH 4/4] fixed import mistake --- dictionary_learning/dictionary.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dictionary_learning/dictionary.py b/dictionary_learning/dictionary.py index aa4ec13..ac93c08 100644 --- a/dictionary_learning/dictionary.py +++ b/dictionary_learning/dictionary.py @@ -7,6 +7,7 @@ import torch.nn as nn import torch.nn.init as init import einops +from typing import Optional @t.no_grad() def set_decoder_norm_to_unit_norm(