Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 44 additions & 17 deletions buffer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import torch as t
import zstandard as zstd
import glob
from datetime import datetime
import os
import json
import io
from nnsight import LanguageModel
Expand All @@ -13,6 +16,8 @@ def __init__(self,
data, # generator which yields text data
model, # LanguageModel from which to extract activations
submodules, # submodule of the model from which to extract activations
activation_save_dirs=None, # paths to save cached activations, one per submodule; if an individual path is None, do not cache for that submodule
activation_cache_dirs=None, # directories with cached activations to load
in_feats=None,
out_feats=None,
io='out', # can be 'in', 'out', or 'in_to_out'
Expand All @@ -22,25 +27,28 @@ def __init__(self,
out_batch_size=8192, # size of batches in which to return activations
device='cpu' # device on which to store the activations
):

if activation_save_dirs is not None and activation_cache_dirs is not None:
raise ValueError("Cannot specify both activation_save_dirs and activation_cache_dirs because we cannot cache while using cached values. Choose one.")
# dictionary of activations
self.activations = {}
for submodule in submodules:
self.activations = [None for _ in submodules]
if activation_cache_dirs is not None:
self.file_iters = [iter(glob.glob(os.path.join(dir_path, '*.pt'))) for dir_path in (activation_cache_dirs)]
for i, submodule in enumerate(submodules):
if io == 'in':
if in_feats is None:
try:
in_feats = submodule.in_features
except:
raise ValueError("in_feats cannot be inferred and must be specified directly")
self.activations[submodule] = t.empty(0, in_feats, device=device)
self.activations[i] = t.empty(0, in_feats, device=device)

elif io == 'out':
if out_feats is None:
try:
out_feats = submodule.out_features
except:
raise ValueError("out_feats cannot be inferred and must be specified directly")
self.activations[submodule] = t.empty(0, out_feats, device=device)
self.activations[i] = t.empty(0, out_feats, device=device)
elif io == 'in_to_out':
raise ValueError("Support for in_to_out is depricated")
self.read = t.zeros(0, dtype=t.bool, device=device)
Expand All @@ -49,6 +57,8 @@ def __init__(self,
self.data = data
self.model = model # assumes nnsight model is already on the device
self.submodules = submodules
self.activation_save_dirs = activation_save_dirs
self.activation_cache_dirs = activation_cache_dirs
self.io = io
self.n_ctxs = n_ctxs
self.ctx_len = ctx_len
Expand All @@ -63,6 +73,18 @@ def __next__(self):
"""
Return a batch of activations
"""
if self.activation_cache_dirs is not None:
batch_activations = []
for file_iter in self.file_iters:
try:
# Load next activation file from the current iterator
file_path = next(file_iter)
activations = t.load(file_path)
batch_activations.append(activations.to(self.device))
except StopIteration:
# No more files to load, end of iteration
raise StopIteration
return batch_activations
# if buffer is less than half full, refresh
if (~self.read).sum() < self.n_ctxs * self.ctx_len // 2:
self.refresh()
Expand All @@ -71,9 +93,14 @@ def __next__(self):
unreads = (~self.read).nonzero().squeeze()
idxs = unreads[t.randperm(len(unreads), device=unreads.device)[:self.out_batch_size]]
self.read[idxs] = True
return {
submodule : activations[idxs] for submodule, activations in self.activations.items()
}
batch_activations = [self.activations[i][idxs] for i in range(len(self.activations))]
if self.activation_save_dirs is not None:
for i, (activations_batch, path) in enumerate(zip(batch_activations, self.activation_save_dirs)):
if path is not None:
filename = f"activations_{i}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.pt"
filepath = os.path.join(path, filename)
t.save(activations_batch.cpu(), filepath)
return batch_activations

def text_batch(self, batch_size=None):
"""
Expand Down Expand Up @@ -102,34 +129,34 @@ def tokenized_batch(self, batch_size=None):
)

def refresh(self):
for submodule, activations in self.activations.items():
self.activations[submodule] = activations[~self.read].contiguous()
for i, activations in enumerate(self.activations):
self.activations[i] = activations[~self.read].contiguous()
self._n_activations = (~self.read).sum().item()

while self._n_activations < self.n_ctxs * self.ctx_len:

with self.model.invoke(self.text_batch(), truncation=True, max_length=self.ctx_len) as invoker:
hidden_states = {}
for submodule in self.submodules:
hidden_states = [None for _ in self.submodules]
for i, submodule in enumerate(self.submodules):
if self.io == 'in':
x = submodule.input
else:
x = submodule.output
if (type(x.shape) == tuple):
x = x[0]
hidden_states[submodule] = x.save()
hidden_states[i] = x.save()

attn_mask = invoker.input['attention_mask']

self._n_activations += (attn_mask != 0).sum().item()

for submodule, activations in self.activations.items():
self.activations[submodule] = t.cat((
for i, activations in enumerate(self.activations):
self.activations[i] = t.cat((
activations,
hidden_states[submodule].value[attn_mask != 0].to(activations.device)),
hidden_states[i].value[attn_mask != 0].to(activations.device)),
dim=0
)
assert len(self.activations[submodule]) == self._n_activations
assert len(self.activations[i]) == self._n_activations

self.read = t.zeros(self._n_activations, dtype=t.bool, device=self.device)

Expand Down
60 changes: 30 additions & 30 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,17 @@ def resample_neurons(deads, activations, ae, optimizer):

def trainSAE(
buffer, # an ActivationBuffer
activation_dims, # dictionary of activation dimensions for each submodule (or a single int)
dictionary_sizes, # dictionary of dictionary sizes for each submodule (or a single int)
lr,
activation_dims, # list of activation dimensions for each submodule (or a single int)
dictionary_sizes, # list of dictionary sizes for each submodule (or a single int)
lrs, # list of learning rates for each submodule (or a single float)
sparsity_penalty,
entropy=False,
steps=None, # if None, train until activations are exhausted
warmup_steps=1000, # linearly increase the learning rate for this many steps
resample_steps=None, # how often to resample dead neurons
ghost_threshold=None, # how many steps a neuron has to be dead for it to turn into a ghost
ghost_thresholds=None, # list of how many steps a neuron has to be dead for it to turn into a ghost (or a single int)
save_steps=None, # how often to save checkpoints
save_dirs=None, # dictionary of directories to save checkpoints to
save_dirs=None, # list of directories to save checkpoints to
checkpoint_offset=0, # if resuming training, the step number of the last checkpoint
load_dirs=None, # if initializing from a pretrained dictionary, directories to load from
log_steps=None, # how often to print statistics
Expand All @@ -167,45 +167,45 @@ def trainSAE(
Train and return sparse autoencoders for each submodule in the buffer.
"""
if isinstance(activation_dims, int):
activation_dims = {submodule: activation_dims for submodule in buffer.submodules}
activation_dims = [activation_dims for submodule in buffer.submodules]
if isinstance(dictionary_sizes, int):
dictionary_sizes = {submodule: dictionary_sizes for submodule in buffer.submodules}

aes = {}
num_samples_since_activateds = {}
for submodule in buffer.submodules:
ae = AutoEncoder(activation_dims[submodule], dictionary_sizes[submodule]).to(device)
dictionary_sizes = [dictionary_sizes for submodule in buffer.submodules]
if isinstance(lrs, float):
lrs = [lrs for submodule in buffer.submodules]
if isinstance(ghost_thresholds, int):
ghost_thresholds = [ghost_thresholds for submodule in buffer.submodules]

aes = [None for submodule in buffer.submodules]
num_samples_since_activateds = [None for submodule in buffer.submodules]
for i, submodule in enumerate(buffer.submodules):
ae = AutoEncoder(activation_dims[i], dictionary_sizes[i]).to(device)
if load_dirs is not None:
ae.load_state_dict(t.load(os.path.join(load_dirs[submodule])))
aes[submodule] = ae
num_samples_since_activateds[submodule] = t.zeros(dictionary_sizes[submodule], dtype=int, device=device)
ae.load_state_dict(t.load(os.path.join(load_dirs[i])))
aes[i] = ae
num_samples_since_activateds[i] = t.zeros(dictionary_sizes[i], dtype=int, device=device)

# set up optimizer and scheduler
optimizers = {
submodule: ConstrainedAdam(ae.parameters(), ae.decoder.parameters(), lr=lr) for submodule, ae in aes.items()
}
optimizers = [ConstrainedAdam(ae.parameters(), ae.decoder.parameters(), lr=lrs[i]) for i, ae in enumerate(aes)]
if resample_steps is None:
def warmup_fn(step):
return min(step / warmup_steps, 1.)
else:
def warmup_fn(step):
return min((step % resample_steps) / warmup_steps, 1.)

schedulers = {
submodule: t.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn) for submodule, optimizer in optimizers.items()
}
schedulers = [t.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn) for optimizer in optimizers]

