From fe54b001cba976ca96d46add8539580268dc5cb6 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 17 Dec 2024 23:33:01 +0000 Subject: [PATCH 1/8] Add a simple end to end test --- requirements.txt | 3 +- tests/test_end_to_end.py | 282 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 284 insertions(+), 1 deletion(-) create mode 100644 tests/test_end_to_end.py diff --git a/requirements.txt b/requirements.txt index 7366e63..5b9f3c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ torch>=2.1.2 tqdm>=4.66.1 umap-learn>=0.5.6 zstandard>=0.22.0 -wandb +wandb>=0.12.0 +pytest>=6.2.4 \ No newline at end of file diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py new file mode 100644 index 0000000..8cb4b95 --- /dev/null +++ b/tests/test_end_to_end.py @@ -0,0 +1,282 @@ +import torch as t +from nnsight import LanguageModel +import os +import json +import random + +from dictionary_learning.training import trainSAE +from dictionary_learning.trainers.standard import StandardTrainer +from dictionary_learning.trainers.top_k import TrainerTopK, AutoEncoderTopK +from dictionary_learning.utils import hf_dataset_to_generator +from dictionary_learning.buffer import ActivationBuffer +from dictionary_learning.dictionary import ( + AutoEncoder, + GatedAutoEncoder, + AutoEncoderNew, + JumpReluAutoEncoder, +) +from dictionary_learning.evaluation import evaluate + +EXPECTED_RESULTS = { + "AutoEncoderTopK": { + "l2_loss": 4.470372676849365, + "l1_loss": 44.47749710083008, + "l0": 40.0, + "frac_alive": 0.000244140625, + "frac_variance_explained": 0.9372208118438721, + "cossim": 0.9471381902694702, + "l2_ratio": 0.9523985981941223, + "relative_reconstruction_bias": 0.9996458888053894, + "loss_original": 3.186223268508911, + "loss_reconstructed": 3.690929412841797, + "loss_zero": 12.936649322509766, + "frac_recovered": 0.9482374787330627, + }, + "AutoEncoder": { + "l2_loss": 6.72230863571167, + "l1_loss": 28.893749237060547, + "l0": 61.12999725341797, + "frac_alive": 0.000244140625, + "frac_variance_explained": 0.6076533794403076, + "cossim": 0.869738757610321, + "l2_ratio": 0.8005934953689575, + "relative_reconstruction_bias": 0.9304398894309998, + "loss_original": 3.186223268508911, + "loss_reconstructed": 5.501500129699707, + "loss_zero": 12.936649322509766, + "frac_recovered": 0.7625460624694824, + }, +} + +DEVICE = "cuda:0" +SAVE_DIR = "./test_data" +MODEL_NAME = "EleutherAI/pythia-70m-deduped" +RANDOM_SEED = 42 +LAYER = 3 +DATASET_NAME = "monology/pile-uncopyrighted" + +EVAL_TOLERANCE = 0.01 + + +def get_nested_folders(path: str) -> list[str]: + """ + Recursively get a list of folders that contain an ae.pt file, starting the search from the given path + """ + folder_names = [] + + for root, dirs, files in os.walk(path): + if "ae.pt" in files: + folder_names.append(root) + + return folder_names + + +def load_dictionary(base_path: str, device: str) -> tuple: + ae_path = f"{base_path}/ae.pt" + config_path = f"{base_path}/config.json" + + with open(config_path, "r") as f: + config = json.load(f) + + # TODO: Save the submodule name in the config? + # submodule_str = config["trainer"]["submodule_name"] + dict_class = config["trainer"]["dict_class"] + + if dict_class == "AutoEncoder": + dictionary = AutoEncoder.from_pretrained(ae_path, device=device) + elif dict_class == "GatedAutoEncoder": + dictionary = GatedAutoEncoder.from_pretrained(ae_path, device=device) + elif dict_class == "AutoEncoderNew": + dictionary = AutoEncoderNew.from_pretrained(ae_path, device=device) + elif dict_class == "AutoEncoderTopK": + k = config["trainer"]["k"] + dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device) + elif dict_class == "JumpReluAutoEncoder": + dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device) + else: + raise ValueError(f"Dictionary class {dict_class} not supported") + + return dictionary, config + + +def test_sae_training(): + """End to end test for training an SAE. Takes ~3 minutes on an RTX 3090. + This isn't a nice suite of unit tests, but it's better than nothing.""" + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + MODEL_NAME = "EleutherAI/pythia-70m-deduped" + model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) + layer = 3 + + context_length = 128 + llm_batch_size = 512 # Fits on a 24GB GPU + sae_batch_size = 8192 + num_contexts_per_sae_batch = sae_batch_size // context_length + + num_inputs_in_buffer = num_contexts_per_sae_batch * 20 + + num_tokens = 10_000_000 + + # sae training parameters + random_seed = 42 + k = 40 + sparsity_penalty = 0.05 + expansion_factor = 8 + + steps = int(num_tokens / sae_batch_size) # Total number of batches to train + save_steps = None + warmup_steps = 1000 # Warmup period at start of training and after each resample + resample_steps = None + + # standard sae training parameters + learning_rate = 3e-4 + + # topk sae training parameters + decay_start = 24000 + auxk_alpha = 1 / 32 + + submodule = model.gpt_neox.layers[LAYER] + submodule_name = f"resid_post_layer_{LAYER}" + io = "out" + activation_dim = model.config.hidden_size + + generator = hf_dataset_to_generator(DATASET_NAME) + + activation_buffer = ActivationBuffer( + generator, + model, + submodule, + n_ctxs=num_inputs_in_buffer, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + # create the list of configs + trainer_configs = [] + trainer_configs.extend( + [ + { + "trainer": TrainerTopK, + "dict_class": AutoEncoderTopK, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "k": k, + "auxk_alpha": auxk_alpha, # see Appendix A.2 + "decay_start": decay_start, # when does the lr decay start + "steps": steps, # when when does training end + "seed": random_seed, + "wandb_name": f"TopKTrainer-{MODEL_NAME}-{submodule_name}", + "device": DEVICE, + "layer": layer, + "lm_name": MODEL_NAME, + "submodule_name": submodule_name, + }, + ] + ) + trainer_configs.extend( + [ + { + "trainer": StandardTrainer, + "dict_class": AutoEncoder, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "lr": learning_rate, + "l1_penalty": sparsity_penalty, + "warmup_steps": warmup_steps, + "resample_steps": resample_steps, + "seed": random_seed, + "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", + "layer": layer, + "lm_name": MODEL_NAME, + "device": DEVICE, + "submodule_name": submodule_name, + }, + ] + ) + + print(f"len trainer configs: {len(trainer_configs)}") + output_dir = f"{SAVE_DIR}/{submodule_name}" + + trainSAE( + data=activation_buffer, + trainer_configs=trainer_configs, + steps=steps, + save_steps=save_steps, + save_dir=output_dir, + ) + + folders = get_nested_folders(output_dir) + + assert len(folders) == 2 + + for folder in folders: + dictionary, config = load_dictionary(folder, DEVICE) + + assert dictionary is not None + assert config is not None + + +def test_evaluation(): + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) + ae_paths = get_nested_folders(SAVE_DIR) + + context_length = 128 + llm_batch_size = 100 + buffer_size = 256 + io = "out" + + generator = hf_dataset_to_generator(DATASET_NAME) + submodule = model.gpt_neox.layers[LAYER] + + input_strings = [] + for i, example in enumerate(generator): + input_strings.append(example) + if i > buffer_size * 2: + break + + for ae_path in ae_paths: + dictionary, config = load_dictionary(ae_path, DEVICE) + dictionary = dictionary.to(dtype=model.dtype) + + activation_dim = config["trainer"]["activation_dim"] + context_length = config["buffer"]["ctx_len"] + + activation_buffer_data = iter(input_strings) + + activation_buffer = ActivationBuffer( + activation_buffer_data, + model, + submodule, + n_ctxs=buffer_size, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=llm_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + eval_results = evaluate( + dictionary, + activation_buffer, + context_length, + llm_batch_size, + io=io, + device=DEVICE, + ) + + print(eval_results) + + dict_class = config["trainer"]["dict_class"] + expected_results = EXPECTED_RESULTS[dict_class] + + for key, value in expected_results.items(): + assert abs(eval_results[key] - value) < EVAL_TOLERANCE From 9ed4af245a22e095e932d6065d368c58947d9a3d Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 17 Dec 2024 23:41:15 +0000 Subject: [PATCH 2/8] Rename input to inputs per nnsight 0.3.0 --- buffer.py | 4 ++-- evaluation.py | 2 +- requirements.txt | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/buffer.py b/buffer.py index 86f24f9..178ea3e 100644 --- a/buffer.py +++ b/buffer.py @@ -124,7 +124,7 @@ def refresh(self): hidden_states = self.submodule.input[0].save() else: hidden_states = self.submodule.output.save() - input = self.model.input.save() + input = self.model.inputs.save() attn_mask = input.value[1]["attention_mask"] hidden_states = hidden_states.value if isinstance(hidden_states, tuple): @@ -251,7 +251,7 @@ def refresh(self): while len(self.activations) < self.n_ctxs * self.ctx_len: with t.no_grad(): with self.model.trace(self.text_batch(), **tracer_kwargs, invoker_args={'truncation': True, 'max_length': self.ctx_len}, remote=self.remote): - input = self.model.input.save() + input = self.model.inputs.save() hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.input[0][0]#.save() if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] diff --git a/evaluation.py b/evaluation.py index 6b3b0e5..13bf4fa 100644 --- a/evaluation.py +++ b/evaluation.py @@ -110,7 +110,7 @@ def loss_recovered( else: raise ValueError(f"Invalid value for io: {io}") - input = model.input.save() + input = model.inputs.save() logits_zero = model.output.save() logits_zero = logits_zero.value diff --git a/requirements.txt b/requirements.txt index 5b9f3c6..bda16d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ circuitsvis>=1.43.2 datasets>=2.18.0 einops>=0.7.0 matplotlib>=3.8.3 -nnsight>=0.2.11 +nnsight>=0.3.0 pandas>=2.2.1 plotly>=5.18.0 torch>=2.1.2 From 807f6ef735872a5cab68773a315f15bc920c3d72 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 18 Dec 2024 02:18:19 +0000 Subject: [PATCH 3/8] Complete nnsight 0.2 to 0.3 changes --- buffer.py | 6 +++--- evaluation.py | 46 +++++++++++++--------------------------- interp.py | 2 +- tests/test_end_to_end.py | 42 ++++++++++++++++++------------------ 4 files changed, 40 insertions(+), 56 deletions(-) diff --git a/buffer.py b/buffer.py index 178ea3e..be3a745 100644 --- a/buffer.py +++ b/buffer.py @@ -121,7 +121,7 @@ def refresh(self): invoker_args={"truncation": True, "max_length": self.ctx_len}, ): if self.io == "in": - hidden_states = self.submodule.input[0].save() + hidden_states = self.submodule.inputs[0].save() else: hidden_states = self.submodule.output.save() input = self.model.inputs.save() @@ -252,7 +252,7 @@ def refresh(self): with t.no_grad(): with self.model.trace(self.text_batch(), **tracer_kwargs, invoker_args={'truncation': True, 'max_length': self.ctx_len}, remote=self.remote): input = self.model.inputs.save() - hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.input[0][0]#.save() + hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.inputs[0][0]#.save() if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] @@ -416,7 +416,7 @@ def refresh(self): invoker_args={"truncation": True, "max_length": self.ctx_len}, ): if self.io in ["in", "in_and_out"]: - hidden_states_in = self.submodule.input[0].save() + hidden_states_in = self.submodule.inputs[0].save() if self.io in ["out", "in_and_out"]: hidden_states_out = self.submodule.output.save() diff --git a/evaluation.py b/evaluation.py index 13bf4fa..558d6c8 100644 --- a/evaluation.py +++ b/evaluation.py @@ -36,21 +36,17 @@ def loss_recovered( # logits when replacing component activations with reconstruction by autoencoder with model.trace(text, **tracer_args, invoker_args=invoker_args): if io == 'in': - x = submodule.input[0] - if type(submodule.input.shape) == tuple: x = x[0] + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale elif io == 'out': - x = submodule.output - if type(submodule.output.shape) == tuple: x = x[0] + x = submodule.output[0] if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale elif io == 'in_and_out': - x = submodule.input[0] - if type(submodule.input.shape) == tuple: x = x[0] - print(f'x.shape: {x.shape}') + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale @@ -58,35 +54,28 @@ def loss_recovered( raise ValueError(f"Invalid value for io: {io}") x = x.save() - # pull this out so dictionary can be written without FakeTensor (top_k needs this) - x_hat = dictionary(x.view(-1, x.shape[-1])).view(x.shape).to(model.dtype) + x_hat = dictionary(x).to(model.dtype) # intervene with `x_hat` with model.trace(text, **tracer_args, invoker_args=invoker_args): if io == 'in': - x = submodule.input[0] + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - if type(submodule.input.shape) == tuple: - submodule.input[0][:] = x_hat - else: - submodule.input = x_hat + submodule.input[:] = x_hat elif io == 'out': - x = submodule.output + x = submodule.output[0] if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - if type(submodule.output.shape) == tuple: - submodule.output = (x_hat,) - else: - submodule.output = x_hat + submodule.output[0][:] = x_hat elif io == 'in_and_out': - x = submodule.input[0] + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - submodule.output = x_hat + submodule.output[0][:] = x_hat else: raise ValueError(f"Invalid value for io: {io}") @@ -96,22 +85,17 @@ def loss_recovered( # logits when replacing component activations with zeros with model.trace(text, **tracer_args, invoker_args=invoker_args): if io == 'in': - x = submodule.input[0] - if type(submodule.input.shape) == tuple: - submodule.input[0][:] = t.zeros_like(x[0]) - else: - submodule.input = t.zeros_like(x) + x = submodule.input + submodule.input[:] = t.zeros_like(x) elif io in ['out', 'in_and_out']: - x = submodule.output - if type(submodule.output.shape) == tuple: - submodule.output[0][:] = t.zeros_like(x[0]) - else: - submodule.output = t.zeros_like(x) + x = submodule.output[0] + submodule.output[0][:] = t.zeros_like(x) else: raise ValueError(f"Invalid value for io: {io}") input = model.inputs.save() logits_zero = model.output.save() + logits_zero = logits_zero.value # get everything into the right format diff --git a/interp.py b/interp.py index 283965b..e721eb9 100644 --- a/interp.py +++ b/interp.py @@ -101,7 +101,7 @@ def _list_decode(x): inputs = buffer.tokenized_batch(batch_size=n_inputs) with t.no_grad(), model.trace(inputs, **tracer_kwargs): - tokens = model.input[1][ + tokens = model.inputs[1][ "input_ids" ].save() # if you're getting errors, check here; might only work for pythia models activations = submodule.output diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 8cb4b95..c6eb0a3 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -19,32 +19,32 @@ EXPECTED_RESULTS = { "AutoEncoderTopK": { - "l2_loss": 4.470372676849365, - "l1_loss": 44.47749710083008, + "l2_loss": 4.462644577026367, + "l1_loss": 44.446834564208984, "l0": 40.0, "frac_alive": 0.000244140625, - "frac_variance_explained": 0.9372208118438721, - "cossim": 0.9471381902694702, - "l2_ratio": 0.9523985981941223, - "relative_reconstruction_bias": 0.9996458888053894, - "loss_original": 3.186223268508911, - "loss_reconstructed": 3.690929412841797, - "loss_zero": 12.936649322509766, - "frac_recovered": 0.9482374787330627, + "frac_variance_explained": 0.9372867941856384, + "cossim": 0.9471449851989746, + "l2_ratio": 0.9524278044700623, + "relative_reconstruction_bias": 0.9986423254013062, + "loss_original": 3.1832079887390137, + "loss_reconstructed": 3.713366985321045, + "loss_zero": 12.936450958251953, + "frac_recovered": 0.9456427693367004, }, "AutoEncoder": { - "l2_loss": 6.72230863571167, - "l1_loss": 28.893749237060547, - "l0": 61.12999725341797, + "l2_loss": 6.721538066864014, + "l1_loss": 28.914989471435547, + "l0": 61.29999923706055, "frac_alive": 0.000244140625, - "frac_variance_explained": 0.6076533794403076, - "cossim": 0.869738757610321, - "l2_ratio": 0.8005934953689575, - "relative_reconstruction_bias": 0.9304398894309998, - "loss_original": 3.186223268508911, - "loss_reconstructed": 5.501500129699707, - "loss_zero": 12.936649322509766, - "frac_recovered": 0.7625460624694824, + "frac_variance_explained": 0.6077123880386353, + "cossim": 0.869745135307312, + "l2_ratio": 0.801030695438385, + "relative_reconstruction_bias": 0.9309902191162109, + "loss_original": 3.1832079887390137, + "loss_reconstructed": 5.499264717102051, + "loss_zero": 12.936450958251953, + "frac_recovered": 0.7625347375869751, }, } From dc3072089c24ce1eb8bc40e9f5248c69a92f5174 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 18 Dec 2024 02:53:40 +0000 Subject: [PATCH 4/8] Fix frac_alive calculation, perform evaluation over multiple batches --- evaluation.py | 51 +++++++++++++++++++++++----------------- tests/test_end_to_end.py | 4 ++-- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/evaluation.py b/evaluation.py index 558d6c8..9097713 100644 --- a/evaluation.py +++ b/evaluation.py @@ -3,6 +3,8 @@ """ import torch as t +from collections import defaultdict + from .buffer import ActivationBuffer, NNsightActivationBuffer from nnsight import LanguageModel from .config import DEBUG @@ -128,7 +130,7 @@ def loss_recovered( return tuple(losses) - +@t.no_grad() def evaluate( dictionary, # a dictionary activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered @@ -138,26 +140,28 @@ def evaluate( normalize_batch=False, # normalize batch before passing through dictionary tracer_args={'use_cache': False, 'output_attentions': False}, # minimize cache during model trace. device="cpu", + n_batches: int = 1, ): - with t.no_grad(): - - out = {} # dict of results + assert n_batches > 0 + out = defaultdict(float) + active_features = t.zeros(dictionary.dict_size, dtype=t.float32, device=device) + for _ in range(n_batches): try: x = next(activations).to(device) if normalize_batch: x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim ** 0.5) - except StopIteration: raise StopIteration( "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data." ) - x_hat, f = dictionary(x, output_features=True) l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() l1_loss = f.norm(p=1, dim=-1).mean() l0 = (f != 0).float().sum(dim=-1).mean() - frac_alive = t.flatten(f, start_dim=0, end_dim=1).any(dim=0).sum() / dictionary.dict_size + + features_BF = t.flatten(f, start_dim=0, end_dim=-2).to(dtype=t.float32) # If f is shape (B, L, D), flatten to (B*L, D) + active_features += features_BF.sum(dim=0) # cosine similarity between x and x_hat x_normed = x / t.linalg.norm(x, dim=-1, keepdim=True) @@ -177,18 +181,17 @@ def evaluate( x_dot_x_hat = (x * x_hat).sum(dim=-1) relative_reconstruction_bias = x_hat_norm_squared.mean() / x_dot_x_hat.mean() - out["l2_loss"] = l2_loss.item() - out["l1_loss"] = l1_loss.item() - out["l0"] = l0.item() - out["frac_alive"] = frac_alive.item() - out["frac_variance_explained"] = frac_variance_explained.item() - out["cossim"] = cossim.item() - out["l2_ratio"] = l2_ratio.item() - out['relative_reconstruction_bias'] = relative_reconstruction_bias.item() + out["l2_loss"] += l2_loss.item() + out["l1_loss"] += l1_loss.item() + out["l0"] += l0.item() + out["frac_variance_explained"] += frac_variance_explained.item() + out["cossim"] += cossim.item() + out["l2_ratio"] += l2_ratio.item() + out['relative_reconstruction_bias'] += relative_reconstruction_bias.item() if not isinstance(activations, (ActivationBuffer, NNsightActivationBuffer)): - return out - + continue + # compute loss recovered loss_original, loss_reconstructed, loss_zero = loss_recovered( activations.text_batch(batch_size=batch_size), @@ -202,9 +205,13 @@ def evaluate( ) frac_recovered = (loss_reconstructed - loss_zero) / (loss_original - loss_zero) - out["loss_original"] = loss_original.item() - out["loss_reconstructed"] = loss_reconstructed.item() - out["loss_zero"] = loss_zero.item() - out["frac_recovered"] = frac_recovered.item() + out["loss_original"] += loss_original.item() + out["loss_reconstructed"] += loss_reconstructed.item() + out["loss_zero"] += loss_zero.item() + out["frac_recovered"] += frac_recovered.item() + + out = {key: value / n_batches for key, value in out.items()} + frac_alive = (active_features != 0).float().sum() / dictionary.dict_size + out["frac_alive"] = frac_alive.item() - return out + return out \ No newline at end of file diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index c6eb0a3..c41a433 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -22,7 +22,7 @@ "l2_loss": 4.462644577026367, "l1_loss": 44.446834564208984, "l0": 40.0, - "frac_alive": 0.000244140625, + "frac_alive": 0.45458984375, "frac_variance_explained": 0.9372867941856384, "cossim": 0.9471449851989746, "l2_ratio": 0.9524278044700623, @@ -36,7 +36,7 @@ "l2_loss": 6.721538066864014, "l1_loss": 28.914989471435547, "l0": 61.29999923706055, - "frac_alive": 0.000244140625, + "frac_alive": 0.14404296875, "frac_variance_explained": 0.6077123880386353, "cossim": 0.869745135307312, "l2_ratio": 0.801030695438385, From 067bf7b05470f61b9ed4f38b95be55c5ac45fb8f Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 18 Dec 2024 03:55:05 +0000 Subject: [PATCH 5/8] Obtain better test results using multiple batches --- evaluation.py | 5 +++- tests/test_end_to_end.py | 59 ++++++++++++++++++++-------------------- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/evaluation.py b/evaluation.py index 9097713..99fddef 100644 --- a/evaluation.py +++ b/evaluation.py @@ -161,6 +161,9 @@ def evaluate( l0 = (f != 0).float().sum(dim=-1).mean() features_BF = t.flatten(f, start_dim=0, end_dim=-2).to(dtype=t.float32) # If f is shape (B, L, D), flatten to (B*L, D) + assert features_BF.shape[-1] == dictionary.dict_size + assert len(features_BF.shape) == 2 + active_features += features_BF.sum(dim=0) # cosine similarity between x and x_hat @@ -191,7 +194,7 @@ def evaluate( if not isinstance(activations, (ActivationBuffer, NNsightActivationBuffer)): continue - + # compute loss recovered loss_original, loss_reconstructed, loss_zero = loss_recovered( activations.text_batch(batch_size=batch_size), diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index c41a433..ce5a1cf 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -19,32 +19,32 @@ EXPECTED_RESULTS = { "AutoEncoderTopK": { - "l2_loss": 4.462644577026367, - "l1_loss": 44.446834564208984, + "l2_loss": 4.325331306457519, + "l1_loss": 47.92763671875, "l0": 40.0, - "frac_alive": 0.45458984375, - "frac_variance_explained": 0.9372867941856384, - "cossim": 0.9471449851989746, - "l2_ratio": 0.9524278044700623, - "relative_reconstruction_bias": 0.9986423254013062, - "loss_original": 3.1832079887390137, - "loss_reconstructed": 3.713366985321045, - "loss_zero": 12.936450958251953, - "frac_recovered": 0.9456427693367004, + "frac_variance_explained": 0.9584966480731965, + "cossim": 0.948570293188095, + "l2_ratio": 0.94872345328331, + "relative_reconstruction_bias": 0.9998040139675141, + "loss_original": 3.328495955467224, + "loss_reconstructed": 3.819682216644287, + "loss_zero": 13.250199031829833, + "frac_recovered": 0.9503251194953919, + "frac_alive": 0.99951171875, }, "AutoEncoder": { - "l2_loss": 6.721538066864014, - "l1_loss": 28.914989471435547, - "l0": 61.29999923706055, - "frac_alive": 0.14404296875, - "frac_variance_explained": 0.6077123880386353, - "cossim": 0.869745135307312, - "l2_ratio": 0.801030695438385, - "relative_reconstruction_bias": 0.9309902191162109, - "loss_original": 3.1832079887390137, - "loss_reconstructed": 5.499264717102051, - "loss_zero": 12.936450958251953, - "frac_recovered": 0.7625347375869751, + "l2_loss": 6.5741173267364506, + "l1_loss": 32.06615734100342, + "l0": 60.9147216796875, + "frac_variance_explained": 0.9042629599571228, + "cossim": 0.8782194256782532, + "l2_ratio": 0.814234834909439, + "relative_reconstruction_bias": 0.9813631415367127, + "loss_original": 3.328495955467224, + "loss_reconstructed": 5.7899915218353275, + "loss_zero": 13.250199031829833, + "frac_recovered": 0.754741370677948, + "frac_alive": 0.9921875, }, } @@ -105,9 +105,7 @@ def test_sae_training(): random.seed(RANDOM_SEED) t.manual_seed(RANDOM_SEED) - MODEL_NAME = "EleutherAI/pythia-70m-deduped" model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) - layer = 3 context_length = 128 llm_batch_size = 512 # Fits on a 24GB GPU @@ -172,7 +170,7 @@ def test_sae_training(): "seed": random_seed, "wandb_name": f"TopKTrainer-{MODEL_NAME}-{submodule_name}", "device": DEVICE, - "layer": layer, + "layer": LAYER, "lm_name": MODEL_NAME, "submodule_name": submodule_name, }, @@ -191,7 +189,7 @@ def test_sae_training(): "resample_steps": resample_steps, "seed": random_seed, "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", - "layer": layer, + "layer": LAYER, "lm_name": MODEL_NAME, "device": DEVICE, "submodule_name": submodule_name, @@ -230,6 +228,8 @@ def test_evaluation(): context_length = 128 llm_batch_size = 100 + sae_batch_size = 4096 + n_batches = 10 buffer_size = 256 io = "out" @@ -239,7 +239,7 @@ def test_evaluation(): input_strings = [] for i, example in enumerate(generator): input_strings.append(example) - if i > buffer_size * 2: + if i > buffer_size * n_batches: break for ae_path in ae_paths: @@ -258,7 +258,7 @@ def test_evaluation(): n_ctxs=buffer_size, ctx_len=context_length, refresh_batch_size=llm_batch_size, - out_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, io=io, d_submodule=activation_dim, device=DEVICE, @@ -271,6 +271,7 @@ def test_evaluation(): llm_batch_size, io=io, device=DEVICE, + n_batches=n_batches, ) print(eval_results) From 05fe179f5b0616310253deaf758c370071f534fa Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 18 Dec 2024 03:55:27 +0000 Subject: [PATCH 6/8] Add early stopping in forward pass --- buffer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/buffer.py b/buffer.py index be3a745..d997596 100644 --- a/buffer.py +++ b/buffer.py @@ -125,6 +125,8 @@ def refresh(self): else: hidden_states = self.submodule.output.save() input = self.model.inputs.save() + + self.submodule.output.stop() attn_mask = input.value[1]["attention_mask"] hidden_states = hidden_states.value if isinstance(hidden_states, tuple): From f1b9b800bc8e2cc308d4d14690df71f854b30fce Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 18 Dec 2024 04:00:52 +0000 Subject: [PATCH 7/8] Change save_steps to a list of ints --- training.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/training.py b/training.py index f100fee..13fd4b3 100644 --- a/training.py +++ b/training.py @@ -6,6 +6,7 @@ import multiprocessing as mp import os from queue import Empty +from typing import Optional import torch as t from tqdm import tqdm @@ -75,17 +76,17 @@ def log_stats( def trainSAE( data, - trainer_configs, - use_wandb=False, - wandb_entity="", - wandb_project="", - steps=None, - save_steps=None, - save_dir=None, - log_steps=None, - activations_split_by_head=False, - transcoder=False, - run_cfg={}, + trainer_configs: list[dict], + use_wandb:bool=False, + wandb_entity:str="", + wandb_project:str="", + steps:Optional[int]=None, + save_steps:Optional[list[int]]=None, + save_dir:Optional[str]=None, + log_steps:Optional[int]=None, + activations_split_by_head:bool=False, + transcoder:bool=False, + run_cfg:dict={}, ): """ Train SAEs using the given trainers @@ -140,7 +141,7 @@ def trainSAE( ) # saving - if save_steps is not None and step % save_steps == 0: + if save_steps is not None and step in save_steps: for dir, trainer in zip(save_dirs, trainers): if dir is not None: if not os.path.exists(os.path.join(dir, "checkpoints")): From d350415e119cacb6547703eb9733daf8ef57075b Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 18 Dec 2024 16:30:15 +0000 Subject: [PATCH 8/8] Check for is_tuple to support mlp / attn submodules --- evaluation.py | 37 ++++++++++++++++++++++++++++++------- tests/test_end_to_end.py | 7 +++---- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/evaluation.py b/evaluation.py index 99fddef..ba56437 100644 --- a/evaluation.py +++ b/evaluation.py @@ -24,12 +24,21 @@ def loss_recovered( How much of the model's loss is recovered by replacing the component output with the reconstruction by the autoencoder? """ - + if max_len is None: invoker_args = {} else: invoker_args = {"truncation": True, "max_length": max_len } + with model.trace("_"): + temp_output = submodule.output.save() + + output_is_tuple = False + # Note: isinstance() won't work here as torch.Size is a subclass of tuple, + # so isinstance(temp_output.shape, tuple) would return True even for torch.Size. + if type(temp_output.shape) == tuple: + output_is_tuple = True + # unmodified logits with model.trace(text, invoker_args=invoker_args): logits_original = model.output.save() @@ -43,7 +52,8 @@ def loss_recovered( scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale elif io == 'out': - x = submodule.output[0] + x = submodule.output + if output_is_tuple: x = x[0] if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale @@ -56,6 +66,9 @@ def loss_recovered( raise ValueError(f"Invalid value for io: {io}") x = x.save() + # If we incorrectly handle output_is_tuple, such as with some mlp submodules, we will get an error here. + assert len(x.shape) == 3, f"Expected x to have shape (B, L, D), got {x.shape}, output_is_tuple: {output_is_tuple}" + x_hat = dictionary(x).to(model.dtype) # intervene with `x_hat` @@ -67,17 +80,24 @@ def loss_recovered( x_hat = x_hat / scale submodule.input[:] = x_hat elif io == 'out': - x = submodule.output[0] + x = submodule.output + if output_is_tuple: x = x[0] if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - submodule.output[0][:] = x_hat + if output_is_tuple: + submodule.output[0][:] = x_hat + else: + submodule.output[:] = x_hat elif io == 'in_and_out': x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - submodule.output[0][:] = x_hat + if output_is_tuple: + submodule.output[0][:] = x_hat + else: + submodule.output[:] = x_hat else: raise ValueError(f"Invalid value for io: {io}") @@ -90,8 +110,11 @@ def loss_recovered( x = submodule.input submodule.input[:] = t.zeros_like(x) elif io in ['out', 'in_and_out']: - x = submodule.output[0] - submodule.output[0][:] = t.zeros_like(x) + x = submodule.output + if output_is_tuple: + submodule.output[0][:] = t.zeros_like(x[0]) + else: + submodule.output[:] = t.zeros_like(x) else: raise ValueError(f"Invalid value for io: {io}") diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index ce5a1cf..8e93cab 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -100,7 +100,7 @@ def load_dictionary(base_path: str, device: str) -> tuple: def test_sae_training(): - """End to end test for training an SAE. Takes ~3 minutes on an RTX 3090. + """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. This isn't a nice suite of unit tests, but it's better than nothing.""" random.seed(RANDOM_SEED) t.manual_seed(RANDOM_SEED) @@ -117,7 +117,6 @@ def test_sae_training(): num_tokens = 10_000_000 # sae training parameters - random_seed = 42 k = 40 sparsity_penalty = 0.05 expansion_factor = 8 @@ -167,7 +166,7 @@ def test_sae_training(): "auxk_alpha": auxk_alpha, # see Appendix A.2 "decay_start": decay_start, # when does the lr decay start "steps": steps, # when when does training end - "seed": random_seed, + "seed": RANDOM_SEED, "wandb_name": f"TopKTrainer-{MODEL_NAME}-{submodule_name}", "device": DEVICE, "layer": LAYER, @@ -187,7 +186,7 @@ def test_sae_training(): "l1_penalty": sparsity_penalty, "warmup_steps": warmup_steps, "resample_steps": resample_steps, - "seed": random_seed, + "seed": RANDOM_SEED, "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", "layer": LAYER, "lm_name": MODEL_NAME,