diff --git a/buffer.py b/buffer.py index d997596..6cbf8e3 100644 --- a/buffer.py +++ b/buffer.py @@ -40,7 +40,7 @@ def __init__(self, d_submodule = submodule.out_features except: raise ValueError("d_submodule cannot be inferred and must be specified directly") - self.activations = t.empty(0, d_submodule, device=device) + self.activations = t.empty(0, d_submodule, device=device, dtype=model.dtype) self.read = t.zeros(0).bool() self.data = data @@ -105,7 +105,7 @@ def refresh(self): self.activations = self.activations[~self.read] current_idx = len(self.activations) - new_activations = t.empty(self.activation_buffer_size, self.d_submodule, device=self.device) + new_activations = t.empty(self.activation_buffer_size, self.d_submodule, device=self.device, dtype=self.model.dtype) new_activations[: len(self.activations)] = self.activations self.activations = new_activations diff --git a/dictionary.py b/dictionary.py index 7950cf7..09e80ab 100644 --- a/dictionary.py +++ b/dictionary.py @@ -47,12 +47,16 @@ def __init__(self, activation_dim, dict_size): self.dict_size = dict_size self.bias = nn.Parameter(t.zeros(activation_dim)) self.encoder = nn.Linear(activation_dim, dict_size, bias=True) - - # rows of decoder weight matrix are unit vectors self.decoder = nn.Linear(dict_size, activation_dim, bias=False) - dec_weight = t.randn_like(self.decoder.weight) - dec_weight = dec_weight / dec_weight.norm(dim=0, keepdim=True) - self.decoder.weight = nn.Parameter(dec_weight) + + # initialize encoder and decoder weights + w = t.randn(activation_dim, dict_size) + ## normalize columns of w + w = w / w.norm(dim=0, keepdim=True) * 0.1 + ## set encoder and decoder weights + self.encoder.weight = nn.Parameter(w.clone().T) + self.decoder.weight = nn.Parameter(w.clone()) + def encode(self, x): return nn.ReLU()(self.encoder(x - self.bias)) @@ -86,6 +90,10 @@ def forward(self, x, output_features=False, ghost_mask=None): return x_hat, x_ghost, f else: return x_hat, x_ghost + + def scale_biases(self, scale: float): + self.encoder.bias.data *= scale + self.bias.data *= scale @classmethod def from_pretrained(cls, path, dtype=t.float, device=None): @@ -204,6 +212,11 @@ def forward(self, x, output_features=False): else: return x_hat + def scale_biases(self, scale: float): + self.decoder_bias.data *= scale + self.mag_bias.data *= scale + self.gate_bias.data *= scale + def from_pretrained(path, device=None): """ Load a pretrained autoencoder from a file. @@ -215,6 +228,7 @@ def from_pretrained(path, device=None): if device is not None: autoencoder.to(device) return autoencoder + class JumpReluAutoEncoder(Dictionary, nn.Module): """ @@ -267,6 +281,11 @@ def forward(self, x, output_features=False): return x_hat, f else: return x_hat + + def scale_biases(self, scale: float): + self.b_dec.data *= scale + self.b_enc.data *= scale + self.threshold.data *= scale @classmethod def from_pretrained( @@ -284,9 +303,10 @@ def from_pretrained( """ if not load_from_sae_lens: state_dict = t.load(path) - dict_size, activation_dim = state_dict['W_enc'].shape + activation_dim, dict_size = state_dict['W_enc'].shape autoencoder = JumpReluAutoEncoder(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) + autoencoder = autoencoder.to(dtype=dtype, device=device) else: from sae_lens import SAE sae, cfg_dict, _ = SAE.from_pretrained(**kwargs) diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 8e93cab..b2374ec 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -6,8 +6,8 @@ from dictionary_learning.training import trainSAE from dictionary_learning.trainers.standard import StandardTrainer -from dictionary_learning.trainers.top_k import TrainerTopK, AutoEncoderTopK -from dictionary_learning.utils import hf_dataset_to_generator +from dictionary_learning.trainers.top_k import TopKTrainer, AutoEncoderTopK +from dictionary_learning.utils import hf_dataset_to_generator, get_nested_folders, load_dictionary from dictionary_learning.buffer import ActivationBuffer from dictionary_learning.dictionary import ( AutoEncoder, @@ -58,50 +58,11 @@ EVAL_TOLERANCE = 0.01 -def get_nested_folders(path: str) -> list[str]: - """ - Recursively get a list of folders that contain an ae.pt file, starting the search from the given path - """ - folder_names = [] - - for root, dirs, files in os.walk(path): - if "ae.pt" in files: - folder_names.append(root) - - return folder_names - - -def load_dictionary(base_path: str, device: str) -> tuple: - ae_path = f"{base_path}/ae.pt" - config_path = f"{base_path}/config.json" - - with open(config_path, "r") as f: - config = json.load(f) - - # TODO: Save the submodule name in the config? - # submodule_str = config["trainer"]["submodule_name"] - dict_class = config["trainer"]["dict_class"] - - if dict_class == "AutoEncoder": - dictionary = AutoEncoder.from_pretrained(ae_path, device=device) - elif dict_class == "GatedAutoEncoder": - dictionary = GatedAutoEncoder.from_pretrained(ae_path, device=device) - elif dict_class == "AutoEncoderNew": - dictionary = AutoEncoderNew.from_pretrained(ae_path, device=device) - elif dict_class == "AutoEncoderTopK": - k = config["trainer"]["k"] - dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device) - elif dict_class == "JumpReluAutoEncoder": - dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device) - else: - raise ValueError(f"Dictionary class {dict_class} not supported") - - return dictionary, config - - def test_sae_training(): """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. - This isn't a nice suite of unit tests, but it's better than nothing.""" + This isn't a nice suite of unit tests, but it's better than nothing. + I have observed that results can slightly vary with library versions. For full determinism, + use pytorch 2.2.0 and nnsight 0.3.3.""" random.seed(RANDOM_SEED) t.manual_seed(RANDOM_SEED) @@ -158,7 +119,7 @@ def test_sae_training(): trainer_configs.extend( [ { - "trainer": TrainerTopK, + "trainer": TopKTrainer, "dict_class": AutoEncoderTopK, "activation_dim": activation_dim, "dict_size": expansion_factor * activation_dim, @@ -278,5 +239,12 @@ def test_evaluation(): dict_class = config["trainer"]["dict_class"] expected_results = EXPECTED_RESULTS[dict_class] + max_diff = 0 + max_diff_percent = 0 for key, value in expected_results.items(): - assert abs(eval_results[key] - value) < EVAL_TOLERANCE + diff = abs(eval_results[key] - value) + max_diff = max(max_diff, diff) + max_diff_percent = max(max_diff_percent, diff / value) + + print(f"Max diff: {max_diff}, max diff %: {max_diff_percent}") + assert max_diff < EVAL_TOLERANCE diff --git a/trainers/__init__.py b/trainers/__init__.py index 461af62..81998af 100644 --- a/trainers/__init__.py +++ b/trainers/__init__.py @@ -2,6 +2,6 @@ from .gdm import GatedSAETrainer from .p_anneal import PAnnealTrainer from .gated_anneal import GatedAnnealTrainer -from .top_k import TrainerTopK -from .jumprelu import TrainerJumpRelu -from .batch_top_k import TrainerBatchTopK, BatchTopKSAE +from .top_k import TopKTrainer +from .jumprelu import JumpReluTrainer +from .batch_top_k import BatchTopKTrainer, BatchTopKSAE diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index c65195f..a7fbdc8 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -16,6 +16,7 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" self.register_buffer("k", t.tensor(k)) + self.register_buffer("threshold", t.tensor(-1.0)) self.encoder = nn.Linear(activation_dim, dict_size) self.encoder.bias.data.zero_() @@ -24,9 +25,16 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): self.set_decoder_norm_to_unit_norm() self.b_dec = nn.Parameter(t.zeros(activation_dim)) - def encode(self, x: t.Tensor, return_active: bool = False): + def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True): post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) + if use_threshold: + encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) + if return_active: + return encoded_acts_BF, encoded_acts_BF.sum(0) > 0 + else: + return encoded_acts_BF + # Flatten and perform batch top-k flattened_acts = post_relu_feat_acts_BF.flatten() post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) @@ -75,13 +83,19 @@ def remove_gradient_parallel_to_decoder_directions(self): "d_sae, d_in d_sae -> d_in d_sae", ) + def scale_biases(self, scale: float): + self.encoder.bias.data *= scale + self.b_dec.data *= scale + if self.threshold >= 0: + self.threshold *= scale + @classmethod def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": state_dict = t.load(path) - dict_size, activation_dim = state_dict['encoder.weight'].shape + dict_size, activation_dim = state_dict["encoder.weight"].shape if k is None: - k = state_dict['k'].item() - elif 'k' in state_dict and k != state_dict['k'].item(): + k = state_dict["k"].item() + elif "k" in state_dict and k != state_dict["k"].item(): raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") autoencoder = cls(activation_dim, dict_size, k) @@ -91,7 +105,7 @@ def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": return autoencoder -class TrainerBatchTopK(SAETrainer): +class BatchTopKTrainer(SAETrainer): def __init__( self, dict_class=BatchTopKSAE, @@ -100,6 +114,8 @@ def __init__( k=8, auxk_alpha=1 / 32, decay_start=24000, + threshold_beta=0.999, + threshold_start_step=1000, steps=30000, top_k_aux=512, seed=None, @@ -117,6 +133,8 @@ def __init__( self.wandb_name = wandb_name self.steps = steps self.k = k + self.threshold_beta = threshold_beta + self.threshold_start_step = threshold_start_step if seed is not None: t.manual_seed(seed) @@ -136,9 +154,7 @@ def __init__( self.dead_feature_threshold = 10_000_000 self.top_k_aux = top_k_aux - self.optimizer = t.optim.Adam( - self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999) - ) + self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) def lr_fn(step): if step < decay_start: @@ -165,20 +181,34 @@ def get_auxiliary_loss(self, x, x_reconstruct, acts): acts_aux = t.zeros_like(acts[:, dead_features]).scatter( -1, acts_topk_aux.indices, acts_topk_aux.values ) - x_reconstruct_aux = F.linear( - acts_aux, self.ae.decoder.weight[:, dead_features] - ) + x_reconstruct_aux = F.linear(acts_aux, self.ae.decoder.weight[:, dead_features]) l2_loss_aux = ( - self.auxk_alpha - * (x_reconstruct_aux.float() - residual.float()).pow(2).mean() + self.auxk_alpha * (x_reconstruct_aux.float() - residual.float()).pow(2).mean() ) return l2_loss_aux else: return t.tensor(0, dtype=x.dtype, device=x.device) def loss(self, x, step=None, logging=False): - f, active_indices = self.ae.encode(x, return_active=True) - l0 = (f != 0).float().sum(dim=-1).mean().item() + f, active_indices = self.ae.encode(x, return_active=True, use_threshold=False) + # l0 = (f != 0).float().sum(dim=-1).mean().item() + + if step > self.threshold_start_step: + with t.no_grad(): + active = f[f > 0] + + if active.size(0) == 0: + min_activation = 0.0 + else: + min_activation = active.min().detach() + + if self.ae.threshold < 0: + self.ae.threshold = min_activation + else: + self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( + (1 - self.threshold_beta) * min_activation + ) + x_hat = self.ae.decode(f) e = x_hat - x @@ -230,14 +260,14 @@ def update(self, step, x): @property def config(self): return { - "trainer_class": "TrainerBatchTopK", + "trainer_class": "BatchTopKTrainer", "dict_class": "BatchTopKSAE", "lr": self.lr, "steps": self.steps, "seed": self.seed, "activation_dim": self.ae.activation_dim, "dict_size": self.ae.dict_size, - "k": self.ae.k, + "k": self.ae.k.item(), "device": self.device, "layer": self.layer, "lm_name": self.lm_name, diff --git a/trainers/jumprelu.py b/trainers/jumprelu.py index f87785a..e27e1b3 100644 --- a/trainers/jumprelu.py +++ b/trainers/jumprelu.py @@ -60,7 +60,7 @@ def backward(ctx, grad_output): return x_grad, threshold_grad, None # None for bandwidth -class TrainerJumpRelu(nn.Module, SAETrainer): +class JumpReluTrainer(nn.Module, SAETrainer): """ Trains a JumpReLU autoencoder. @@ -77,7 +77,8 @@ def __init__( # TODO: What's the default lr use in the paper? lr=7e-5, bandwidth=0.001, - sparsity_penalty=0.1, + sparsity_penalty=1.0, + target_l0=20.0, device="cpu", layer=None, lm_name=None, @@ -99,6 +100,7 @@ def __init__( self.bandwidth = bandwidth self.sparsity_coefficient = sparsity_penalty + self.target_l0 = target_l0 # TODO: Better auto-naming (e.g. in BatchTopK package) self.wandb_name = wandb_name @@ -123,7 +125,8 @@ def loss(self, x, logging=False, **_): recon_loss = (x - recon).pow(2).sum(dim=-1).mean() l0 = StepFunction.apply(f, self.ae.threshold, self.bandwidth).sum(dim=-1).mean() - sparsity_loss = self.sparsity_coefficient * l0 + + sparsity_loss = self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2) loss = recon_loss + sparsity_loss if not logging: @@ -153,7 +156,7 @@ def update(self, step, x): @property def config(self): return { - "trainer_class": "TrainerJumpRelu", + "trainer_class": "JumpReluTrainer", "dict_class": "JumpReluAutoEncoder", "lr": self.lr, "steps": self.steps, @@ -165,4 +168,7 @@ def config(self): "lm_name": self.lm_name, "wandb_name": self.wandb_name, "submodule_name": self.submodule_name, + "bandwidth": self.bandwidth, + "sparsity_penalty": self.sparsity_coefficient, + "target_l0": self.target_l0, } diff --git a/trainers/p_anneal.py b/trainers/p_anneal.py index 4a157b9..0138547 100644 --- a/trainers/p_anneal.py +++ b/trainers/p_anneal.py @@ -166,7 +166,7 @@ def lp_norm(self, f, p): def loss(self, x, step, logging=False): # Compute loss terms x_hat, f = self.ae(x, output_features=True) - l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() + recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() lp_loss = self.lp_norm(f, self.p) scaled_lp_loss = lp_loss * self.sparsity_coeff self.lp_loss = lp_loss @@ -201,7 +201,7 @@ def loss(self, x, step, logging=False): self.steps_since_active[~deads] = 0 if logging is False: - return l2_loss + scaled_lp_loss + return recon_loss + scaled_lp_loss else: loss_log = { 'p' : self.p, diff --git a/trainers/standard.py b/trainers/standard.py index 2cfbb6a..07b9b67 100644 --- a/trainers/standard.py +++ b/trainers/standard.py @@ -127,6 +127,7 @@ def resample_neurons(self, deads, activations): def loss(self, x, logging=False, **kwargs): x_hat, f = self.ae(x, output_features=True) l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() + recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() l1_loss = f.norm(p=1, dim=-1).mean() if self.steps_since_active is not None: @@ -135,7 +136,7 @@ def loss(self, x, logging=False, **kwargs): self.steps_since_active[deads] += 1 self.steps_since_active[~deads] = 0 - loss = l2_loss + self.l1_penalty * l1_loss + loss = recon_loss + self.l1_penalty * sparsity_warmup * l1_loss if not logging: return loss @@ -144,7 +145,7 @@ def loss(self, x, logging=False, **kwargs): x, x_hat, f, { 'l2_loss' : l2_loss.item(), - 'mse_loss' : (x - x_hat).pow(2).sum(dim=-1).mean().item(), + 'mse_loss' : recon_loss.item(), 'sparsity_loss' : l1_loss.item(), 'loss' : loss.item() } diff --git a/trainers/top_k.py b/trainers/top_k.py index 33046f5..12c549a 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -7,6 +7,7 @@ import torch as t import torch.nn as nn from collections import namedtuple +from typing import Optional from ..config import DEBUG from ..dictionary import Dictionary @@ -58,7 +59,10 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): super().__init__() self.activation_dim = activation_dim self.dict_size = dict_size - self.k = k + + assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" + self.register_buffer("k", t.tensor(k)) + self.register_buffer("threshold", t.tensor(-1.0)) self.encoder = nn.Linear(activation_dim, dict_size) self.encoder.bias.data.zero_() @@ -69,8 +73,17 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): self.b_dec = nn.Parameter(t.zeros(activation_dim)) - def encode(self, x: t.Tensor, return_topk: bool = False): + def encode(self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = False): post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) + + if use_threshold: + encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) + if return_topk: + post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) + return encoded_acts_BF, post_topk.values, post_topk.indices + else: + return encoded_acts_BF + post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) # We can't split immediately due to nnsight @@ -117,12 +130,24 @@ def remove_gradient_parallel_to_decoder_directions(self): "d_sae, d_in d_sae -> d_in d_sae", ) - def from_pretrained(path, k: int, device=None): + def scale_biases(self, scale: float): + self.encoder.bias.data *= scale + self.b_dec.data *= scale + if self.threshold >= 0: + self.threshold *= scale + + def from_pretrained(path, k: Optional[int] = None, device=None): """ Load a pretrained autoencoder from a file. """ state_dict = t.load(path) dict_size, activation_dim = state_dict["encoder.weight"].shape + + if k is None: + k = state_dict["k"].item() + elif "k" in state_dict and k != state_dict["k"].item(): + raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") + autoencoder = AutoEncoderTopK(activation_dim, dict_size, k) autoencoder.load_state_dict(state_dict) if device is not None: @@ -130,7 +155,7 @@ def from_pretrained(path, k: int, device=None): return autoencoder -class TrainerTopK(SAETrainer): +class TopKTrainer(SAETrainer): """ Top-K SAE training scheme. """ @@ -143,6 +168,8 @@ def __init__( k=100, auxk_alpha=1 / 32, # see Appendix A.2 decay_start=24000, # when does the lr decay start + threshold_beta=0.999, + threshold_start_step=1000, steps=30000, # when when does training end seed=None, device=None, @@ -161,6 +188,9 @@ def __init__( self.wandb_name = wandb_name self.steps = steps self.k = k + self.threshold_beta = threshold_beta + self.threshold_start_step = threshold_start_step + if seed is not None: t.manual_seed(seed) t.cuda.manual_seed_all(seed) @@ -201,7 +231,26 @@ def lr_fn(step): def loss(self, x, step=None, logging=False): # Run the SAE - f, top_acts, top_indices = self.ae.encode(x, return_topk=True) + f, top_acts, top_indices = self.ae.encode(x, return_topk=True, use_threshold=False) + + if step > self.threshold_start_step: + with t.no_grad(): + active = top_acts.clone().detach() + active[active <= 0] = float("inf") + min_activations = active.min(dim=1).values + min_activation = min_activations.mean() + + B, K = active.shape + assert len(active.shape) == 2 + assert min_activations.shape == (B,) + + if self.ae.threshold < 0: + self.ae.threshold = min_activation + else: + self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( + (1 - self.threshold_beta) * min_activation + ) + x_hat = self.ae.decode(f) # Measure goodness of reconstruction @@ -270,6 +319,7 @@ def update(self, step, x): # Initialise the decoder bias if step == 0: median = geometric_median(x) + median = median.to(self.ae.b_dec.dtype) self.ae.b_dec.data = median # Make sure the decoder is still unit-norm @@ -293,14 +343,14 @@ def update(self, step, x): @property def config(self): return { - "trainer_class": "TrainerTopK", + "trainer_class": "TopKTrainer", "dict_class": "AutoEncoderTopK", "lr": self.lr, "steps": self.steps, "seed": self.seed, "activation_dim": self.ae.activation_dim, "dict_size": self.ae.dict_size, - "k": self.ae.k, + "k": self.ae.k.item(), "device": self.device, "layer": self.layer, "lm_name": self.lm_name, diff --git a/training.py b/training.py index 13fd4b3..d4c6a38 100644 --- a/training.py +++ b/training.py @@ -73,6 +73,34 @@ def log_stats( if log_queues: log_queues[i].put(log) +def get_norm_factor(data, steps: int) -> float: + """Per Section 3.1, find a fixed scalar factor so activation vectors have unit mean squared norm. + This is very helpful for hyperparameter transfer between different layers and models. + Use more steps for more accurate results. + https://arxiv.org/pdf/2408.05147 + + If experiencing troubles with hyperparameter transfer between models, it may be worth instead normalizing to the square root of d_model. + https://transformer-circuits.pub/2024/april-update/index.html#training-saes""" + total_mean_squared_norm = 0 + count = 0 + + for step, act_BD in enumerate(tqdm(data, total=steps, desc="Calculating norm factor")): + if step > steps: + break + + count += 1 + mean_squared_norm = t.mean(t.sum(act_BD ** 2, dim=1)) + total_mean_squared_norm += mean_squared_norm + + average_mean_squared_norm = total_mean_squared_norm / count + norm_factor = t.sqrt(average_mean_squared_norm).item() + + print(f"Average mean squared norm: {average_mean_squared_norm}") + print(f"Norm factor: {norm_factor}") + + return norm_factor + + def trainSAE( data, @@ -87,10 +115,16 @@ def trainSAE( activations_split_by_head:bool=False, transcoder:bool=False, run_cfg:dict={}, + normalize_activations:bool=False, ): """ Train SAEs using the given trainers + + If normalize_activations is True, the activations will be normalized to have unit mean squared norm. + The autoencoders weights will be scaled before saving, so the activations don't need to be scaled during inference. + This is very helpful for hyperparameter transfer between different layers and models. """ + trainers = [] for config in trainer_configs: trainer_class = config["trainer"] @@ -130,7 +164,21 @@ def trainSAE( else: save_dirs = [None for _ in trainer_configs] + if normalize_activations: + norm_factor = get_norm_factor(data, steps=100) + + for trainer in trainers: + trainer.config["norm_factor"] = norm_factor + # Verify that all autoencoders have a scale_biases method + trainer.ae.scale_biases(1.0) + for step, act in enumerate(tqdm(data, total=steps)): + + act = act.to(dtype=t.float32) + + if normalize_activations: + act /= norm_factor + if steps is not None and step >= steps: break @@ -144,6 +192,11 @@ def trainSAE( if save_steps is not None and step in save_steps: for dir, trainer in zip(save_dirs, trainers): if dir is not None: + + if normalize_activations: + # Temporarily scale up biases for checkpoint saving + trainer.ae.scale_biases(norm_factor) + if not os.path.exists(os.path.join(dir, "checkpoints")): os.mkdir(os.path.join(dir, "checkpoints")) t.save( @@ -151,12 +204,17 @@ def trainSAE( os.path.join(dir, "checkpoints", f"ae_{step}.pt"), ) + if normalize_activations: + trainer.ae.scale_biases(1 / norm_factor) + # training for trainer in trainers: trainer.update(step, act) # save final SAEs for save_dir, trainer in zip(save_dirs, trainers): + if normalize_activations: + trainer.ae.scale_biases(norm_factor) if save_dir is not None: t.save(trainer.ae.state_dict(), os.path.join(save_dir, "ae.pt")) diff --git a/utils.py b/utils.py index 8641f05..4f34a4e 100644 --- a/utils.py +++ b/utils.py @@ -2,26 +2,94 @@ import zstandard as zstd import io import json +import os +from nnsight import LanguageModel -def hf_dataset_to_generator(dataset_name, split='train', streaming=True): +from dictionary_learning.trainers.top_k import AutoEncoderTopK +from dictionary_learning.trainers.batch_top_k import BatchTopKSAE +from dictionary_learning.dictionary import ( + AutoEncoder, + GatedAutoEncoder, + AutoEncoderNew, + JumpReluAutoEncoder, +) + + +def hf_dataset_to_generator(dataset_name, split="train", streaming=True): dataset = load_dataset(dataset_name, split=split, streaming=streaming) - + def gen(): for x in iter(dataset): - yield x['text'] - + yield x["text"] + return gen() + def zst_to_generator(data_path): """ Load a dataset from a .jsonl.zst file. The jsonl entries is assumed to have a 'text' field """ - compressed_file = open(data_path, 'rb') + compressed_file = open(data_path, "rb") dctx = zstd.ZstdDecompressor() reader = dctx.stream_reader(compressed_file) - text_stream = io.TextIOWrapper(reader, encoding='utf-8') + text_stream = io.TextIOWrapper(reader, encoding="utf-8") + def generator(): for line in text_stream: - yield json.loads(line)['text'] - return generator() \ No newline at end of file + yield json.loads(line)["text"] + + return generator() + + +def get_nested_folders(path: str) -> list[str]: + """ + Recursively get a list of folders that contain an ae.pt file, starting the search from the given path + """ + folder_names = [] + + for root, dirs, files in os.walk(path): + if "ae.pt" in files: + folder_names.append(root) + + return folder_names + + +def load_dictionary(base_path: str, device: str) -> tuple: + ae_path = f"{base_path}/ae.pt" + config_path = f"{base_path}/config.json" + + with open(config_path, "r") as f: + config = json.load(f) + + dict_class = config["trainer"]["dict_class"] + + if dict_class == "AutoEncoder": + dictionary = AutoEncoder.from_pretrained(ae_path, device=device) + elif dict_class == "GatedAutoEncoder": + dictionary = GatedAutoEncoder.from_pretrained(ae_path, device=device) + elif dict_class == "AutoEncoderNew": + dictionary = AutoEncoderNew.from_pretrained(ae_path, device=device) + elif dict_class == "AutoEncoderTopK": + k = config["trainer"]["k"] + dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device) + elif dict_class == "BatchTopKSAE": + dictionary = BatchTopKSAE.from_pretrained(ae_path, device=device) + elif dict_class == "JumpReluAutoEncoder": + dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device) + else: + raise ValueError(f"Dictionary class {dict_class} not supported") + + return dictionary, config + + +def get_submodule(model: LanguageModel, layer: int): + """Gets the residual stream submodule""" + model_name = model._model_key + + if "pythia" in model_name: + return model.gpt_neox.layers[layer] + elif "gemma" in model_name: + return model.model.layers[layer] + else: + raise ValueError(f"Please add submodule for model {model_name}")