diff --git a/buffer.py b/buffer.py index 86f24f9..d997596 100644 --- a/buffer.py +++ b/buffer.py @@ -121,10 +121,12 @@ def refresh(self): invoker_args={"truncation": True, "max_length": self.ctx_len}, ): if self.io == "in": - hidden_states = self.submodule.input[0].save() + hidden_states = self.submodule.inputs[0].save() else: hidden_states = self.submodule.output.save() - input = self.model.input.save() + input = self.model.inputs.save() + + self.submodule.output.stop() attn_mask = input.value[1]["attention_mask"] hidden_states = hidden_states.value if isinstance(hidden_states, tuple): @@ -251,8 +253,8 @@ def refresh(self): while len(self.activations) < self.n_ctxs * self.ctx_len: with t.no_grad(): with self.model.trace(self.text_batch(), **tracer_kwargs, invoker_args={'truncation': True, 'max_length': self.ctx_len}, remote=self.remote): - input = self.model.input.save() - hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.input[0][0]#.save() + input = self.model.inputs.save() + hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.inputs[0][0]#.save() if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] @@ -416,7 +418,7 @@ def refresh(self): invoker_args={"truncation": True, "max_length": self.ctx_len}, ): if self.io in ["in", "in_and_out"]: - hidden_states_in = self.submodule.input[0].save() + hidden_states_in = self.submodule.inputs[0].save() if self.io in ["out", "in_and_out"]: hidden_states_out = self.submodule.output.save() diff --git a/evaluation.py b/evaluation.py index 6b3b0e5..ba56437 100644 --- a/evaluation.py +++ b/evaluation.py @@ -3,6 +3,8 @@ """ import torch as t +from collections import defaultdict + from .buffer import ActivationBuffer, NNsightActivationBuffer from nnsight import LanguageModel from .config import DEBUG @@ -22,12 +24,21 @@ def loss_recovered( How much of the model's loss is recovered by replacing the component output with the reconstruction by the autoencoder? """ - + if max_len is None: invoker_args = {} else: invoker_args = {"truncation": True, "max_length": max_len } + with model.trace("_"): + temp_output = submodule.output.save() + + output_is_tuple = False + # Note: isinstance() won't work here as torch.Size is a subclass of tuple, + # so isinstance(temp_output.shape, tuple) would return True even for torch.Size. + if type(temp_output.shape) == tuple: + output_is_tuple = True + # unmodified logits with model.trace(text, invoker_args=invoker_args): logits_original = model.output.save() @@ -36,21 +47,18 @@ def loss_recovered( # logits when replacing component activations with reconstruction by autoencoder with model.trace(text, **tracer_args, invoker_args=invoker_args): if io == 'in': - x = submodule.input[0] - if type(submodule.input.shape) == tuple: x = x[0] + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale elif io == 'out': x = submodule.output - if type(submodule.output.shape) == tuple: x = x[0] + if output_is_tuple: x = x[0] if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale elif io == 'in_and_out': - x = submodule.input[0] - if type(submodule.input.shape) == tuple: x = x[0] - print(f'x.shape: {x.shape}') + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale @@ -58,35 +66,38 @@ def loss_recovered( raise ValueError(f"Invalid value for io: {io}") x = x.save() - # pull this out so dictionary can be written without FakeTensor (top_k needs this) - x_hat = dictionary(x.view(-1, x.shape[-1])).view(x.shape).to(model.dtype) + # If we incorrectly handle output_is_tuple, such as with some mlp submodules, we will get an error here. + assert len(x.shape) == 3, f"Expected x to have shape (B, L, D), got {x.shape}, output_is_tuple: {output_is_tuple}" + + x_hat = dictionary(x).to(model.dtype) # intervene with `x_hat` with model.trace(text, **tracer_args, invoker_args=invoker_args): if io == 'in': - x = submodule.input[0] + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - if type(submodule.input.shape) == tuple: - submodule.input[0][:] = x_hat - else: - submodule.input = x_hat + submodule.input[:] = x_hat elif io == 'out': x = submodule.output + if output_is_tuple: x = x[0] if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - if type(submodule.output.shape) == tuple: - submodule.output = (x_hat,) + if output_is_tuple: + submodule.output[0][:] = x_hat else: - submodule.output = x_hat + submodule.output[:] = x_hat elif io == 'in_and_out': - x = submodule.input[0] + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - submodule.output = x_hat + if output_is_tuple: + submodule.output[0][:] = x_hat + else: + submodule.output[:] = x_hat else: raise ValueError(f"Invalid value for io: {io}") @@ -96,22 +107,20 @@ def loss_recovered( # logits when replacing component activations with zeros with model.trace(text, **tracer_args, invoker_args=invoker_args): if io == 'in': - x = submodule.input[0] - if type(submodule.input.shape) == tuple: - submodule.input[0][:] = t.zeros_like(x[0]) - else: - submodule.input = t.zeros_like(x) + x = submodule.input + submodule.input[:] = t.zeros_like(x) elif io in ['out', 'in_and_out']: x = submodule.output - if type(submodule.output.shape) == tuple: + if output_is_tuple: submodule.output[0][:] = t.zeros_like(x[0]) else: - submodule.output = t.zeros_like(x) + submodule.output[:] = t.zeros_like(x) else: raise ValueError(f"Invalid value for io: {io}") - input = model.input.save() + input = model.inputs.save() logits_zero = model.output.save() + logits_zero = logits_zero.value # get everything into the right format @@ -144,7 +153,7 @@ def loss_recovered( return tuple(losses) - +@t.no_grad() def evaluate( dictionary, # a dictionary activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered @@ -154,26 +163,31 @@ def evaluate( normalize_batch=False, # normalize batch before passing through dictionary tracer_args={'use_cache': False, 'output_attentions': False}, # minimize cache during model trace. device="cpu", + n_batches: int = 1, ): - with t.no_grad(): - - out = {} # dict of results + assert n_batches > 0 + out = defaultdict(float) + active_features = t.zeros(dictionary.dict_size, dtype=t.float32, device=device) + for _ in range(n_batches): try: x = next(activations).to(device) if normalize_batch: x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim ** 0.5) - except StopIteration: raise StopIteration( "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data." ) - x_hat, f = dictionary(x, output_features=True) l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() l1_loss = f.norm(p=1, dim=-1).mean() l0 = (f != 0).float().sum(dim=-1).mean() - frac_alive = t.flatten(f, start_dim=0, end_dim=1).any(dim=0).sum() / dictionary.dict_size + + features_BF = t.flatten(f, start_dim=0, end_dim=-2).to(dtype=t.float32) # If f is shape (B, L, D), flatten to (B*L, D) + assert features_BF.shape[-1] == dictionary.dict_size + assert len(features_BF.shape) == 2 + + active_features += features_BF.sum(dim=0) # cosine similarity between x and x_hat x_normed = x / t.linalg.norm(x, dim=-1, keepdim=True) @@ -193,17 +207,16 @@ def evaluate( x_dot_x_hat = (x * x_hat).sum(dim=-1) relative_reconstruction_bias = x_hat_norm_squared.mean() / x_dot_x_hat.mean() - out["l2_loss"] = l2_loss.item() - out["l1_loss"] = l1_loss.item() - out["l0"] = l0.item() - out["frac_alive"] = frac_alive.item() - out["frac_variance_explained"] = frac_variance_explained.item() - out["cossim"] = cossim.item() - out["l2_ratio"] = l2_ratio.item() - out['relative_reconstruction_bias'] = relative_reconstruction_bias.item() + out["l2_loss"] += l2_loss.item() + out["l1_loss"] += l1_loss.item() + out["l0"] += l0.item() + out["frac_variance_explained"] += frac_variance_explained.item() + out["cossim"] += cossim.item() + out["l2_ratio"] += l2_ratio.item() + out['relative_reconstruction_bias'] += relative_reconstruction_bias.item() if not isinstance(activations, (ActivationBuffer, NNsightActivationBuffer)): - return out + continue # compute loss recovered loss_original, loss_reconstructed, loss_zero = loss_recovered( @@ -218,9 +231,13 @@ def evaluate( ) frac_recovered = (loss_reconstructed - loss_zero) / (loss_original - loss_zero) - out["loss_original"] = loss_original.item() - out["loss_reconstructed"] = loss_reconstructed.item() - out["loss_zero"] = loss_zero.item() - out["frac_recovered"] = frac_recovered.item() + out["loss_original"] += loss_original.item() + out["loss_reconstructed"] += loss_reconstructed.item() + out["loss_zero"] += loss_zero.item() + out["frac_recovered"] += frac_recovered.item() + + out = {key: value / n_batches for key, value in out.items()} + frac_alive = (active_features != 0).float().sum() / dictionary.dict_size + out["frac_alive"] = frac_alive.item() - return out + return out \ No newline at end of file diff --git a/interp.py b/interp.py index 283965b..e721eb9 100644 --- a/interp.py +++ b/interp.py @@ -101,7 +101,7 @@ def _list_decode(x): inputs = buffer.tokenized_batch(batch_size=n_inputs) with t.no_grad(), model.trace(inputs, **tracer_kwargs): - tokens = model.input[1][ + tokens = model.inputs[1][ "input_ids" ].save() # if you're getting errors, check here; might only work for pythia models activations = submodule.output diff --git a/requirements.txt b/requirements.txt index 7366e63..bda16d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,11 +2,12 @@ circuitsvis>=1.43.2 datasets>=2.18.0 einops>=0.7.0 matplotlib>=3.8.3 -nnsight>=0.2.11 +nnsight>=0.3.0 pandas>=2.2.1 plotly>=5.18.0 torch>=2.1.2 tqdm>=4.66.1 umap-learn>=0.5.6 zstandard>=0.22.0 -wandb +wandb>=0.12.0 +pytest>=6.2.4 \ No newline at end of file diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py new file mode 100644 index 0000000..8e93cab --- /dev/null +++ b/tests/test_end_to_end.py @@ -0,0 +1,282 @@ +import torch as t +from nnsight import LanguageModel +import os +import json +import random + +from dictionary_learning.training import trainSAE +from dictionary_learning.trainers.standard import StandardTrainer +from dictionary_learning.trainers.top_k import TrainerTopK, AutoEncoderTopK +from dictionary_learning.utils import hf_dataset_to_generator +from dictionary_learning.buffer import ActivationBuffer +from dictionary_learning.dictionary import ( + AutoEncoder, + GatedAutoEncoder, + AutoEncoderNew, + JumpReluAutoEncoder, +) +from dictionary_learning.evaluation import evaluate + +EXPECTED_RESULTS = { + "AutoEncoderTopK": { + "l2_loss": 4.325331306457519, + "l1_loss": 47.92763671875, + "l0": 40.0, + "frac_variance_explained": 0.9584966480731965, + "cossim": 0.948570293188095, + "l2_ratio": 0.94872345328331, + "relative_reconstruction_bias": 0.9998040139675141, + "loss_original": 3.328495955467224, + "loss_reconstructed": 3.819682216644287, + "loss_zero": 13.250199031829833, + "frac_recovered": 0.9503251194953919, + "frac_alive": 0.99951171875, + }, + "AutoEncoder": { + "l2_loss": 6.5741173267364506, + "l1_loss": 32.06615734100342, + "l0": 60.9147216796875, + "frac_variance_explained": 0.9042629599571228, + "cossim": 0.8782194256782532, + "l2_ratio": 0.814234834909439, + "relative_reconstruction_bias": 0.9813631415367127, + "loss_original": 3.328495955467224, + "loss_reconstructed": 5.7899915218353275, + "loss_zero": 13.250199031829833, + "frac_recovered": 0.754741370677948, + "frac_alive": 0.9921875, + }, +} + +DEVICE = "cuda:0" +SAVE_DIR = "./test_data" +MODEL_NAME = "EleutherAI/pythia-70m-deduped" +RANDOM_SEED = 42 +LAYER = 3 +DATASET_NAME = "monology/pile-uncopyrighted" + +EVAL_TOLERANCE = 0.01 + + +def get_nested_folders(path: str) -> list[str]: + """ + Recursively get a list of folders that contain an ae.pt file, starting the search from the given path + """ + folder_names = [] + + for root, dirs, files in os.walk(path): + if "ae.pt" in files: + folder_names.append(root) + + return folder_names + + +def load_dictionary(base_path: str, device: str) -> tuple: + ae_path = f"{base_path}/ae.pt" + config_path = f"{base_path}/config.json" + + with open(config_path, "r") as f: + config = json.load(f) + + # TODO: Save the submodule name in the config? + # submodule_str = config["trainer"]["submodule_name"] + dict_class = config["trainer"]["dict_class"] + + if dict_class == "AutoEncoder": + dictionary = AutoEncoder.from_pretrained(ae_path, device=device) + elif dict_class == "GatedAutoEncoder": + dictionary = GatedAutoEncoder.from_pretrained(ae_path, device=device) + elif dict_class == "AutoEncoderNew": + dictionary = AutoEncoderNew.from_pretrained(ae_path, device=device) + elif dict_class == "AutoEncoderTopK": + k = config["trainer"]["k"] + dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device) + elif dict_class == "JumpReluAutoEncoder": + dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device) + else: + raise ValueError(f"Dictionary class {dict_class} not supported") + + return dictionary, config + + +def test_sae_training(): + """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. + This isn't a nice suite of unit tests, but it's better than nothing.""" + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) + + context_length = 128 + llm_batch_size = 512 # Fits on a 24GB GPU + sae_batch_size = 8192 + num_contexts_per_sae_batch = sae_batch_size // context_length + + num_inputs_in_buffer = num_contexts_per_sae_batch * 20 + + num_tokens = 10_000_000 + + # sae training parameters + k = 40 + sparsity_penalty = 0.05 + expansion_factor = 8 + + steps = int(num_tokens / sae_batch_size) # Total number of batches to train + save_steps = None + warmup_steps = 1000 # Warmup period at start of training and after each resample + resample_steps = None + + # standard sae training parameters + learning_rate = 3e-4 + + # topk sae training parameters + decay_start = 24000 + auxk_alpha = 1 / 32 + + submodule = model.gpt_neox.layers[LAYER] + submodule_name = f"resid_post_layer_{LAYER}" + io = "out" + activation_dim = model.config.hidden_size + + generator = hf_dataset_to_generator(DATASET_NAME) + + activation_buffer = ActivationBuffer( + generator, + model, + submodule, + n_ctxs=num_inputs_in_buffer, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + # create the list of configs + trainer_configs = [] + trainer_configs.extend( + [ + { + "trainer": TrainerTopK, + "dict_class": AutoEncoderTopK, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "k": k, + "auxk_alpha": auxk_alpha, # see Appendix A.2 + "decay_start": decay_start, # when does the lr decay start + "steps": steps, # when when does training end + "seed": RANDOM_SEED, + "wandb_name": f"TopKTrainer-{MODEL_NAME}-{submodule_name}", + "device": DEVICE, + "layer": LAYER, + "lm_name": MODEL_NAME, + "submodule_name": submodule_name, + }, + ] + ) + trainer_configs.extend( + [ + { + "trainer": StandardTrainer, + "dict_class": AutoEncoder, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "lr": learning_rate, + "l1_penalty": sparsity_penalty, + "warmup_steps": warmup_steps, + "resample_steps": resample_steps, + "seed": RANDOM_SEED, + "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", + "layer": LAYER, + "lm_name": MODEL_NAME, + "device": DEVICE, + "submodule_name": submodule_name, + }, + ] + ) + + print(f"len trainer configs: {len(trainer_configs)}") + output_dir = f"{SAVE_DIR}/{submodule_name}" + + trainSAE( + data=activation_buffer, + trainer_configs=trainer_configs, + steps=steps, + save_steps=save_steps, + save_dir=output_dir, + ) + + folders = get_nested_folders(output_dir) + + assert len(folders) == 2 + + for folder in folders: + dictionary, config = load_dictionary(folder, DEVICE) + + assert dictionary is not None + assert config is not None + + +def test_evaluation(): + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) + ae_paths = get_nested_folders(SAVE_DIR) + + context_length = 128 + llm_batch_size = 100 + sae_batch_size = 4096 + n_batches = 10 + buffer_size = 256 + io = "out" + + generator = hf_dataset_to_generator(DATASET_NAME) + submodule = model.gpt_neox.layers[LAYER] + + input_strings = [] + for i, example in enumerate(generator): + input_strings.append(example) + if i > buffer_size * n_batches: + break + + for ae_path in ae_paths: + dictionary, config = load_dictionary(ae_path, DEVICE) + dictionary = dictionary.to(dtype=model.dtype) + + activation_dim = config["trainer"]["activation_dim"] + context_length = config["buffer"]["ctx_len"] + + activation_buffer_data = iter(input_strings) + + activation_buffer = ActivationBuffer( + activation_buffer_data, + model, + submodule, + n_ctxs=buffer_size, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + eval_results = evaluate( + dictionary, + activation_buffer, + context_length, + llm_batch_size, + io=io, + device=DEVICE, + n_batches=n_batches, + ) + + print(eval_results) + + dict_class = config["trainer"]["dict_class"] + expected_results = EXPECTED_RESULTS[dict_class] + + for key, value in expected_results.items(): + assert abs(eval_results[key] - value) < EVAL_TOLERANCE diff --git a/training.py b/training.py index f100fee..13fd4b3 100644 --- a/training.py +++ b/training.py @@ -6,6 +6,7 @@ import multiprocessing as mp import os from queue import Empty +from typing import Optional import torch as t from tqdm import tqdm @@ -75,17 +76,17 @@ def log_stats( def trainSAE( data, - trainer_configs, - use_wandb=False, - wandb_entity="", - wandb_project="", - steps=None, - save_steps=None, - save_dir=None, - log_steps=None, - activations_split_by_head=False, - transcoder=False, - run_cfg={}, + trainer_configs: list[dict], + use_wandb:bool=False, + wandb_entity:str="", + wandb_project:str="", + steps:Optional[int]=None, + save_steps:Optional[list[int]]=None, + save_dir:Optional[str]=None, + log_steps:Optional[int]=None, + activations_split_by_head:bool=False, + transcoder:bool=False, + run_cfg:dict={}, ): """ Train SAEs using the given trainers @@ -140,7 +141,7 @@ def trainSAE( ) # saving - if save_steps is not None and step % save_steps == 0: + if save_steps is not None and step in save_steps: for dir, trainer in zip(save_dirs, trainers): if dir is not None: if not os.path.exists(os.path.join(dir, "checkpoints")):