for step, acts in enumerate(tqdm(buffer, total=steps)):
real_step = step + checkpoint_offset
if steps is not None and real_step >= steps:
break

for submodule, act in acts.items():
for i, act in enumerate(acts):
act = act.to(device)
ae, num_samples_since_activated, optimizer, scheduler \
= aes[submodule], num_samples_since_activateds[submodule], optimizers[submodule], schedulers[submodule]
= aes[i], num_samples_since_activateds[i], optimizers[i], schedulers[i]
optimizer.zero_grad()
loss = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_threshold)
loss = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_thresholds[i])
loss.backward()
optimizer.step()
scheduler.step()
Expand All @@ -218,8 +218,8 @@ def warmup_fn(step):
# logging
if log_steps is not None and step % log_steps == 0:
with t.no_grad():
losses = sae_loss(act, ae, sparsity_penalty, entropy, separate=True, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_threshold)
if ghost_threshold is None:
losses = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_thresholds[i], separate=True)
if ghost_thresholds is None:
mse_loss, sparsity_loss = losses
print(f"step {step} MSE loss: {mse_loss}, sparsity loss: {sparsity_loss}")
else:
Expand All @@ -234,11 +234,11 @@ def warmup_fn(step):

# saving
if save_steps is not None and save_dirs is not None and real_step % save_steps == 0:
if not os.path.exists(os.path.join(save_dirs[submodule], "checkpoints")):
os.mkdir(os.path.join(save_dirs[submodule], "checkpoints"))
if not os.path.exists(os.path.join(save_dirs[i], "checkpoints")):
os.mkdir(os.path.join(save_dirs[i], "checkpoints"))
t.save(
ae.state_dict(),
os.path.join(save_dirs[submodule], "checkpoints", f"ae_{real_step}.pt")
os.path.join(save_dirs[i], "checkpoints", f"ae_{real_step}.pt")
)

return aes
return aes