@@ -149,17 +149,17 @@ def resample_neurons(deads, activations, ae, optimizer):
149149
150150def trainSAE (
151151 buffer , # an ActivationBuffer
152- activation_dims , # dictionary of activation dimensions for each submodule (or a single int)
153- dictionary_sizes , # dictionary of dictionary sizes for each submodule (or a single int)
154- lrs , # dictionary of learning rates for each submodule (or a single float)
152+ activation_dims , # list of activation dimensions for each submodule (or a single int)
153+ dictionary_sizes , # list of dictionary sizes for each submodule (or a single int)
154+ lrs , # list of learning rates for each submodule (or a single float)
155155 sparsity_penalty ,
156156 entropy = False ,
157157 steps = None , # if None, train until activations are exhausted
158158 warmup_steps = 1000 , # linearly increase the learning rate for this many steps
159159 resample_steps = None , # how often to resample dead neurons
160- ghost_thresholds = None , # dictionary of how many steps a neuron has to be dead for it to turn into a ghost (or a single int)
160+ ghost_thresholds = None , # list of how many steps a neuron has to be dead for it to turn into a ghost (or a single int)
161161 save_steps = None , # how often to save checkpoints
162- save_dirs = None , # dictionary of directories to save checkpoints to
162+ save_dirs = None , # list of directories to save checkpoints to
163163 checkpoint_offset = 0 , # if resuming training, the step number of the last checkpoint
164164 load_dirs = None , # if initializing from a pretrained dictionary, directories to load from
165165 log_steps = None , # how often to print statistics
@@ -168,49 +168,45 @@ def trainSAE(
168168 Train and return sparse autoencoders for each submodule in the buffer.
169169 """
170170 if isinstance (activation_dims , int ):
171- activation_dims = { submodule : activation_dims for submodule in buffer .submodules }
171+ activation_dims = [ activation_dims for submodule in buffer .submodules ]
172172 if isinstance (dictionary_sizes , int ):
173- dictionary_sizes = { submodule : dictionary_sizes for submodule in buffer .submodules }
173+ dictionary_sizes = [ dictionary_sizes for submodule in buffer .submodules ]
174174 if isinstance (lrs , float ):
175- lrs = { submodule : lrs for submodule in buffer .submodules }
175+ lrs = [ lrs for submodule in buffer .submodules ]
176176 if isinstance (ghost_thresholds , int ):
177- ghost_thresholds = { submodule : ghost_thresholds for submodule in buffer .submodules }
177+ ghost_thresholds = [ ghost_thresholds for submodule in buffer .submodules ]
178178
179- aes = {}
180- num_samples_since_activateds = {}
181- for submodule in buffer .submodules :
182- ae = AutoEncoder (activation_dims [submodule ], dictionary_sizes [submodule ]).to (device )
179+ aes = [ None for submodule in buffer . submodules ]
180+ num_samples_since_activateds = [ None for submodule in buffer . submodules ]
181+ for i , submodule in enumerate ( buffer .submodules ) :
182+ ae = AutoEncoder (activation_dims [i ], dictionary_sizes [i ]).to (device )
183183 if load_dirs is not None :
184- ae .load_state_dict (t .load (os .path .join (load_dirs [submodule ])))
185- aes [submodule ] = ae
186- num_samples_since_activateds [submodule ] = t .zeros (dictionary_sizes [submodule ], dtype = int , device = device )
184+ ae .load_state_dict (t .load (os .path .join (load_dirs [i ])))
185+ aes [i ] = ae
186+ num_samples_since_activateds [i ] = t .zeros (dictionary_sizes [i ], dtype = int , device = device )
187187
188188 # set up optimizer and scheduler
189- optimizers = {
190- submodule : ConstrainedAdam (ae .parameters (), ae .decoder .parameters (), lr = lrs [submodule ]) for submodule , ae in aes .items ()
191- }
189+ optimizers = [ConstrainedAdam (ae .parameters (), ae .decoder .parameters (), lr = lrs [i ]) for i , ae in enumerate (aes )]
192190 if resample_steps is None :
193191 def warmup_fn (step ):
194192 return min (step / warmup_steps , 1. )
195193 else :
196194 def warmup_fn (step ):
197195 return min ((step % resample_steps ) / warmup_steps , 1. )
198196
199- schedulers = {
200- submodule : t .optim .lr_scheduler .LambdaLR (optimizer , lr_lambda = warmup_fn ) for submodule , optimizer in optimizers .items ()
201- }
197+ schedulers = [t .optim .lr_scheduler .LambdaLR (optimizer , lr_lambda = warmup_fn ) for optimizer in optimizers ]
202198
203199 for step , acts in enumerate (tqdm (buffer , total = steps )):
204200 real_step = step + checkpoint_offset
205201 if steps is not None and real_step >= steps :
206202 break
207203
208- for submodule , act in acts . items ( ):
204+ for i , act in enumerate ( acts ):
209205 act = act .to (device )
210206 ae , num_samples_since_activated , optimizer , scheduler \
211- = aes [submodule ], num_samples_since_activateds [submodule ], optimizers [submodule ], schedulers [submodule ]
207+ = aes [i ], num_samples_since_activateds [i ], optimizers [i ], schedulers [i ]
212208 optimizer .zero_grad ()
213- loss = sae_loss (act , ae , sparsity_penalty , use_entropy = entropy , num_samples_since_activated = num_samples_since_activated , ghost_threshold = ghost_thresholds [submodule ])
209+ loss = sae_loss (act , ae , sparsity_penalty , use_entropy = entropy , num_samples_since_activated = num_samples_since_activated , ghost_threshold = ghost_thresholds [i ])
214210 loss .backward ()
215211 optimizer .step ()
216212 scheduler .step ()
@@ -223,8 +219,8 @@ def warmup_fn(step):
223219 # logging
224220 if log_steps is not None and step % log_steps == 0 :
225221 with t .no_grad ():
226- losses = sae_loss (acts , ae , sparsity_penalty , entropy , separate = True , num_samples_since_activated = num_samples_since_activated , ghost_threshold = ghost_threshold )
227- if ghost_threshold is None :
222+ losses = sae_loss (act , ae , sparsity_penalty , use_entropy = entropy , num_samples_since_activated = num_samples_since_activated , ghost_threshold = ghost_thresholds [ i ], separate = True )
223+ if ghost_thresholds is None :
228224 mse_loss , sparsity_loss = losses
229225 print (f"step { step } MSE loss: { mse_loss } , sparsity loss: { sparsity_loss } " )
230226 else :
@@ -239,11 +235,11 @@ def warmup_fn(step):
239235
240236 # saving
241237 if save_steps is not None and save_dirs is not None and real_step % save_steps == 0 :
242- if not os .path .exists (os .path .join (save_dirs [submodule ], "checkpoints" )):
243- os .mkdir (os .path .join (save_dirs [submodule ], "checkpoints" ))
238+ if not os .path .exists (os .path .join (save_dirs [i ], "checkpoints" )):
239+ os .mkdir (os .path .join (save_dirs [i ], "checkpoints" ))
244240 t .save (
245241 ae .state_dict (),
246- os .path .join (save_dirs [submodule ], "checkpoints" , f"ae_{ real_step } .pt" )
242+ os .path .join (save_dirs [i ], "checkpoints" , f"ae_{ real_step } .pt" )
247243 )
248244
249245 return aes
0 commit comments