Skip to content

Commit 59785e2

Browse files
committed
change parallel to use list instead of dictionary
This allows one to train multiple autoencoders off of a single layer since we aren't indexing a dictionary off the same thing. This could be useful for something like hyperparamater tuning where you only want to change one thing at a time. Here's an example: ``` submodules = [model.gpt_neox.layers[3].mlp, model.gpt_neox.layers[3].mlp, model.gpt_neox.layers[3].mlp] activation_dim = 512 # output dimension of the MLP dictionary_size = 16 * activation_dim learning_rates = [3e-4, 1e-3, 3e-3] ```
1 parent cc70650 commit 59785e2

File tree

2 files changed

+40
-46
lines changed

2 files changed

+40
-46
lines changed

buffer.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,23 @@ def __init__(self,
2424
):
2525

2626
# dictionary of activations
27-
self.activations = {}
28-
for submodule in submodules:
27+
self.activations = [None for _ in submodules]
28+
for i, submodule in enumerate(submodules):
2929
if io == 'in':
3030
if in_feats is None:
3131
try:
3232
in_feats = submodule.in_features
3333
except:
3434
raise ValueError("in_feats cannot be inferred and must be specified directly")
35-
self.activations[submodule] = t.empty(0, in_feats, device=device)
35+
self.activations[i] = t.empty(0, in_feats, device=device)
3636

3737
elif io == 'out':
3838
if out_feats is None:
3939
try:
4040
out_feats = submodule.out_features
4141
except:
4242
raise ValueError("out_feats cannot be inferred and must be specified directly")
43-
self.activations[submodule] = t.empty(0, out_feats, device=device)
43+
self.activations[i] = t.empty(0, out_feats, device=device)
4444
elif io == 'in_to_out':
4545
raise ValueError("Support for in_to_out is depricated")
4646
self.read = t.zeros(0, dtype=t.bool, device=device)
@@ -71,9 +71,7 @@ def __next__(self):
7171
unreads = (~self.read).nonzero().squeeze()
7272
idxs = unreads[t.randperm(len(unreads), device=unreads.device)[:self.out_batch_size]]
7373
self.read[idxs] = True
74-
return {
75-
submodule : activations[idxs] for submodule, activations in self.activations.items()
76-
}
74+
return [self.activations[i][idxs] for i in range(len(self.activations))]
7775

7876
def text_batch(self, batch_size=None):
7977
"""
@@ -102,34 +100,34 @@ def tokenized_batch(self, batch_size=None):
102100
)
103101

104102
def refresh(self):
105-
for submodule, activations in self.activations.items():
106-
self.activations[submodule] = activations[~self.read].contiguous()
103+
for i, activations in enumerate(self.activations):
104+
self.activations[i] = activations[~self.read].contiguous()
107105
self._n_activations = (~self.read).sum().item()
108106

109107
while self._n_activations < self.n_ctxs * self.ctx_len:
110108

111109
with self.model.invoke(self.text_batch(), truncation=True, max_length=self.ctx_len) as invoker:
112-
hidden_states = {}
113-
for submodule in self.submodules:
110+
hidden_states = [None for _ in self.submodules]
111+
for i, submodule in enumerate(self.submodules):
114112
if self.io == 'in':
115113
x = submodule.input
116114
else:
117115
x = submodule.output
118116
if (type(x.shape) == tuple):
119117
x = x[0]
120-
hidden_states[submodule] = x.save()
118+
hidden_states[i] = x.save()
121119

122120
attn_mask = invoker.input['attention_mask']
123121

124122
self._n_activations += (attn_mask != 0).sum().item()
125123

126-
for submodule, activations in self.activations.items():
127-
self.activations[submodule] = t.cat((
124+
for i, activations in enumerate(self.activations):
125+
self.activations[i] = t.cat((
128126
activations,
129-
hidden_states[submodule].value[attn_mask != 0].to(activations.device)),
127+
hidden_states[i].value[attn_mask != 0].to(activations.device)),
130128
dim=0
131129
)
132-
assert len(self.activations[submodule]) == self._n_activations
130+
assert len(self.activations[i]) == self._n_activations
133131

134132
self.read = t.zeros(self._n_activations, dtype=t.bool, device=self.device)
135133

training.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -149,17 +149,17 @@ def resample_neurons(deads, activations, ae, optimizer):
149149

150150
def trainSAE(
151151
buffer, # an ActivationBuffer
152-
activation_dims, # dictionary of activation dimensions for each submodule (or a single int)
153-
dictionary_sizes, # dictionary of dictionary sizes for each submodule (or a single int)
154-
lrs, # dictionary of learning rates for each submodule (or a single float)
152+
activation_dims, # list of activation dimensions for each submodule (or a single int)
153+
dictionary_sizes, # list of dictionary sizes for each submodule (or a single int)
154+
lrs, # list of learning rates for each submodule (or a single float)
155155
sparsity_penalty,
156156
entropy=False,
157157
steps=None, # if None, train until activations are exhausted
158158
warmup_steps=1000, # linearly increase the learning rate for this many steps
159159
resample_steps=None, # how often to resample dead neurons
160-
ghost_thresholds=None, # dictionary of how many steps a neuron has to be dead for it to turn into a ghost (or a single int)
160+
ghost_thresholds=None, # list of how many steps a neuron has to be dead for it to turn into a ghost (or a single int)
161161
save_steps=None, # how often to save checkpoints
162-
save_dirs=None, # dictionary of directories to save checkpoints to
162+
save_dirs=None, # list of directories to save checkpoints to
163163
checkpoint_offset=0, # if resuming training, the step number of the last checkpoint
164164
load_dirs=None, # if initializing from a pretrained dictionary, directories to load from
165165
log_steps=None, # how often to print statistics
@@ -168,49 +168,45 @@ def trainSAE(
168168
Train and return sparse autoencoders for each submodule in the buffer.
169169
"""
170170
if isinstance(activation_dims, int):
171-
activation_dims = {submodule: activation_dims for submodule in buffer.submodules}
171+
activation_dims = [activation_dims for submodule in buffer.submodules]
172172
if isinstance(dictionary_sizes, int):
173-
dictionary_sizes = {submodule: dictionary_sizes for submodule in buffer.submodules}
173+
dictionary_sizes = [dictionary_sizes for submodule in buffer.submodules]
174174
if isinstance(lrs, float):
175-
lrs = {submodule: lrs for submodule in buffer.submodules}
175+
lrs = [lrs for submodule in buffer.submodules]
176176
if isinstance(ghost_thresholds, int):
177-
ghost_thresholds = {submodule: ghost_thresholds for submodule in buffer.submodules}
177+
ghost_thresholds = [ghost_thresholds for submodule in buffer.submodules]
178178

