From 7ec4227333cb9644a6fd8657beaca66003c3c989 Mon Sep 17 00:00:00 2001 From: Jacob G-W Date: Tue, 27 Feb 2024 19:24:47 -0500 Subject: [PATCH 1/3] add ability to vary more hyperparamaters in parallel training --- training.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/training.py b/training.py index 3b81b5b..7632612 100644 --- a/training.py +++ b/training.py @@ -150,13 +150,13 @@ def trainSAE( buffer, # an ActivationBuffer activation_dims, # dictionary of activation dimensions for each submodule (or a single int) dictionary_sizes, # dictionary of dictionary sizes for each submodule (or a single int) - lr, + lrs, # dictionary of learning rates for each submodule (or a single float) sparsity_penalty, entropy=False, steps=None, # if None, train until activations are exhausted warmup_steps=1000, # linearly increase the learning rate for this many steps resample_steps=None, # how often to resample dead neurons - ghost_threshold=None, # how many steps a neuron has to be dead for it to turn into a ghost + ghost_thresholds=None, # dictionary of how many steps a neuron has to be dead for it to turn into a ghost (or a single int) save_steps=None, # how often to save checkpoints save_dirs=None, # dictionary of directories to save checkpoints to checkpoint_offset=0, # if resuming training, the step number of the last checkpoint @@ -170,6 +170,10 @@ def trainSAE( activation_dims = {submodule: activation_dims for submodule in buffer.submodules} if isinstance(dictionary_sizes, int): dictionary_sizes = {submodule: dictionary_sizes for submodule in buffer.submodules} + if isinstance(lrs, float): + lrs = {submodule: lrs for submodule in buffer.submodules} + if isinstance(ghost_thresholds, int): + ghost_thresholds = {submodule: ghost_thresholds for submodule in buffer.submodules} aes = {} num_samples_since_activateds = {} @@ -182,7 +186,7 @@ def trainSAE( # set up optimizer and scheduler optimizers = { - submodule: ConstrainedAdam(ae.parameters(), ae.decoder.parameters(), lr=lr) for submodule, ae in aes.items() + submodule: ConstrainedAdam(ae.parameters(), ae.decoder.parameters(), lr=lrs[submodule]) for submodule, ae in aes.items() } if resample_steps is None: def warmup_fn(step): @@ -205,7 +209,7 @@ def warmup_fn(step): ae, num_samples_since_activated, optimizer, scheduler \ = aes[submodule], num_samples_since_activateds[submodule], optimizers[submodule], schedulers[submodule] optimizer.zero_grad() - loss = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_threshold) + loss = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_thresholds[submodule]) loss.backward() optimizer.step() scheduler.step() From a23173f55c4516012a0bfe2ecfb4b8731954a5ed Mon Sep 17 00:00:00 2001 From: Jacob G-W Date: Tue, 27 Feb 2024 20:57:20 -0500 Subject: [PATCH 2/3] change parallel to use list instead of dictionary This allows one to train multiple autoencoders off of a single layer since we aren't indexing a dictionary off the same thing. This could be useful for something like hyperparamater tuning where you only want to change one thing at a time. Here's an example: ``` submodules = [model.gpt_neox.layers[3].mlp, model.gpt_neox.layers[3].mlp, model.gpt_neox.layers[3].mlp] activation_dim = 512 # output dimension of the MLP dictionary_size = 16 * activation_dim learning_rates = [3e-4, 1e-3, 3e-3] ``` --- buffer.py | 30 +++++++++++++-------------- training.py | 58 +++++++++++++++++++++++++---------------------------- 2 files changed, 41 insertions(+), 47 deletions(-) diff --git a/buffer.py b/buffer.py index 6b81f75..f643568 100644 --- a/buffer.py +++ b/buffer.py @@ -24,15 +24,15 @@ def __init__(self, ): # dictionary of activations - self.activations = {} - for submodule in submodules: + self.activations = [None for _ in submodules] + for i, submodule in enumerate(submodules): if io == 'in': if in_feats is None: try: in_feats = submodule.in_features except: raise ValueError("in_feats cannot be inferred and must be specified directly") - self.activations[submodule] = t.empty(0, in_feats, device=device) + self.activations[i] = t.empty(0, in_feats, device=device) elif io == 'out': if out_feats is None: @@ -40,7 +40,7 @@ def __init__(self, out_feats = submodule.out_features except: raise ValueError("out_feats cannot be inferred and must be specified directly") - self.activations[submodule] = t.empty(0, out_feats, device=device) + self.activations[i] = t.empty(0, out_feats, device=device) elif io == 'in_to_out': raise ValueError("Support for in_to_out is depricated") self.read = t.zeros(0, dtype=t.bool, device=device) @@ -71,9 +71,7 @@ def __next__(self): unreads = (~self.read).nonzero().squeeze() idxs = unreads[t.randperm(len(unreads), device=unreads.device)[:self.out_batch_size]] self.read[idxs] = True - return { - submodule : activations[idxs] for submodule, activations in self.activations.items() - } + return [self.activations[i][idxs] for i in range(len(self.activations))] def text_batch(self, batch_size=None): """ @@ -102,34 +100,34 @@ def tokenized_batch(self, batch_size=None): ) def refresh(self): - for submodule, activations in self.activations.items(): - self.activations[submodule] = activations[~self.read].contiguous() + for i, activations in enumerate(self.activations): + self.activations[i] = activations[~self.read].contiguous() self._n_activations = (~self.read).sum().item() while self._n_activations < self.n_ctxs * self.ctx_len: with self.model.invoke(self.text_batch(), truncation=True, max_length=self.ctx_len) as invoker: - hidden_states = {} - for submodule in self.submodules: + hidden_states = [None for _ in self.submodules] + for i, submodule in enumerate(self.submodules): if self.io == 'in': x = submodule.input else: x = submodule.output if (type(x.shape) == tuple): x = x[0] - hidden_states[submodule] = x.save() + hidden_states[i] = x.save() attn_mask = invoker.input['attention_mask'] self._n_activations += (attn_mask != 0).sum().item() - for submodule, activations in self.activations.items(): - self.activations[submodule] = t.cat(( + for i, activations in enumerate(self.activations): + self.activations[i] = t.cat(( activations, - hidden_states[submodule].value[attn_mask != 0].to(activations.device)), + hidden_states[i].value[attn_mask != 0].to(activations.device)), dim=0 ) - assert len(self.activations[submodule]) == self._n_activations + assert len(self.activations[i]) == self._n_activations self.read = t.zeros(self._n_activations, dtype=t.bool, device=self.device) diff --git a/training.py b/training.py index 7632612..74fc68a 100644 --- a/training.py +++ b/training.py @@ -148,17 +148,17 @@ def resample_neurons(deads, activations, ae, optimizer): def trainSAE( buffer, # an ActivationBuffer - activation_dims, # dictionary of activation dimensions for each submodule (or a single int) - dictionary_sizes, # dictionary of dictionary sizes for each submodule (or a single int) - lrs, # dictionary of learning rates for each submodule (or a single float) + activation_dims, # list of activation dimensions for each submodule (or a single int) + dictionary_sizes, # list of dictionary sizes for each submodule (or a single int) + lrs, # list of learning rates for each submodule (or a single float) sparsity_penalty, entropy=False, steps=None, # if None, train until activations are exhausted warmup_steps=1000, # linearly increase the learning rate for this many steps resample_steps=None, # how often to resample dead neurons - ghost_thresholds=None, # dictionary of how many steps a neuron has to be dead for it to turn into a ghost (or a single int) + ghost_thresholds=None, # list of how many steps a neuron has to be dead for it to turn into a ghost (or a single int) save_steps=None, # how often to save checkpoints - save_dirs=None, # dictionary of directories to save checkpoints to + save_dirs=None, # list of directories to save checkpoints to checkpoint_offset=0, # if resuming training, the step number of the last checkpoint load_dirs=None, # if initializing from a pretrained dictionary, directories to load from log_steps=None, # how often to print statistics @@ -167,27 +167,25 @@ def trainSAE( Train and return sparse autoencoders for each submodule in the buffer. """ if isinstance(activation_dims, int): - activation_dims = {submodule: activation_dims for submodule in buffer.submodules} + activation_dims = [activation_dims for submodule in buffer.submodules] if isinstance(dictionary_sizes, int): - dictionary_sizes = {submodule: dictionary_sizes for submodule in buffer.submodules} + dictionary_sizes = [dictionary_sizes for submodule in buffer.submodules] if isinstance(lrs, float): - lrs = {submodule: lrs for submodule in buffer.submodules} + lrs = [lrs for submodule in buffer.submodules] if isinstance(ghost_thresholds, int): - ghost_thresholds = {submodule: ghost_thresholds for submodule in buffer.submodules} + ghost_thresholds = [ghost_thresholds for submodule in buffer.submodules] - aes = {} - num_samples_since_activateds = {} - for submodule in buffer.submodules: - ae = AutoEncoder(activation_dims[submodule], dictionary_sizes[submodule]).to(device) + aes = [None for submodule in buffer.submodules] + num_samples_since_activateds = [None for submodule in buffer.submodules] + for i, submodule in enumerate(buffer.submodules): + ae = AutoEncoder(activation_dims[i], dictionary_sizes[i]).to(device) if load_dirs is not None: - ae.load_state_dict(t.load(os.path.join(load_dirs[submodule]))) - aes[submodule] = ae - num_samples_since_activateds[submodule] = t.zeros(dictionary_sizes[submodule], dtype=int, device=device) + ae.load_state_dict(t.load(os.path.join(load_dirs[i]))) + aes[i] = ae + num_samples_since_activateds[i] = t.zeros(dictionary_sizes[i], dtype=int, device=device) # set up optimizer and scheduler - optimizers = { - submodule: ConstrainedAdam(ae.parameters(), ae.decoder.parameters(), lr=lrs[submodule]) for submodule, ae in aes.items() - } + optimizers = [ConstrainedAdam(ae.parameters(), ae.decoder.parameters(), lr=lrs[i]) for i, ae in enumerate(aes)] if resample_steps is None: def warmup_fn(step): return min(step / warmup_steps, 1.) @@ -195,21 +193,19 @@ def warmup_fn(step): def warmup_fn(step): return min((step % resample_steps) / warmup_steps, 1.) - schedulers = { - submodule: t.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn) for submodule, optimizer in optimizers.items() - } + schedulers = [t.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn) for optimizer in optimizers] for step, acts in enumerate(tqdm(buffer, total=steps)): real_step = step + checkpoint_offset if steps is not None and real_step >= steps: break - for submodule, act in acts.items(): + for i, act in enumerate(acts): act = act.to(device) ae, num_samples_since_activated, optimizer, scheduler \ - = aes[submodule], num_samples_since_activateds[submodule], optimizers[submodule], schedulers[submodule] + = aes[i], num_samples_since_activateds[i], optimizers[i], schedulers[i] optimizer.zero_grad() - loss = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_thresholds[submodule]) + loss = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_thresholds[i]) loss.backward() optimizer.step() scheduler.step() @@ -222,8 +218,8 @@ def warmup_fn(step): # logging if log_steps is not None and step % log_steps == 0: with t.no_grad(): - losses = sae_loss(act, ae, sparsity_penalty, entropy, separate=True, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_threshold) - if ghost_threshold is None: + losses = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_thresholds[i], separate=True) + if ghost_thresholds is None: mse_loss, sparsity_loss = losses print(f"step {step} MSE loss: {mse_loss}, sparsity loss: {sparsity_loss}") else: @@ -238,11 +234,11 @@ def warmup_fn(step): # saving if save_steps is not None and save_dirs is not None and real_step % save_steps == 0: - if not os.path.exists(os.path.join(save_dirs[submodule], "checkpoints")): - os.mkdir(os.path.join(save_dirs[submodule], "checkpoints")) + if not os.path.exists(os.path.join(save_dirs[i], "checkpoints")): + os.mkdir(os.path.join(save_dirs[i], "checkpoints")) t.save( ae.state_dict(), - os.path.join(save_dirs[submodule], "checkpoints", f"ae_{real_step}.pt") + os.path.join(save_dirs[i], "checkpoints", f"ae_{real_step}.pt") ) - return aes \ No newline at end of file + return aes From 72e23be638b3deb9626b7db0791bc8a5810a88fa Mon Sep 17 00:00:00 2001 From: Jacob G-W Date: Thu, 29 Feb 2024 19:39:09 +0000 Subject: [PATCH 3/3] add way to cache activations from layer This allows one to re-train a sparse autoencoder on the same layer without re-generating all of the activations to train on. --- buffer.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/buffer.py b/buffer.py index f643568..b889004 100644 --- a/buffer.py +++ b/buffer.py @@ -1,5 +1,8 @@ import torch as t import zstandard as zstd +import glob +from datetime import datetime +import os import json import io from nnsight import LanguageModel @@ -13,6 +16,8 @@ def __init__(self, data, # generator which yields text data model, # LanguageModel from which to extract activations submodules, # submodule of the model from which to extract activations + activation_save_dirs=None, # paths to save cached activations, one per submodule; if an individual path is None, do not cache for that submodule + activation_cache_dirs=None, # directories with cached activations to load in_feats=None, out_feats=None, io='out', # can be 'in', 'out', or 'in_to_out' @@ -22,9 +27,12 @@ def __init__(self, out_batch_size=8192, # size of batches in which to return activations device='cpu' # device on which to store the activations ): - + if activation_save_dirs is not None and activation_cache_dirs is not None: + raise ValueError("Cannot specify both activation_save_dirs and activation_cache_dirs because we cannot cache while using cached values. Choose one.") # dictionary of activations self.activations = [None for _ in submodules] + if activation_cache_dirs is not None: + self.file_iters = [iter(glob.glob(os.path.join(dir_path, '*.pt'))) for dir_path in (activation_cache_dirs)] for i, submodule in enumerate(submodules): if io == 'in': if in_feats is None: @@ -49,6 +57,8 @@ def __init__(self, self.data = data self.model = model # assumes nnsight model is already on the device self.submodules = submodules + self.activation_save_dirs = activation_save_dirs + self.activation_cache_dirs = activation_cache_dirs self.io = io self.n_ctxs = n_ctxs self.ctx_len = ctx_len @@ -63,6 +73,18 @@ def __next__(self): """ Return a batch of activations """ + if self.activation_cache_dirs is not None: + batch_activations = [] + for file_iter in self.file_iters: + try: + # Load next activation file from the current iterator + file_path = next(file_iter) + activations = t.load(file_path) + batch_activations.append(activations.to(self.device)) + except StopIteration: + # No more files to load, end of iteration + raise StopIteration + return batch_activations # if buffer is less than half full, refresh if (~self.read).sum() < self.n_ctxs * self.ctx_len // 2: self.refresh() @@ -71,7 +93,14 @@ def __next__(self): unreads = (~self.read).nonzero().squeeze() idxs = unreads[t.randperm(len(unreads), device=unreads.device)[:self.out_batch_size]] self.read[idxs] = True - return [self.activations[i][idxs] for i in range(len(self.activations))] + batch_activations = [self.activations[i][idxs] for i in range(len(self.activations))] + if self.activation_save_dirs is not None: + for i, (activations_batch, path) in enumerate(zip(batch_activations, self.activation_save_dirs)): + if path is not None: + filename = f"activations_{i}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.pt" + filepath = os.path.join(path, filename) + t.save(activations_batch.cpu(), filepath) + return batch_activations def text_batch(self, batch_size=None): """