From f0bb66d1c25bcb7dc8df62d8dbc3bfd47d26b14c Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 27 Dec 2024 05:00:16 +0000 Subject: [PATCH 1/3] Track lr decay implementation --- trainers/standard_lr_decay.py | 216 ++++++++++++++++++++++++++++++++++ 1 file changed, 216 insertions(+) create mode 100644 trainers/standard_lr_decay.py diff --git a/trainers/standard_lr_decay.py b/trainers/standard_lr_decay.py new file mode 100644 index 0000000..1119ab5 --- /dev/null +++ b/trainers/standard_lr_decay.py @@ -0,0 +1,216 @@ +""" +Implements the standard SAE training scheme. +""" +import torch as t +from typing import Optional + +from ..trainers.trainer import SAETrainer +from ..config import DEBUG +from ..dictionary import AutoEncoder +from collections import namedtuple + +class ConstrainedAdam(t.optim.Adam): + """ + A variant of Adam where some of the parameters are constrained to have unit norm. + """ + def __init__(self, params, constrained_params, lr): + super().__init__(params, lr=lr) + self.constrained_params = list(constrained_params) + + def step(self, closure=None): + with t.no_grad(): + for p in self.constrained_params: + normed_p = p / p.norm(dim=0, keepdim=True) + # project away the parallel component of the gradient + p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p + super().step(closure=closure) + with t.no_grad(): + for p in self.constrained_params: + # renormalize the constrained parameters + p /= p.norm(dim=0, keepdim=True) + +class StandardTrainer(SAETrainer): + """ + Standard SAE training scheme. + """ + def __init__(self, + dict_class=AutoEncoder, + activation_dim:int=512, + dict_size:int=64*512, + lr:float=1e-3, + l1_penalty:float=1e-1, + warmup_steps:int=1000, # lr warmup period at start of training and after each resample + sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training + lr_decay_steps_fraction:Optional[float]=0.2, + final_lr_fraction:Optional[float]=0.1, + steps: Optional[int]=None, # total of steps to train for + resample_steps:Optional[int]=None, # how often to resample neurons + seed:Optional[int]=None, + device=None, + layer:Optional[int]=None, + lm_name:Optional[str]=None, + wandb_name:Optional[str]='StandardTrainer', + submodule_name:Optional[str]=None, + ): + """Options: + warump_steps: LR linear warmup period at start of training and after each resample + sparsity_warmup_steps: Sparsity linear warmup period at start of training + lr_decay_steps_fraction: LR linear decay for the last fraction of training""" + super().__init__(seed) + + assert layer is not None and lm_name is not None + self.layer = layer + self.lm_name = lm_name + self.submodule_name = submodule_name + + if seed is not None: + t.manual_seed(seed) + t.cuda.manual_seed_all(seed) + + # initialize dictionary + self.ae = dict_class(activation_dim, dict_size) + + self.lr = lr + self.l1_penalty=l1_penalty + self.warmup_steps = warmup_steps + self.wandb_name = wandb_name + + if device is None: + self.device = 'cuda' if t.cuda.is_available() else 'cpu' + else: + self.device = device + self.ae.to(self.device) + + if lr_decay_steps_fraction is not None: + assert steps is not None, "total number of steps must be specified for lr decay" + assert resample_steps is None, "lr decay not implemented for resampling" + assert lr_decay_steps_fraction < 1 and lr_decay_steps_fraction > 0 + assert final_lr_fraction <= 1 and final_lr_fraction >= 0 + + self.steps = steps + self.sparsity_warmup_steps = sparsity_warmup_steps + self.lr_decay_steps_fraction = lr_decay_steps_fraction + self.final_lr_fraction = final_lr_fraction + + self.resample_steps = resample_steps + if self.resample_steps is not None: + # how many steps since each neuron was last activated? + self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device) + else: + self.steps_since_active = None + + self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) + if resample_steps is None: + def warmup_fn(step): + warmup_scale = min(step / warmup_steps, 1.) + + if self.lr_decay_steps_fraction is not None: + cooldown_start = self.steps * (1 - self.lr_decay_steps_fraction) + if step >= cooldown_start: + cooldown = 1.0 + (self.final_lr_fraction - 1.0) * (step - cooldown_start) / (self.steps - cooldown_start) + return max(cooldown, self.final_lr_fraction) + return warmup_scale + else: + def warmup_fn(step): + return min((step % resample_steps) / warmup_steps, 1.) + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) + + def resample_neurons(self, deads, activations): + with t.no_grad(): + if deads.sum() == 0: return + print(f"resampling {deads.sum().item()} neurons") + + # compute loss for each activation + losses = (activations - self.ae(activations)).norm(dim=-1) + + # sample input to create encoder/decoder weights from + n_resample = min([deads.sum(), losses.shape[0]]) + indices = t.multinomial(losses, num_samples=n_resample, replacement=False) + sampled_vecs = activations[indices] + + # get norm of the living neurons + alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() + + # resample first n_resample dead neurons + deads[deads.nonzero()[n_resample:]] = False + self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2 + self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T + self.ae.encoder.bias[deads] = 0. + + + # reset Adam parameters for dead neurons + state_dict = self.optimizer.state_dict()['state'] + ## encoder weight + state_dict[1]['exp_avg'][deads] = 0. + state_dict[1]['exp_avg_sq'][deads] = 0. + ## encoder bias + state_dict[2]['exp_avg'][deads] = 0. + state_dict[2]['exp_avg_sq'][deads] = 0. + ## decoder weight + state_dict[3]['exp_avg'][:,deads] = 0. + state_dict[3]['exp_avg_sq'][:,deads] = 0. + + def loss(self, x, step: int, logging=False, **kwargs): + + if self.sparsity_warmup_steps is not None: + sparsity_scale = min(step / self.sparsity_warmup_steps, 1.) + else: + sparsity_scale = 1. + + x_hat, f = self.ae(x, output_features=True) + l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() + recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() + l1_loss = f.norm(p=1, dim=-1).mean() + + if self.steps_since_active is not None: + # update steps_since_active + deads = (f == 0).all(dim=0) + self.steps_since_active[deads] += 1 + self.steps_since_active[~deads] = 0 + + loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss + + if not logging: + return loss + else: + return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( + x, x_hat, f, + { + 'l2_loss' : l2_loss.item(), + 'mse_loss' : recon_loss.item(), + 'sparsity_loss' : l1_loss.item(), + 'loss' : loss.item() + } + ) + + + def update(self, step, activations): + activations = activations.to(self.device) + + self.optimizer.zero_grad() + loss = self.loss(activations, step=step) + loss.backward() + self.optimizer.step() + self.scheduler.step() + + if self.resample_steps is not None and step % self.resample_steps == 0: + self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) + + @property + def config(self): + return { + 'dict_class': 'AutoEncoder', + 'trainer_class' : 'StandardTrainer', + 'activation_dim': self.ae.activation_dim, + 'dict_size': self.ae.dict_size, + 'lr' : self.lr, + 'l1_penalty' : self.l1_penalty, + 'warmup_steps' : self.warmup_steps, + 'resample_steps' : self.resample_steps, + 'device' : self.device, + 'layer' : self.layer, + 'lm_name' : self.lm_name, + 'wandb_name': self.wandb_name, + 'submodule_name': self.submodule_name, + } + From e0db40b8fadcdd1e24c1945829ecd4eb57451fa8 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 27 Dec 2024 05:00:34 +0000 Subject: [PATCH 2/3] Clean up lr decay --- trainers/standard_lr_decay.py | 216 ---------------------------------- 1 file changed, 216 deletions(-) delete mode 100644 trainers/standard_lr_decay.py diff --git a/trainers/standard_lr_decay.py b/trainers/standard_lr_decay.py deleted file mode 100644 index 1119ab5..0000000 --- a/trainers/standard_lr_decay.py +++ /dev/null @@ -1,216 +0,0 @@ -""" -Implements the standard SAE training scheme. -""" -import torch as t -from typing import Optional - -from ..trainers.trainer import SAETrainer -from ..config import DEBUG -from ..dictionary import AutoEncoder -from collections import namedtuple - -class ConstrainedAdam(t.optim.Adam): - """ - A variant of Adam where some of the parameters are constrained to have unit norm. - """ - def __init__(self, params, constrained_params, lr): - super().__init__(params, lr=lr) - self.constrained_params = list(constrained_params) - - def step(self, closure=None): - with t.no_grad(): - for p in self.constrained_params: - normed_p = p / p.norm(dim=0, keepdim=True) - # project away the parallel component of the gradient - p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p - super().step(closure=closure) - with t.no_grad(): - for p in self.constrained_params: - # renormalize the constrained parameters - p /= p.norm(dim=0, keepdim=True) - -class StandardTrainer(SAETrainer): - """ - Standard SAE training scheme. - """ - def __init__(self, - dict_class=AutoEncoder, - activation_dim:int=512, - dict_size:int=64*512, - lr:float=1e-3, - l1_penalty:float=1e-1, - warmup_steps:int=1000, # lr warmup period at start of training and after each resample - sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training - lr_decay_steps_fraction:Optional[float]=0.2, - final_lr_fraction:Optional[float]=0.1, - steps: Optional[int]=None, # total of steps to train for - resample_steps:Optional[int]=None, # how often to resample neurons - seed:Optional[int]=None, - device=None, - layer:Optional[int]=None, - lm_name:Optional[str]=None, - wandb_name:Optional[str]='StandardTrainer', - submodule_name:Optional[str]=None, - ): - """Options: - warump_steps: LR linear warmup period at start of training and after each resample - sparsity_warmup_steps: Sparsity linear warmup period at start of training - lr_decay_steps_fraction: LR linear decay for the last fraction of training""" - super().__init__(seed) - - assert layer is not None and lm_name is not None - self.layer = layer - self.lm_name = lm_name - self.submodule_name = submodule_name - - if seed is not None: - t.manual_seed(seed) - t.cuda.manual_seed_all(seed) - - # initialize dictionary - self.ae = dict_class(activation_dim, dict_size) - - self.lr = lr - self.l1_penalty=l1_penalty - self.warmup_steps = warmup_steps - self.wandb_name = wandb_name - - if device is None: - self.device = 'cuda' if t.cuda.is_available() else 'cpu' - else: - self.device = device - self.ae.to(self.device) - - if lr_decay_steps_fraction is not None: - assert steps is not None, "total number of steps must be specified for lr decay" - assert resample_steps is None, "lr decay not implemented for resampling" - assert lr_decay_steps_fraction < 1 and lr_decay_steps_fraction > 0 - assert final_lr_fraction <= 1 and final_lr_fraction >= 0 - - self.steps = steps - self.sparsity_warmup_steps = sparsity_warmup_steps - self.lr_decay_steps_fraction = lr_decay_steps_fraction - self.final_lr_fraction = final_lr_fraction - - self.resample_steps = resample_steps - if self.resample_steps is not None: - # how many steps since each neuron was last activated? - self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device) - else: - self.steps_since_active = None - - self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) - if resample_steps is None: - def warmup_fn(step): - warmup_scale = min(step / warmup_steps, 1.) - - if self.lr_decay_steps_fraction is not None: - cooldown_start = self.steps * (1 - self.lr_decay_steps_fraction) - if step >= cooldown_start: - cooldown = 1.0 + (self.final_lr_fraction - 1.0) * (step - cooldown_start) / (self.steps - cooldown_start) - return max(cooldown, self.final_lr_fraction) - return warmup_scale - else: - def warmup_fn(step): - return min((step % resample_steps) / warmup_steps, 1.) - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) - - def resample_neurons(self, deads, activations): - with t.no_grad(): - if deads.sum() == 0: return - print(f"resampling {deads.sum().item()} neurons") - - # compute loss for each activation - losses = (activations - self.ae(activations)).norm(dim=-1) - - # sample input to create encoder/decoder weights from - n_resample = min([deads.sum(), losses.shape[0]]) - indices = t.multinomial(losses, num_samples=n_resample, replacement=False) - sampled_vecs = activations[indices] - - # get norm of the living neurons - alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() - - # resample first n_resample dead neurons - deads[deads.nonzero()[n_resample:]] = False - self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2 - self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T - self.ae.encoder.bias[deads] = 0. - - - # reset Adam parameters for dead neurons - state_dict = self.optimizer.state_dict()['state'] - ## encoder weight - state_dict[1]['exp_avg'][deads] = 0. - state_dict[1]['exp_avg_sq'][deads] = 0. - ## encoder bias - state_dict[2]['exp_avg'][deads] = 0. - state_dict[2]['exp_avg_sq'][deads] = 0. - ## decoder weight - state_dict[3]['exp_avg'][:,deads] = 0. - state_dict[3]['exp_avg_sq'][:,deads] = 0. - - def loss(self, x, step: int, logging=False, **kwargs): - - if self.sparsity_warmup_steps is not None: - sparsity_scale = min(step / self.sparsity_warmup_steps, 1.) - else: - sparsity_scale = 1. - - x_hat, f = self.ae(x, output_features=True) - l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() - recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() - l1_loss = f.norm(p=1, dim=-1).mean() - - if self.steps_since_active is not None: - # update steps_since_active - deads = (f == 0).all(dim=0) - self.steps_since_active[deads] += 1 - self.steps_since_active[~deads] = 0 - - loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss - - if not logging: - return loss - else: - return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( - x, x_hat, f, - { - 'l2_loss' : l2_loss.item(), - 'mse_loss' : recon_loss.item(), - 'sparsity_loss' : l1_loss.item(), - 'loss' : loss.item() - } - ) - - - def update(self, step, activations): - activations = activations.to(self.device) - - self.optimizer.zero_grad() - loss = self.loss(activations, step=step) - loss.backward() - self.optimizer.step() - self.scheduler.step() - - if self.resample_steps is not None and step % self.resample_steps == 0: - self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) - - @property - def config(self): - return { - 'dict_class': 'AutoEncoder', - 'trainer_class' : 'StandardTrainer', - 'activation_dim': self.ae.activation_dim, - 'dict_size': self.ae.dict_size, - 'lr' : self.lr, - 'l1_penalty' : self.l1_penalty, - 'warmup_steps' : self.warmup_steps, - 'resample_steps' : self.resample_steps, - 'device' : self.device, - 'layer' : self.layer, - 'lm_name' : self.lm_name, - 'wandb_name': self.wandb_name, - 'submodule_name': self.submodule_name, - } - From 911b95890e20998df92710a01d158f4663d6834b Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 27 Dec 2024 05:01:39 +0000 Subject: [PATCH 3/3] Add sparsity warmup for trainers with a sparsity penalty --- tests/test_end_to_end.py | 1 + trainers/gdm.py | 43 +++++++++++++++++++++++++--------------- trainers/jumprelu.py | 40 +++++++++++++++++++++++-------------- trainers/standard.py | 43 +++++++++++++++++++++++++--------------- 4 files changed, 80 insertions(+), 47 deletions(-) diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 31cb314..8aa6cfc 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -149,6 +149,7 @@ def test_sae_training(): "lr": learning_rate, "l1_penalty": sparsity_penalty, "warmup_steps": warmup_steps, + "sparsity_warmup_steps": None, "resample_steps": resample_steps, "seed": RANDOM_SEED, "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", diff --git a/trainers/gdm.py b/trainers/gdm.py index 47ea772..792e64e 100644 --- a/trainers/gdm.py +++ b/trainers/gdm.py @@ -3,6 +3,8 @@ """ import torch as t +from typing import Optional + from ..trainers.trainer import SAETrainer from ..config import DEBUG from ..dictionary import GatedAutoEncoder @@ -33,19 +35,19 @@ class GatedSAETrainer(SAETrainer): Gated SAE training scheme. """ def __init__(self, - dict_class=GatedAutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=5e-5, - l1_penalty=1e-1, - warmup_steps=1000, # lr warmup period at start of training and after each resample - resample_steps=None, # how often to resample neurons - seed=None, - device=None, - layer=None, - lm_name=None, - wandb_name='GatedSAETrainer', - submodule_name=None, + dict_class = GatedAutoEncoder, + activation_dim: int = 512, + dict_size: int = 64*512, + lr: float = 5e-5, + l1_penalty: float = 1e-1, + warmup_steps: int = 1000, # lr warmup period at start of training and after each resample + sparsity_warmup_steps: int = 2000, + seed: Optional[int] = None, + device: Optional[str] = None, + layer: Optional[int] = None, + lm_name: Optional[str] = None, + wandb_name: Optional[str] = 'GatedSAETrainer', + submodule_name: Optional[str] = None, ): super().__init__(seed) @@ -64,6 +66,7 @@ def __init__(self, self.lr = lr self.l1_penalty=l1_penalty self.warmup_steps = warmup_steps + self.sparsity_warmup_steps = sparsity_warmup_steps self.wandb_name = wandb_name if device is None: @@ -81,7 +84,13 @@ def warmup_fn(step): return min(1, step / warmup_steps) self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_fn) - def loss(self, x, logging=False, **kwargs): + def loss(self, x:t.Tensor, step:int, logging:bool=False, **kwargs): + + if self.sparsity_warmup_steps is not None: + sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) + else: + sparsity_scale = 1.0 + f, f_gate = self.ae.encode(x, return_gate=True) x_hat = self.ae.decode(f) x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach() @@ -90,7 +99,7 @@ def loss(self, x, logging=False, **kwargs): L_sparse = t.linalg.norm(f_gate, ord=1, dim=-1).mean() L_aux = (x - x_hat_gate).pow(2).sum(dim=-1).mean() - loss = L_recon + self.l1_penalty * L_sparse + L_aux + loss = L_recon + (self.l1_penalty * L_sparse * sparsity_scale) + L_aux if not logging: return loss @@ -108,7 +117,7 @@ def loss(self, x, logging=False, **kwargs): def update(self, step, x): x = x.to(self.device) self.optimizer.zero_grad() - loss = self.loss(x) + loss = self.loss(x, step) loss.backward() self.optimizer.step() self.scheduler.step() @@ -123,6 +132,8 @@ def config(self): 'lr' : self.lr, 'l1_penalty' : self.l1_penalty, 'warmup_steps' : self.warmup_steps, + 'sparsity_warmup_steps' : self.sparsity_warmup_steps, + 'seed' : self.seed, 'device' : self.device, 'layer' : self.layer, 'lm_name' : self.lm_name, diff --git a/trainers/jumprelu.py b/trainers/jumprelu.py index e27e1b3..586313e 100644 --- a/trainers/jumprelu.py +++ b/trainers/jumprelu.py @@ -3,6 +3,7 @@ import torch import torch.autograd as autograd from torch import nn +from typing import Optional from ..dictionary import Dictionary, JumpReluAutoEncoder from .trainer import SAETrainer @@ -69,21 +70,22 @@ class JumpReluTrainer(nn.Module, SAETrainer): def __init__( self, dict_class=JumpReluAutoEncoder, - activation_dim=512, - dict_size=8192, - steps=30000, + activation_dim: int = 512, + dict_size: int = 8192, + steps: int = 30000, # XXX: Training decay is not implemented - seed=None, + seed: Optional[int] = None, # TODO: What's the default lr use in the paper? - lr=7e-5, - bandwidth=0.001, - sparsity_penalty=1.0, - target_l0=20.0, - device="cpu", - layer=None, - lm_name=None, - wandb_name="JumpRelu", - submodule_name=None, + lr: float = 7e-5, + bandwidth: float = 0.001, + sparsity_penalty: float = 1.0, + sparsity_warmup_steps: int = 2000, + target_l0: float = 20.0, + device: str = "cpu", + layer: Optional[int] = None, + lm_name: Optional[str] = None, + wandb_name: str = "JumpRelu", + submodule_name: Optional[str] = None, ): super().__init__() @@ -100,6 +102,7 @@ def __init__( self.bandwidth = bandwidth self.sparsity_coefficient = sparsity_penalty + self.sparsity_warmup_steps = sparsity_warmup_steps self.target_l0 = target_l0 # TODO: Better auto-naming (e.g. in BatchTopK package) @@ -119,14 +122,20 @@ def __init__( self.logging_parameters = [] - def loss(self, x, logging=False, **_): + def loss(self, x: torch.Tensor, step: int, logging=False, **_): + + if self.sparsity_warmup_steps is not None: + sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) + else: + sparsity_scale = 1.0 + f = self.ae.encode(x) recon = self.ae.decode(f) recon_loss = (x - recon).pow(2).sum(dim=-1).mean() l0 = StepFunction.apply(f, self.ae.threshold, self.bandwidth).sum(dim=-1).mean() - sparsity_loss = self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2) + sparsity_loss = self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2) * sparsity_scale loss = recon_loss + sparsity_loss if not logging: @@ -170,5 +179,6 @@ def config(self): "submodule_name": self.submodule_name, "bandwidth": self.bandwidth, "sparsity_penalty": self.sparsity_coefficient, + "sparsity_warmup_steps": self.sparsity_warmup_steps, "target_l0": self.target_l0, } diff --git a/trainers/standard.py b/trainers/standard.py index 506a5c0..8b5157f 100644 --- a/trainers/standard.py +++ b/trainers/standard.py @@ -2,6 +2,8 @@ Implements the standard SAE training scheme. """ import torch as t +from typing import Optional + from ..trainers.trainer import SAETrainer from ..config import DEBUG from ..dictionary import AutoEncoder @@ -33,18 +35,19 @@ class StandardTrainer(SAETrainer): """ def __init__(self, dict_class=AutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=1e-3, - l1_penalty=1e-1, - warmup_steps=1000, # lr warmup period at start of training and after each resample - resample_steps=None, # how often to resample neurons - seed=None, + activation_dim:int=512, + dict_size:int=64*512, + lr:float=1e-3, + l1_penalty:float=1e-1, + warmup_steps:int=1000, # lr warmup period at start of training and after each resample + sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training + resample_steps:Optional[int]=None, # how often to resample neurons + seed:Optional[int]=None, device=None, - layer=None, - lm_name=None, - wandb_name='StandardTrainer', - submodule_name=None, + layer:Optional[int]=None, + lm_name:Optional[str]=None, + wandb_name:Optional[str]='StandardTrainer', + submodule_name:Optional[str]=None, ): super().__init__(seed) @@ -70,10 +73,10 @@ def __init__(self, else: self.device = device self.ae.to(self.device) + + self.sparsity_warmup_steps = sparsity_warmup_steps self.resample_steps = resample_steps - - if self.resample_steps is not None: # how many steps since each neuron was last activated? self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device) @@ -124,7 +127,13 @@ def resample_neurons(self, deads, activations): state_dict[3]['exp_avg'][:,deads] = 0. state_dict[3]['exp_avg_sq'][:,deads] = 0. - def loss(self, x, logging=False, **kwargs): + def loss(self, x, step: int, logging=False, **kwargs): + + if self.sparsity_warmup_steps is not None: + sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) + else: + sparsity_scale = 1.0 + x_hat, f = self.ae(x, output_features=True) l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() @@ -136,7 +145,7 @@ def loss(self, x, logging=False, **kwargs): self.steps_since_active[deads] += 1 self.steps_since_active[~deads] = 0 - loss = recon_loss + self.l1_penalty * l1_loss + loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss if not logging: return loss @@ -156,7 +165,7 @@ def update(self, step, activations): activations = activations.to(self.device) self.optimizer.zero_grad() - loss = self.loss(activations) + loss = self.loss(activations, step=step) loss.backward() self.optimizer.step() self.scheduler.step() @@ -175,6 +184,8 @@ def config(self): 'l1_penalty' : self.l1_penalty, 'warmup_steps' : self.warmup_steps, 'resample_steps' : self.resample_steps, + 'sparsity_warmup_steps' : self.sparsity_warmup_steps, + 'seed' : self.seed, 'device' : self.device, 'layer' : self.layer, 'lm_name' : self.lm_name,