179-
aes = {}
180-
num_samples_since_activateds = {}
181-
for submodule in buffer.submodules:
182-
ae = AutoEncoder(activation_dims[submodule], dictionary_sizes[submodule]).to(device)
179+
aes = [None for submodule in buffer.submodules]
180+
num_samples_since_activateds = [None for submodule in buffer.submodules]
181+
for i, submodule in enumerate(buffer.submodules):
182+
ae = AutoEncoder(activation_dims[i], dictionary_sizes[i]).to(device)
183183
if load_dirs is not None:
184-
ae.load_state_dict(t.load(os.path.join(load_dirs[submodule])))
185-
aes[submodule] = ae
186-
num_samples_since_activateds[submodule] = t.zeros(dictionary_sizes[submodule], dtype=int, device=device)
184+
ae.load_state_dict(t.load(os.path.join(load_dirs[i])))
185+
aes[i] = ae
186+
num_samples_since_activateds[i] = t.zeros(dictionary_sizes[i], dtype=int, device=device)
187187

188188
# set up optimizer and scheduler
189-
optimizers = {
190-
submodule: ConstrainedAdam(ae.parameters(), ae.decoder.parameters(), lr=lrs[submodule]) for submodule, ae in aes.items()
191-
}
189+
optimizers = [ConstrainedAdam(ae.parameters(), ae.decoder.parameters(), lr=lrs[i]) for i, ae in enumerate(aes)]
192190
if resample_steps is None:
193191
def warmup_fn(step):
194192
return min(step / warmup_steps, 1.)
195193
else:
196194
def warmup_fn(step):
197195
return min((step % resample_steps) / warmup_steps, 1.)
198196

199-
schedulers = {
200-
submodule: t.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn) for submodule, optimizer in optimizers.items()
201-
}
197+
schedulers = [t.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn) for optimizer in optimizers]
202198

203199
for step, acts in enumerate(tqdm(buffer, total=steps)):
204200
real_step = step + checkpoint_offset
205201
if steps is not None and real_step >= steps:
206202
break
207203

208-
for submodule, act in acts.items():
204+
for i, act in enumerate(acts):
209205
act = act.to(device)
210206
ae, num_samples_since_activated, optimizer, scheduler \
211-
= aes[submodule], num_samples_since_activateds[submodule], optimizers[submodule], schedulers[submodule]
207+
= aes[i], num_samples_since_activateds[i], optimizers[i], schedulers[i]
212208
optimizer.zero_grad()
213-
loss = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_thresholds[submodule])
209+
loss = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_thresholds[i])
214210
loss.backward()
215211
optimizer.step()
216212
scheduler.step()
@@ -223,8 +219,8 @@ def warmup_fn(step):
223219
# logging
224220
if log_steps is not None and step % log_steps == 0:
225221
with t.no_grad():
226-
losses = sae_loss(acts, ae, sparsity_penalty, entropy, separate=True, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_threshold)
227-
if ghost_threshold is None:
222+
losses = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_thresholds[i], separate=True)
223+
if ghost_thresholds is None:
228224
mse_loss, sparsity_loss = losses
229225
print(f"step {step} MSE loss: {mse_loss}, sparsity loss: {sparsity_loss}")
230226
else:
@@ -239,11 +235,11 @@ def warmup_fn(step):
239235

240236
# saving
241237
if save_steps is not None and save_dirs is not None and real_step % save_steps == 0:
242-
if not os.path.exists(os.path.join(save_dirs[submodule], "checkpoints")):
243-
os.mkdir(os.path.join(save_dirs[submodule], "checkpoints"))
238+
if not os.path.exists(os.path.join(save_dirs[i], "checkpoints")):
239+
os.mkdir(os.path.join(save_dirs[i], "checkpoints"))
244240
t.save(
245241
ae.state_dict(),
246-
os.path.join(save_dirs[submodule], "checkpoints", f"ae_{real_step}.pt")
242+
os.path.join(save_dirs[i], "checkpoints", f"ae_{real_step}.pt")
247243
)
248244

249245
return aes

0 commit comments

Comments
 (0)