From 944edd1b0c588ba288df556a4cfbc90523b78fda Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 13 May 2025 02:02:02 +0000 Subject: [PATCH 1/4] Add a pytorch activation buffer, enable model truncation --- dictionary_learning/pytorch_buffer.py | 225 ++++++++++++++++++++++ dictionary_learning/utils.py | 62 +++++- tests/test_end_to_end.py | 8 +- tests/test_pytorch_end_to_end.py | 261 ++++++++++++++++++++++++++ 4 files changed, 548 insertions(+), 8 deletions(-) create mode 100644 dictionary_learning/pytorch_buffer.py create mode 100644 tests/test_pytorch_end_to_end.py diff --git a/dictionary_learning/pytorch_buffer.py b/dictionary_learning/pytorch_buffer.py new file mode 100644 index 0000000..720101f --- /dev/null +++ b/dictionary_learning/pytorch_buffer.py @@ -0,0 +1,225 @@ +import torch as t +from transformers import AutoModelForCausalLM, AutoTokenizer +import gc +from tqdm import tqdm +import contextlib + + +class EarlyStopException(Exception): + """Custom exception for stopping model forward pass early.""" + + pass + + +def collect_activations( + model: AutoModelForCausalLM, + submodule: t.nn.Module, + inputs_BL: dict[str, t.Tensor], + use_no_grad: bool = True, +) -> t.Tensor: + """ + Registers a forward hook on the submodule to capture the residual (or hidden) + activations. We then raise an EarlyStopException to skip unneeded computations. + + Args: + model: The model to run. + submodule: The submodule to hook into. + inputs_BL: The inputs to the model. + use_no_grad: Whether to run the forward pass within a `t.no_grad()` context. Defaults to True. + """ + activations_BLD = None + + def gather_target_act_hook(module, inputs, outputs): + nonlocal activations_BLD + # For many models, the submodule outputs are a tuple or a single tensor: + # If "outputs" is a tuple, pick the relevant item: + # e.g. if your layer returns (hidden, something_else), you'd do outputs[0] + # Otherwise just do outputs + if isinstance(outputs, tuple): + activations_BLD = outputs[0] + else: + activations_BLD = outputs + + raise EarlyStopException("Early stopping after capturing activations") + + handle = submodule.register_forward_hook(gather_target_act_hook) + + # Determine the context manager based on the flag + context_manager = t.no_grad() if use_no_grad else contextlib.nullcontext() + + try: + # Use the selected context manager + with context_manager: + _ = model(**inputs_BL) + except EarlyStopException: + pass + except Exception as e: + print(f"Unexpected error during forward pass: {str(e)}") + raise + finally: + handle.remove() + + if activations_BLD is None: + # This should ideally not happen if the hook worked and EarlyStopException was raised, + # but handle it just in case. + raise RuntimeError( + "Failed to collect activations. The hook might not have run correctly." + ) + + return activations_BLD + + +class ActivationBuffer: + """ + Implements a buffer of activations. The buffer stores activations from a model, + yields them in batches, and refreshes them when the buffer is less than half full. + """ + + def __init__( + self, + data, # generator which yields text data + model: AutoModelForCausalLM, # Language Model from which to extract activations + submodule, # submodule of the model from which to extract activations + d_submodule=None, # submodule dimension; if None, try to detect automatically + io="out", # can be 'in' or 'out'; whether to extract input or output activations + n_ctxs=3e4, # approximate number of contexts to store in the buffer + ctx_len=128, # length of each context + refresh_batch_size=512, # size of batches in which to process the data when adding to buffer + out_batch_size=8192, # size of batches in which to yield activations + device="cpu", # device on which to store the activations + remove_bos: bool = False, + add_special_tokens: bool = True, + ): + if io not in ["in", "out"]: + raise ValueError("io must be either 'in' or 'out'") + + if d_submodule is None: + try: + if io == "in": + d_submodule = submodule.in_features + else: + d_submodule = submodule.out_features + except: + raise ValueError( + "d_submodule cannot be inferred and must be specified directly" + ) + self.activations = t.empty(0, d_submodule, device=device, dtype=model.dtype) + self.read = t.zeros(0).bool() + + self.data = data + self.model = model + self.submodule = submodule + self.d_submodule = d_submodule + self.io = io + self.n_ctxs = n_ctxs + self.ctx_len = ctx_len + self.activation_buffer_size = n_ctxs * ctx_len + self.refresh_batch_size = refresh_batch_size + self.out_batch_size = out_batch_size + self.device = device + self.remove_bos = remove_bos + self.add_special_tokens = add_special_tokens + self.tokenizer = AutoTokenizer.from_pretrained(model.name_or_path) + + if not self.tokenizer.pad_token: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def __iter__(self): + return self + + def __next__(self): + """ + Return a batch of activations + """ + with t.no_grad(): + # if buffer is less than half full, refresh + if (~self.read).sum() < self.activation_buffer_size // 2: + self.refresh() + + # return a batch + unreads = (~self.read).nonzero().squeeze() + idxs = unreads[ + t.randperm(len(unreads), device=unreads.device)[: self.out_batch_size] + ] + self.read[idxs] = True + return self.activations[idxs] + + def text_batch(self, batch_size=None): + """ + Return a list of text + """ + if batch_size is None: + batch_size = self.refresh_batch_size + try: + return [next(self.data) for _ in range(batch_size)] + except StopIteration: + raise StopIteration("End of data stream reached") + + def tokenized_batch(self, batch_size=None): + """ + Return a batch of tokenized inputs. + """ + texts = self.text_batch(batch_size=batch_size) + return self.tokenizer( + texts, + return_tensors="pt", + max_length=self.ctx_len, + padding=True, + truncation=True, + add_special_tokens=self.add_special_tokens, + ).to(self.device) + + def refresh(self): + gc.collect() + t.cuda.empty_cache() + self.activations = self.activations[~self.read] + + current_idx = len(self.activations) + new_activations = t.empty( + self.activation_buffer_size, + self.d_submodule, + device=self.device, + dtype=self.model.dtype, + ) + + new_activations[: len(self.activations)] = self.activations + self.activations = new_activations + + # Optional progress bar when filling buffer. At larger models / buffer sizes (e.g. gemma-2-2b, 1M tokens on a 4090) this can take a couple minutes. + # pbar = tqdm(total=self.activation_buffer_size, initial=current_idx, desc="Refreshing activations") + + while current_idx < self.activation_buffer_size: + with t.no_grad(): + input = self.tokenized_batch() + hidden_states = collect_activations(self.model, self.submodule, input) + attn_mask = input["attention_mask"] + if self.remove_bos: + hidden_states = hidden_states[:, 1:, :] + attn_mask = attn_mask[:, 1:] + hidden_states = hidden_states[attn_mask != 0] + + remaining_space = self.activation_buffer_size - current_idx + assert remaining_space > 0 + hidden_states = hidden_states[:remaining_space] + + self.activations[current_idx : current_idx + len(hidden_states)] = ( + hidden_states.to(self.device) + ) + current_idx += len(hidden_states) + + # pbar.update(len(hidden_states)) + + # pbar.close() + self.read = t.zeros(len(self.activations), dtype=t.bool, device=self.device) + + @property + def config(self): + return { + "d_submodule": self.d_submodule, + "io": self.io, + "n_ctxs": self.n_ctxs, + "ctx_len": self.ctx_len, + "refresh_batch_size": self.refresh_batch_size, + "out_batch_size": self.out_batch_size, + "device": self.device, + } diff --git a/dictionary_learning/utils.py b/dictionary_learning/utils.py index 3b1077e..537754c 100644 --- a/dictionary_learning/utils.py +++ b/dictionary_learning/utils.py @@ -3,7 +3,11 @@ import io import json import os -from nnsight import LanguageModel +from transformers import AutoModelForCausalLM +from fractions import Fraction +import random +from transformers import AutoTokenizer +import torch as t from .trainers.top_k import AutoEncoderTopK from .trainers.batch_top_k import BatchTopKSAE @@ -88,13 +92,61 @@ def load_dictionary(base_path: str, device: str) -> tuple: return dictionary, config -def get_submodule(model: LanguageModel, layer: int): +def get_submodule(model: AutoModelForCausalLM, layer: int): """Gets the residual stream submodule""" - model_name = model._model_key + model_name = model.name_or_path - if "pythia" in model_name: + if model.config.architectures[0] == "GPTNeoXForCausalLM": return model.gpt_neox.layers[layer] - elif "gemma" in model_name: + elif ( + model.config.architectures[0] == "Qwen2ForCausalLM" + or model.config.architectures[0] == "Gemma2ForCausalLM" + ): return model.model.layers[layer] else: raise ValueError(f"Please add submodule for model {model_name}") + + +def truncate_model(model: AutoModelForCausalLM, layer: int): + """From tilde-research/activault + https://github.com/tilde-research/activault/blob/db6d1e4e36c2d3eb4fdce79e72be94f387eccee1/pipeline/setup.py#L74 + This provides significant memory savings by deleting all layers that aren't needed for the given layer. + You should probably test this before using it""" + import gc + + total_params_before = sum(p.numel() for p in model.parameters()) + print(f"Model parameters before truncation: {total_params_before:,}") + + if ( + model.config.architectures[0] == "Qwen2ForCausalLM" + or model.config.architectures[0] == "Gemma2ForCausalLM" + ): + removed_layers = model.model.layers[layer + 1 :] + + model.model.layers = model.model.layers[: layer + 1] + + del removed_layers + del model.lm_head + + model.lm_head = t.nn.Identity() + + elif model.config.architectures[0] == "GPTNeoXForCausalLM": + removed_layers = model.gpt_neox.layers[layer + 1 :] + + model.gpt_neox.layers = model.gpt_neox.layers[: layer + 1] + + del removed_layers + del model.embed_out + + model.embed_out = t.nn.Identity() + + else: + raise ValueError(f"Please add truncation for model {model.name_or_path}") + + total_params_after = sum(p.numel() for p in model.parameters()) + print(f"Model parameters after truncation: {total_params_after:,}") + + gc.collect() + t.cuda.empty_cache() + + return model diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 797cbab..055db18 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -59,14 +59,16 @@ LAYER = 3 DATASET_NAME = "monology/pile-uncopyrighted" -EVAL_TOLERANCE = 0.01 +EVAL_TOLERANCE_PERCENT = 0.005 def test_sae_training(): """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. This isn't a nice suite of unit tests, but it's better than nothing. I have observed that results can slightly vary with library versions. For full determinism, - use pytorch 2.5.1 and nnsight 0.3.7.""" + use pytorch 2.5.1 and nnsight 0.3.7. + Unfortunately an RTX 3090 is also required for full determinism. On an H100 the results are off by ~0.3%, meaning this test will + not be within the EVAL_TOLERANCE.""" random.seed(RANDOM_SEED) t.manual_seed(RANDOM_SEED) @@ -257,4 +259,4 @@ def test_evaluation(): max_diff_percent = max(max_diff_percent, diff / value) print(f"Max diff: {max_diff}, max diff %: {max_diff_percent}") - assert max_diff < EVAL_TOLERANCE + assert max_diff_percent < EVAL_TOLERANCE_PERCENT diff --git a/tests/test_pytorch_end_to_end.py b/tests/test_pytorch_end_to_end.py new file mode 100644 index 0000000..79ef5b3 --- /dev/null +++ b/tests/test_pytorch_end_to_end.py @@ -0,0 +1,261 @@ +import torch as t +from transformers import AutoModelForCausalLM, AutoTokenizer +import os +import json +import random + +from dictionary_learning.training import trainSAE +from dictionary_learning.trainers.standard import StandardTrainer +from dictionary_learning.trainers.top_k import TopKTrainer, AutoEncoderTopK +from dictionary_learning.utils import ( + hf_dataset_to_generator, + get_nested_folders, + load_dictionary, +) + +# from dictionary_learning.buffer import ActivationBuffer +from dictionary_learning.pytorch_buffer import ActivationBuffer +from dictionary_learning.dictionary import ( + AutoEncoder, + GatedAutoEncoder, + AutoEncoderNew, + JumpReluAutoEncoder, +) +from dictionary_learning.evaluation import evaluate + +EXPECTED_RESULTS = { + "AutoEncoderTopK": { + "l2_loss": 4.358876752853393, + "l1_loss": 50.90618553161621, + "l0": 40.0, + "frac_variance_explained": 0.9577824175357819, + "cossim": 0.9476200461387634, + "l2_ratio": 0.9476299166679383, + "relative_reconstruction_bias": 0.9996505916118622, + "frac_alive": 1.0, + }, + "AutoEncoder": { + "l2_loss": 6.8308186531066895, + "l1_loss": 19.398421669006346, + "l0": 37.4469970703125, + "frac_variance_explained": 0.9003101229667664, + "cossim": 0.8782103300094605, + "l2_ratio": 0.7444103538990021, + "relative_reconstruction_bias": 0.960041344165802, + "frac_alive": 0.9970703125, + }, +} + +DEVICE = "cuda:0" +SAVE_DIR = "./test_data" +MODEL_NAME = "EleutherAI/pythia-70m-deduped" +RANDOM_SEED = 42 +LAYER = 3 +DATASET_NAME = "monology/pile-uncopyrighted" + +EVAL_TOLERANCE_PERCENT = 0.005 + + +def test_sae_training(): + """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. + This isn't a nice suite of unit tests, but it's better than nothing. + I have observed that results can slightly vary with library versions. For full determinism, + use pytorch 2.5.1 and nnsight 0.3.7. + Unfortunately an RTX 3090 is also required for full determinism. On an H100 the results are off by ~0.3%, meaning this test will + not be within the EVAL_TOLERANCE.""" + + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + # model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, device_map="auto", torch_dtype=t.float32 + ).to(DEVICE) + + context_length = 128 + llm_batch_size = 512 # Fits on a 24GB GPU + sae_batch_size = 8192 + num_contexts_per_sae_batch = sae_batch_size // context_length + + num_inputs_in_buffer = num_contexts_per_sae_batch * 20 + + num_tokens = 10_000_000 + + # sae training parameters + k = 40 + sparsity_penalty = 2.0 + expansion_factor = 8 + + steps = int(num_tokens / sae_batch_size) # Total number of batches to train + save_steps = None + warmup_steps = 1000 # Warmup period at start of training and after each resample + resample_steps = None + + # standard sae training parameters + learning_rate = 3e-4 + + # topk sae training parameters + decay_start = None + auxk_alpha = 1 / 32 + + submodule = model.gpt_neox.layers[LAYER] + submodule_name = f"resid_post_layer_{LAYER}" + io = "out" + activation_dim = model.config.hidden_size + + generator = hf_dataset_to_generator(DATASET_NAME) + + activation_buffer = ActivationBuffer( + generator, + model, + submodule, + n_ctxs=num_inputs_in_buffer, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + # create the list of configs + trainer_configs = [] + trainer_configs.extend( + [ + { + "trainer": TopKTrainer, + "dict_class": AutoEncoderTopK, + "lr": None, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "k": k, + "auxk_alpha": auxk_alpha, # see Appendix A.2 + "warmup_steps": 0, + "decay_start": decay_start, # when does the lr decay start + "steps": steps, # when when does training end + "seed": RANDOM_SEED, + "wandb_name": f"TopKTrainer-{MODEL_NAME}-{submodule_name}", + "device": DEVICE, + "layer": LAYER, + "lm_name": MODEL_NAME, + "submodule_name": submodule_name, + }, + ] + ) + trainer_configs.extend( + [ + { + "trainer": StandardTrainer, + "dict_class": AutoEncoder, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "lr": learning_rate, + "l1_penalty": sparsity_penalty, + "warmup_steps": warmup_steps, + "sparsity_warmup_steps": None, + "decay_start": decay_start, + "steps": steps, + "resample_steps": resample_steps, + "seed": RANDOM_SEED, + "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", + "layer": LAYER, + "lm_name": MODEL_NAME, + "device": DEVICE, + "submodule_name": submodule_name, + }, + ] + ) + + print(f"len trainer configs: {len(trainer_configs)}") + output_dir = f"{SAVE_DIR}/{submodule_name}" + + trainSAE( + data=activation_buffer, + trainer_configs=trainer_configs, + steps=steps, + save_steps=save_steps, + save_dir=output_dir, + ) + + folders = get_nested_folders(output_dir) + + assert len(folders) == 2 + + for folder in folders: + dictionary, config = load_dictionary(folder, DEVICE) + + assert dictionary is not None + assert config is not None + + +def test_evaluation(): + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, device_map="auto", torch_dtype=t.float32 + ).to(DEVICE) + ae_paths = get_nested_folders(SAVE_DIR) + + context_length = 128 + llm_batch_size = 100 + sae_batch_size = 4096 + n_batches = 10 + buffer_size = 256 + io = "out" + + generator = hf_dataset_to_generator(DATASET_NAME) + submodule = model.gpt_neox.layers[LAYER] + + input_strings = [] + for i, example in enumerate(generator): + input_strings.append(example) + if i > buffer_size * n_batches: + break + + for ae_path in ae_paths: + dictionary, config = load_dictionary(ae_path, DEVICE) + dictionary = dictionary.to(dtype=model.dtype) + + activation_dim = config["trainer"]["activation_dim"] + context_length = config["buffer"]["ctx_len"] + + activation_buffer_data = iter(input_strings) + + activation_buffer = ActivationBuffer( + activation_buffer_data, + model, + submodule, + n_ctxs=buffer_size, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + eval_results = evaluate( + dictionary, + activation_buffer, + context_length, + llm_batch_size, + io=io, + device=DEVICE, + n_batches=n_batches, + ) + + print(eval_results) + + dict_class = config["trainer"]["dict_class"] + expected_results = EXPECTED_RESULTS[dict_class] + + max_diff = 0 + max_diff_percent = 0 + for key, value in expected_results.items(): + diff = abs(eval_results[key] - value) + max_diff = max(max_diff, diff) + max_diff_percent = max(max_diff_percent, diff / value) + + print(f"Max diff: {max_diff}, max diff %: {max_diff_percent}") + assert max_diff_percent < EVAL_TOLERANCE_PERCENT From c644ccd20d6b17e1bc1199e981cf8539c5c06f0a Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 13 May 2025 02:02:20 +0000 Subject: [PATCH 2/4] Add better dataset generators --- dictionary_learning/buffer.py | 10 +- dictionary_learning/utils.py | 189 ++++++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+), 3 deletions(-) diff --git a/dictionary_learning/buffer.py b/dictionary_learning/buffer.py index 7d304b7..67e1c0f 100644 --- a/dictionary_learning/buffer.py +++ b/dictionary_learning/buffer.py @@ -28,6 +28,7 @@ def __init__(self, out_batch_size=8192, # size of batches in which to yield activations device='cpu', # device on which to store the activations remove_bos: bool = False, + add_special_tokens: bool = True, ): if io not in ['in', 'out']: @@ -56,7 +57,8 @@ def __init__(self, self.out_batch_size = out_batch_size self.device = device self.remove_bos = remove_bos - + self.add_special_tokens = add_special_tokens + def __iter__(self): return self @@ -98,7 +100,8 @@ def tokenized_batch(self, batch_size=None): return_tensors='pt', max_length=self.ctx_len, padding=True, - truncation=True + truncation=True, + add_special_tokens=self.add_special_tokens ) def refresh(self): @@ -117,8 +120,9 @@ def refresh(self): while current_idx < self.activation_buffer_size: with t.no_grad(): + tokens = self.tokenized_batch() with self.model.trace( - self.text_batch(), + tokens, **tracer_kwargs, invoker_args={"truncation": True, "max_length": self.ctx_len}, ): diff --git a/dictionary_learning/utils.py b/dictionary_learning/utils.py index 537754c..6f2d2c0 100644 --- a/dictionary_learning/utils.py +++ b/dictionary_learning/utils.py @@ -47,6 +47,195 @@ def generator(): return generator() +def randomly_remove_system_prompt( + text: str, freq: float, system_prompt: str | None = None +) -> str: + if system_prompt and random.random() < freq: + assert system_prompt in text + text = text.replace(system_prompt, "") + return text + + +def hf_mixed_dataset_to_generator( + tokenizer: AutoTokenizer, + pretrain_dataset: str = "HuggingFaceFW/fineweb", + chat_dataset: str = "lmsys/lmsys-chat-1m", + min_chars: int = 1, + pretrain_frac: float = 0.9, # 0.9 → 90 % pretrain, 10 % chat + split: str = "train", + streaming: bool = True, + pretrain_key: str = "text", + chat_key: str = "conversation", + sequence_pack_pretrain: bool = True, + sequence_pack_chat: bool = False, + system_prompt_to_remove: str | None = None, + system_prompt_removal_freq: float = 0.9, +): + """Get a mix of pretrain and chat data at a specified ratio. By default, 90% of the data will be pretrain and 10% will be chat. + + Default datasets: + pretrain_dataset: "HuggingFaceFW/fineweb" + chat_dataset: "lmsys/lmsys-chat-1m" + + Note that you will have to request permission for lmsys (instant approval on HuggingFace). + + min_chars: minimum number of characters per sample. To perform sequence packing, set it to ~4x sequence length in tokens. + Samples will be joined with the eos token. + If it's low (like 1), each sample will just be a single row from the dataset, padded to the max length. Sometimes this will fill the context, sometimes it won't. + + Why use strings instead of tokens? Because dictionary learning expects an iterator of strings, and this is simple and good enough. + + Implicit assumption: each sample will be truncated to sequence length when tokenized. + + By default, we sequence pack the pretrain data and DO NOT sequence pack the chat data, as it would look kind of weird. The EOS token is used to separate + user / assistant messages, not to separate conversations from different users. + If you want to sequence pack the chat data, set sequence_pack_chat to True. + + Pretrain format will be: texttexttext... + Chat format will be Optionally: ... + + Other parameters: + - system_prompt_to_remove: an optional string that will be removed from the chat data with a given frequency. + You probably want to verify that the system prompt you pass in is correct. + - system_prompt_removal_freq: the frequency with which the system prompt will be removed + + Why? Well, we probably don't want to have 1000's of copies of the system prompt in the training dataset. But we also may not want to remove it entirely. + And we may want to use the LLM with no system prompt when comparing between models. + IDK, this is a complicated and annoying detail. At least this constrains the complexity to the dataset generator. + """ + if not 0 < pretrain_frac < 1: + raise ValueError("main_frac must be between 0 and 1 (exclusive)") + + assert min_chars > 0 + + # Load both datasets as iterable streams + pretrain_ds = iter(load_dataset(pretrain_dataset, split=split, streaming=streaming)) + chat_ds = iter(load_dataset(chat_dataset, split=split, streaming=streaming)) + + # Convert the fraction to two small integers (e.g. 0.9 → 9 / 10) + frac = Fraction(pretrain_frac).limit_denominator() + n_pretrain = frac.numerator + n_chat = frac.denominator - n_pretrain + eos_token = tokenizer.eos_token + + bos_token = tokenizer.bos_token if tokenizer.bos_token else eos_token + + def gen(): + while True: + for _ in range(n_pretrain): + if sequence_pack_pretrain: + length = 0 + samples = [] + while length < min_chars: + # Add bos token to the beginning of the sample + sample = next(pretrain_ds)[pretrain_key] + samples.append(sample) + length += len(sample) + samples = bos_token + eos_token.join(samples) + yield samples + else: + sample = bos_token + next(pretrain_ds)[pretrain_key] + yield sample + for _ in range(n_chat): + if sequence_pack_chat: + length = 0 + samples = [] + while length < min_chars: + sample = next(chat_ds)[chat_key] + # Apply chat template also includes bos token + sample = tokenizer.apply_chat_template(sample, tokenize=False) + sample = randomly_remove_system_prompt( + sample, system_prompt_removal_freq, system_prompt_to_remove + ) + samples.append(sample) + length += len(sample) + samples = "".join(samples) + yield samples + else: + sample = tokenizer.apply_chat_template( + next(chat_ds)[chat_key], tokenize=False + ) + sample = randomly_remove_system_prompt( + sample, system_prompt_removal_freq, system_prompt_to_remove + ) + yield sample + + return gen() + + +def hf_sequence_packing_dataset_to_generator( + tokenizer: AutoTokenizer, + pretrain_dataset: str = "HuggingFaceFW/fineweb", + min_chars: int = 1, + split: str = "train", + streaming: bool = True, + pretrain_key: str = "text", + sequence_pack_pretrain: bool = True, +): + """min_chars: minimum number of characters per sample. To perform sequence packing, set it to ~4x sequence length in tokens. + Samples will be joined with the eos token. + If it's low (like 1), each sample will just be a single row from the dataset, padded to the max length. Sometimes this will fill the context, sometimes it won't.""" + assert min_chars > 0 + + # Load both datasets as iterable streams + pretrain_ds = iter(load_dataset(pretrain_dataset, split=split, streaming=streaming)) + + eos_token = tokenizer.eos_token + + bos_token = tokenizer.bos_token if tokenizer.bos_token else eos_token + + def gen(): + while True: + if sequence_pack_pretrain: + length = 0 + samples = [] + while length < min_chars: + # Add bos token to the beginning of the sample + sample = next(pretrain_ds)[pretrain_key] + samples.append(sample) + length += len(sample) + samples = bos_token + eos_token.join(samples) + yield samples + else: + sample = bos_token + next(pretrain_ds)[pretrain_key] + yield sample + + return gen() + + +def simple_hf_mixed_dataset_to_generator( + main_name: str, + aux_name: str, + main_frac: float = 0.9, # 0.9 → 90 % main, 10 % aux + split: str = "train", + streaming: bool = True, + main_key: str = "text", + aux_key: str = "text", +): + if not 0 < main_frac < 1: + raise ValueError("main_frac must be between 0 and 1 (exclusive)") + + # Load both datasets as iterable streams + main_ds = iter(load_dataset(main_name, split=split, streaming=streaming)) + aux_ds = iter(load_dataset(aux_name, split=split, streaming=streaming)) + + # Convert the fraction to two small integers (e.g. 0.9 → 9 / 10) + frac = Fraction(main_frac).limit_denominator() + n_main = frac.numerator + n_aux = frac.denominator - n_main + + def gen(): + while True: + # Yield `n_main` items from the main dataset + for _ in range(n_main): + yield next(main_ds)[main_key] + # Yield `n_aux` items from the auxiliary dataset + for _ in range(n_aux): + yield next(aux_ds)[aux_key] + + return gen() + + def get_nested_folders(path: str) -> list[str]: """ Recursively get a list of folders that contain an ae.pt file, starting the search from the given path From 17a41c76335c038a0da3ae6f6589ba2afc6a19b3 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 13 May 2025 02:35:03 +0000 Subject: [PATCH 3/4] Add optional backup step --- dictionary_learning/training.py | 54 ++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/dictionary_learning/training.py b/dictionary_learning/training.py index 50f20ee..0671f31 100644 --- a/dictionary_learning/training.py +++ b/dictionary_learning/training.py @@ -126,6 +126,7 @@ def trainSAE( verbose:bool=False, device:str="cuda", autocast_dtype: t.dtype = t.float32, + backup_steps:Optional[int]=None, ): """ Train SAEs using the given trainers @@ -214,23 +215,42 @@ def trainSAE( # saving if save_steps is not None and step in save_steps: for dir, trainer in zip(save_dirs, trainers): - if dir is not None: - - if normalize_activations: - # Temporarily scale up biases for checkpoint saving - trainer.ae.scale_biases(norm_factor) - - if not os.path.exists(os.path.join(dir, "checkpoints")): - os.mkdir(os.path.join(dir, "checkpoints")) - - checkpoint = {k: v.cpu() for k, v in trainer.ae.state_dict().items()} - t.save( - checkpoint, - os.path.join(dir, "checkpoints", f"ae_{step}.pt"), - ) - - if normalize_activations: - trainer.ae.scale_biases(1 / norm_factor) + if dir is None: + continue + + if normalize_activations: + # Temporarily scale up biases for checkpoint saving + trainer.ae.scale_biases(norm_factor) + + if not os.path.exists(os.path.join(dir, "checkpoints")): + os.mkdir(os.path.join(dir, "checkpoints")) + + checkpoint = {k: v.cpu() for k, v in trainer.ae.state_dict().items()} + t.save( + checkpoint, + os.path.join(dir, "checkpoints", f"ae_{step}.pt"), + ) + + if normalize_activations: + trainer.ae.scale_biases(1 / norm_factor) + + # backup + if backup_steps is not None and step % backup_steps == 0: + for save_dir, trainer in zip(save_dirs, trainers): + if save_dir is None: + continue + # save the current state of the trainer for resume if training is interrupted + # this will be overwritten by the next checkpoint and at the end of training + t.save( + { + "step": step, + "ae": trainer.ae.state_dict(), + "optimizer": trainer.optimizer.state_dict(), + "config": trainer.config, + "norm_factor": norm_factor, + }, + os.path.join(save_dir, "ae.pt"), + ) # training for trainer in trainers: From fe9d8c7c144ec85eff87154f3704afaa6a24e671 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 13 May 2025 03:29:26 +0000 Subject: [PATCH 4/4] add activault buffer implementation --- dictionary_learning/activault_s3_buffer.py | 744 +++++++++++++++++++++ 1 file changed, 744 insertions(+) create mode 100644 dictionary_learning/activault_s3_buffer.py diff --git a/dictionary_learning/activault_s3_buffer.py b/dictionary_learning/activault_s3_buffer.py new file mode 100644 index 0000000..1b94a7e --- /dev/null +++ b/dictionary_learning/activault_s3_buffer.py @@ -0,0 +1,744 @@ +"""Copyright (2025) Tilde Research Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import io +import json +import os +import random +import signal +import sys +import time +import warnings +from multiprocessing import Process, Queue, Value +from typing import Optional + +import einops +import aiohttp +import boto3 +import torch +import torch.nn as nn +import multiprocessing as mp +import warnings +import logging + +logger = logging.getLogger(__name__) + +# Constants for file sizes +KB = 1024 +MB = KB * KB + +# Cache directory constants +OUTER_CACHE_DIR = "cache" +INNER_CACHE_DIR = "cache" +BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "main") + + +def _metadata_path(run_name): + """Generate the metadata file path for a given run name.""" + return f"{run_name}/metadata.json" + + +def _statistics_path(run_name): + """Generate the statistics file path for a given run name.""" + return f"{run_name}/statistics.json" + + +async def download_chunks(session, url, total_size, chunk_size): + """Download file chunks asynchronously with retries.""" + tries_left = 5 + while tries_left > 0: + chunks = [ + (i, min(i + chunk_size - 1, total_size - 1)) + for i in range(0, total_size, chunk_size) + ] + tasks = [ + asyncio.create_task(request_chunk(session, url, start, end)) + for start, end in chunks + ] + responses = await asyncio.gather(*tasks, return_exceptions=True) + + results = [] + retry = False + for response in responses: + if isinstance(response, Exception): + logger.error(f"Error occurred: {response}") + logger.error( + f"Session: {session}, URL: {url}, Tries left: {tries_left}" + ) + tries_left -= 1 + retry = True + break + else: + results.append(response) + + if not retry: + return results + + return None + + +async def request_chunk(session, url, start, end): + """Request a specific chunk of a file.""" + headers = {"Range": f"bytes={start}-{end}"} + try: + async with session.get(url, headers=headers) as response: + response.raise_for_status() + return start, await response.read() + except Exception as e: + return e + + +def download_loop(*args): + """Run the asynchronous download loop.""" + asyncio.run(_async_download(*args)) + + +def compile(byte_buffers, shuffle=True, seed=None, return_ids=False): + """Compile downloaded chunks into a tensor.""" + combined_bytes = b"".join( + chunk for _, chunk in sorted(byte_buffers, key=lambda x: x[0]) + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # n = np.frombuffer(combined_bytes, dtype=np.float16) + # t = torch.from_numpy(n) + # t = torch.frombuffer(combined_bytes, dtype=dtype) # torch.float32 + buffer = io.BytesIO(combined_bytes) + t = torch.load(buffer) + if ( + isinstance(t, dict) and "states" in t and not return_ids + ): # backward compatibility + t = t["states"] # ignore input_ids + buffer.close() + + if shuffle and not return_ids: + t = shuffle_megabatch_tokens(t, seed) + + return t + + +def shuffle_megabatch_tokens(t, seed=None): + """ + Shuffle within a megabatch (across batches and sequences), using each token as the unit of shuffling. + + Args: + t (torch.Tensor): Input tensor of shape (batch_size * batches_per_file, sequence_length, d_in + 1) + seed (int): Seed for the random number generator + + Returns: + torch.Tensor: Shuffled tensor of the same shape as input + """ + original_shape = ( + t.shape + ) # (batch_size * batches_per_file, sequence_length, d_in + 1) + + total_tokens = ( + original_shape[0] * original_shape[1] + ) # reshape to (total_tokens, d_in + 1) + t_reshaped = t.reshape(total_tokens, -1) + + rng = torch.Generator() + if seed is not None: + rng.manual_seed(seed) + + shuffled_indices = torch.randperm(total_tokens, generator=rng) + t_shuffled = t_reshaped[shuffled_indices] + + t = t_shuffled.reshape(original_shape) # revert + + return t + + +def write_tensor(t, buffer, writeable_tensors, readable_tensors, ongoing_downloads): + """Write a tensor to the shared buffer.""" + idx = writeable_tensors.get(block=True) + if isinstance(buffer[0], SharedBuffer): + buffer[idx].states.copy_(t["states"]) + buffer[idx].input_ids.copy_(t["input_ids"]) + else: + buffer[idx] = t + + readable_tensors.put(idx, block=True) + with ongoing_downloads.get_lock(): + ongoing_downloads.value -= 1 + + +async def _async_download( + buffer, + file_index, + s3_paths, + stop, + readable_tensors, + writeable_tensors, + ongoing_downloads, + concurrency, + bytes_per_file, + chunk_size, + shuffle, + seed, + return_ids, +): + """Asynchronously download and process files from S3.""" + connector = aiohttp.TCPConnector(limit=concurrency) + async with aiohttp.ClientSession(connector=connector) as session: + while file_index.value < len(s3_paths) and not stop.value: + with ongoing_downloads.get_lock(): + ongoing_downloads.value += 1 + with file_index.get_lock(): + url = s3_paths[file_index.value] + file_index.value += 1 + bytes_results = await download_chunks( + session, url, bytes_per_file, chunk_size + ) + if bytes_results is not None: + try: + t = compile(bytes_results, shuffle, seed, return_ids) + write_tensor( + t, + buffer, + writeable_tensors, + readable_tensors, + ongoing_downloads, + ) + except Exception as e: + logger.error(f"Exception while downloading: {e}") + logger.error(f"Failed URL: {url}") + stop.value = True # Set stop flag + break # Exit the loop + else: + logger.error(f"Failed to download URL: {url}") + with ongoing_downloads.get_lock(): + ongoing_downloads.value -= 1 + + +class S3RCache: + """A cache that reads data from Amazon S3.""" + + @classmethod + def from_credentials( + self, aws_access_key_id, aws_secret_access_key, *args, **kwargs + ): + s3_client = boto3.client( + "s3", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + endpoint_url=os.environ.get("S3_ENDPOINT_URL"), + ) + return S3RCache(s3_client, *args, **kwargs) + + def __init__( + self, + s3_client, + s3_prefix, + bucket_name=BUCKET_NAME, + device="cpu", + concurrency=100, + chunk_size=MB * 16, + buffer_size=2, + shuffle=True, + preserve_file_order=False, + seed=42, + paths=None, + n_workers=1, + return_ids=False, + ) -> None: + """Initialize S3 cache.""" + ensure_spawn_context() + + # Configure S3 client with correct signature version + self.s3_client = ( + boto3.client( + "s3", + region_name="eu-north1", # Make sure this matches your bucket region + config=boto3.session.Config(signature_version="s3v4"), + ) + if s3_client is None + else s3_client + ) + + self.s3_prefix = s3_prefix + self.bucket_name = bucket_name + self.device = device + self.concurrency = concurrency + self.chunk_size = chunk_size + self.buffer_size = buffer_size + self.shuffle = shuffle + self.preserve_file_order = preserve_file_order + self.seed = seed + self.return_ids = return_ids + + random.seed(self.seed) + torch.manual_seed(self.seed) # unclear if this has effect + # but we drill down the seed to download loop anyway + + self.paths = paths + self._s3_paths = self._list_s3_files() + if isinstance(self.s3_prefix, list): + target_prefix = self.s3_prefix[0] + else: + target_prefix = self.s3_prefix + response = self.s3_client.get_object( + Bucket=bucket_name, Key=_metadata_path(target_prefix) + ) + content = response["Body"].read() + self.metadata = json.loads(content) + # self.metadata["bytes_per_file"] = 1612711320 + self._activation_dtype = eval(self.metadata["dtype"]) + + self._running_processes = [] + self.n_workers = n_workers + + self.readable_tensors = Queue(maxsize=self.buffer_size) + self.writeable_tensors = Queue(maxsize=self.buffer_size) + + for i in range(self.buffer_size): + self.writeable_tensors.put(i) + + if self.return_ids: + self.buffer = [ + SharedBuffer( + self.metadata["shape"], + self.metadata["input_ids_shape"], + self._activation_dtype, + ) + for _ in range(self.buffer_size) + ] + for shared_buffer in self.buffer: + shared_buffer.share_memory() + else: + self.buffer = torch.empty( + (self.buffer_size, *self.metadata["shape"]), + dtype=self._activation_dtype, + ).share_memory_() + + self._stop = Value("b", False) + self._file_index = Value("i", 0) + self._ongoing_downloads = Value("i", 0) + + signal.signal(signal.SIGTERM, self._catch_stop) + signal.signal(signal.SIGINT, self._catch_stop) + + self._initial_file_index = 0 + + @property + def current_file_index(self): + return self._file_index.value + + def set_file_index(self, index): + self._initial_file_index = index + + def _catch_stop(self, *args, **kwargs): + logger.info("cleaning up before process is killed") + self._stop_downloading() + sys.exit(0) + + def sync(self): + self._s3_paths = self._list_s3_files() + + def _reset(self): + self._file_index.value = self._initial_file_index + self._ongoing_downloads.value = 0 + self._stop.value = False + + while not self.readable_tensors.empty(): + self.readable_tensors.get() + + while not self.writeable_tensors.empty(): + self.writeable_tensors.get() + for i in range(self.buffer_size): + self.writeable_tensors.put(i) + + def _list_s3_files(self): + """List and prepare all data files from one or more S3 prefixes.""" + paths = [] + combined_metadata = None + combined_config = None + + # Handle single prefix case for backward compatibility + prefixes = ( + [self.s3_prefix] if isinstance(self.s3_prefix, str) else self.s3_prefix + ) + + # Process each prefix + for prefix in prefixes: + # Get metadata for this prefix + response = self.s3_client.get_object( + Bucket=self.bucket_name, Key=_metadata_path(prefix) + ) + metadata = json.loads(response["Body"].read()) + + # Get config for this prefix + try: + config_response = self.s3_client.get_object( + Bucket=self.bucket_name, + Key=f"{'/'.join(prefix.split('/')[:-1])}/cfg.json", + ) + config = json.loads(config_response["Body"].read()) + except Exception as e: + logger.warning( + f"Warning: Could not load config for prefix {prefix}: {e}" + ) + config = {} + + # Initialize combined metadata and config from first prefix + if combined_metadata is None: + combined_metadata = metadata.copy() + combined_config = config.copy() + # Initialize accumulation fields + combined_config["total_tokens"] = 0 + combined_config["n_total_files"] = 0 + combined_config["batches_processed"] = 0 + else: + # Verify metadata compatibility + if metadata["shape"][1:] != combined_metadata["shape"][1:]: + raise ValueError( + f"Incompatible shapes between datasets: {metadata['shape']} vs {combined_metadata['shape']}" + ) + if metadata["dtype"] != combined_metadata["dtype"]: + raise ValueError(f"Incompatible dtypes between datasets") + + # Accumulate config fields + combined_config["total_tokens"] += config.get("total_tokens", 0) + combined_config["n_total_files"] += config.get("n_total_files", 0) + combined_config["batches_processed"] += config.get("batches_processed", 0) + + # List files for this prefix + paginator = self.s3_client.get_paginator("list_objects_v2") + page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix) + + prefix_paths = [] + for page in page_iterator: + if "Contents" not in page: + continue + + for obj in page["Contents"]: + if ( + obj["Key"] != _metadata_path(prefix) + and obj["Key"] != _statistics_path(prefix) + and not obj["Key"].endswith("cfg.json") + ): + url = self.s3_client.generate_presigned_url( + "get_object", + Params={"Bucket": self.bucket_name, "Key": obj["Key"]}, + ExpiresIn=604700, + ) + prefix_paths.append(url) + + paths.extend(prefix_paths) + + # Store the combined metadata and config + self.metadata = combined_metadata + self.config = combined_config # Store combined config for potential later use + + if self.preserve_file_order: + # chronological upload order + return sorted(paths) + else: + # shuffle the file order + random.shuffle(paths) + return paths + + def __iter__(self): + self._reset() + + if self._running_processes: + raise ValueError( + "Cannot iterate over cache a second time while it is downloading" + ) + + if len(self._s3_paths) > self._initial_file_index: + while len(self._running_processes) < self.n_workers: + p = Process( + target=download_loop, + args=( + self.buffer, + self._file_index, + self._s3_paths[ + self._initial_file_index : + ], # Start from the initial index + self._stop, + self.readable_tensors, + self.writeable_tensors, + self._ongoing_downloads, + self.concurrency, + self.metadata["bytes_per_file"], + self.chunk_size, + self.shuffle, + self.seed, + self.return_ids, + ), + ) + p.start() + self._running_processes.append(p) + time.sleep(0.75) + + return self + + def _next_tensor(self): + try: + idx = self.readable_tensors.get(block=True) + if self.return_ids: + t = { + "states": self.buffer[idx].states.clone().detach(), + "input_ids": self.buffer[idx].input_ids.clone().detach(), + } + else: + t = self.buffer[idx].clone().detach() + + self.writeable_tensors.put(idx, block=True) + return t + except Exception as e: + logger.error(f"exception while iterating: {e}") + self._stop_downloading() + raise StopIteration + + def __next__(self): + while ( + self._file_index.value < len(self._s3_paths) + or not self.readable_tensors.empty() + or self._ongoing_downloads.value > 0 + ): + return self._next_tensor() + + if self._running_processes: + self._stop_downloading() + raise StopIteration + + def finalize(self): + self._stop_downloading() + + def _stop_downloading(self): + logger.info("stopping workers...") + self._file_index.value = len(self._s3_paths) + self._stop.value = True + + while not all([not p.is_alive() for p in self._running_processes]): + if not self.readable_tensors.empty(): + self.readable_tensors.get() + + if not self.writeable_tensors.full(): + self.writeable_tensors.put(0) + + time.sleep(0.25) + + for p in self._running_processes: + p.join() # still join to make sure all resources are cleaned up + + self._ongoing_downloads.value = 0 + self._running_processes = [] + + +""" +tl;dr of why we need this: +shared memory is handled differently for nested structures -- see buffer intiialization +we can initialize a dict with two tensors with shared memory, and these tensors themselves are shared but NOT the dict +hence writing to buffer[idx] in write_tensor will not actually write to self.buffer[idx], which _next_tensor uses +(possibly a better fix, but for now this works) +""" + + +class SharedBuffer(nn.Module): + def __init__(self, shape, input_ids_shape, dtype): + super().__init__() + self.states = nn.Parameter(torch.ones(shape, dtype=dtype), requires_grad=False) + self.input_ids = nn.Parameter( + torch.ones(input_ids_shape, dtype=torch.int64), requires_grad=False + ) + + def forward(self): + return {"states": self.states, "input_ids": self.input_ids} + + +### mini-helper for multiprocessing +def ensure_spawn_context(): + """ + Ensures multiprocessing uses 'spawn' context if not already set. + Returns silently if already set to 'spawn'. + Issues warning if unable to set to 'spawn'. + """ + if mp.get_start_method(allow_none=True) != "spawn": + try: + mp.set_start_method("spawn", force=True) + except RuntimeError: + warnings.warn( + "Multiprocessing start method is not 'spawn'. This may cause issues." + ) + + +def create_s3_client( + access_key_id: Optional[str] = None, + secret_access_key: Optional[str] = None, + endpoint_url: Optional[str] = None, +) -> boto3.client: + """Create an S3 client configured for S3-compatible storage services. + + This function creates a boto3 S3 client with optimized settings for reliable + data transfer. It supports both direct credential passing and environment + variable configuration. + + Args: + access_key_id: S3 access key ID. If None, reads from AWS_ACCESS_KEY_ID env var + secret_access_key: S3 secret key. If None, reads from AWS_SECRET_ACCESS_KEY env var + endpoint_url: S3-compatible storage service endpoint URL + + Returns: + boto3.client: Configured S3 client with optimized settings + + Environment Variables: + - AWS_ACCESS_KEY_ID: S3 access key ID (if not provided as argument) + - AWS_SECRET_ACCESS_KEY: S3 secret key (if not provided as argument) + + Example: + ```python + # Using environment variables + s3_client = create_s3_client() + + # Using explicit credentials + s3_client = create_s3_client( + access_key_id="your_key", + secret_access_key="your_secret", + endpoint_url="your_endpoint_url" + ) + ``` + + Note: + The client is configured with path-style addressing and S3v4 signatures + for maximum compatibility with S3-compatible storage services. + """ + access_key_id = access_key_id or os.environ.get("AWS_ACCESS_KEY_ID") + secret_access_key = secret_access_key or os.environ.get("AWS_SECRET_ACCESS_KEY") + endpoint_url = endpoint_url or os.environ.get("S3_ENDPOINT_URL") + + if not access_key_id or not secret_access_key: + raise ValueError( + "S3 credentials must be provided either through arguments or " + "AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables" + ) + + if not endpoint_url: + raise ValueError( + "S3 endpoint URL must be provided either through arguments or " + "S3_ENDPOINT_URL environment variable" + ) + + session = boto3.session.Session() + return session.client( + service_name="s3", + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + endpoint_url=endpoint_url, + use_ssl=True, + verify=True, + config=boto3.session.Config( + s3={"addressing_style": "path"}, + signature_version="s3v4", + # Advanced configuration options (currently commented out): + # retries=dict( + # max_attempts=3, # Number of retry attempts + # mode='adaptive' # Adds exponential backoff + # ), + # max_pool_connections=20, # Limits concurrent connections + # connect_timeout=60, # Connection timeout in seconds + # read_timeout=300, # Read timeout in seconds + # tcp_keepalive=True, # Enable TCP keepalive + ), + ) + + +class ActivaultS3ActivationBuffer: + def __init__( + self, + cache: S3RCache, + batch_size: int = 8192, + device: str = "cpu", + io: str = "out", + ): + self.cache = iter(cache) # Make sure it's an iterator + self.batch_size = batch_size + self.device = device + self.io = io + + self.states = None # Shape: [N, D] + self.read_mask = None # Shape: [N] + self.refresh() # Load the first batch + + def __iter__(self): + return self + + def __next__(self): + with torch.no_grad(): + if (~self.read_mask).sum() < self.batch_size: + self.refresh() + + if self.states is None or self.states.shape[0] == 0: + raise StopIteration + + unreads = (~self.read_mask).nonzero().squeeze() + if unreads.ndim == 0: + unreads = unreads.unsqueeze(0) + selected = unreads[ + torch.randperm(len(unreads), device=self.device)[: self.batch_size] + ] + self.read_mask[selected] = True + return self.states[selected] + + def refresh(self): + try: + next_batch = next(self.cache) # dict with "states" key + except StopIteration: + self.states = None + self.read_mask = None + return + + states = next_batch["states"].to(self.device) # [B, L, D] + flat_states = einops.rearrange(states, "b l d -> (b l) d").contiguous() + self.states = flat_states + self.read_mask = torch.zeros( + flat_states.shape[0], dtype=torch.bool, device=self.device + ) + + def close(self): + if hasattr(self.cache, "finalize"): + self.cache.finalize() + elif hasattr(self.cache, "close"): + self.cache.close() + + +if __name__ == "__main__": + device = "cuda" + sae_batch_size = 2048 + io = "out" + + # example activault usage + + BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "main") + s3_prefix = ["mistral.8b.fineweb/blocks.9.hook_resid_post"] + cache = S3RCache.from_credentials( + aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"), + s3_prefix=s3_prefix, + bucket_name=BUCKET_NAME, + device=device, + buffer_size=2, + return_ids=True, + shuffle=True, + n_workers=2, + ) + + s3_buffer = ActivaultS3ActivationBuffer( + cache, batch_size=sae_batch_size, device=device, io=io + )