Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def test_sae_training():
"lr": learning_rate,
"l1_penalty": sparsity_penalty,
"warmup_steps": warmup_steps,
"sparsity_warmup_steps": None,
"resample_steps": resample_steps,
"seed": RANDOM_SEED,
"wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}",
Expand Down
43 changes: 27 additions & 16 deletions trainers/gdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""

import torch as t
from typing import Optional

from ..trainers.trainer import SAETrainer
from ..config import DEBUG
from ..dictionary import GatedAutoEncoder
Expand Down Expand Up @@ -33,19 +35,19 @@ class GatedSAETrainer(SAETrainer):
Gated SAE training scheme.
"""
def __init__(self,
dict_class=GatedAutoEncoder,
activation_dim=512,
dict_size=64*512,
lr=5e-5,
l1_penalty=1e-1,
warmup_steps=1000, # lr warmup period at start of training and after each resample
resample_steps=None, # how often to resample neurons
seed=None,
device=None,
layer=None,
lm_name=None,
wandb_name='GatedSAETrainer',
submodule_name=None,
dict_class = GatedAutoEncoder,
activation_dim: int = 512,
dict_size: int = 64*512,
lr: float = 5e-5,
l1_penalty: float = 1e-1,
warmup_steps: int = 1000, # lr warmup period at start of training and after each resample
sparsity_warmup_steps: int = 2000,
seed: Optional[int] = None,
device: Optional[str] = None,
layer: Optional[int] = None,
lm_name: Optional[str] = None,
wandb_name: Optional[str] = 'GatedSAETrainer',
submodule_name: Optional[str] = None,
):
super().__init__(seed)

Expand All @@ -64,6 +66,7 @@ def __init__(self,
self.lr = lr
self.l1_penalty=l1_penalty
self.warmup_steps = warmup_steps
self.sparsity_warmup_steps = sparsity_warmup_steps
self.wandb_name = wandb_name

if device is None:
Expand All @@ -81,7 +84,13 @@ def warmup_fn(step):
return min(1, step / warmup_steps)
self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_fn)

def loss(self, x, logging=False, **kwargs):
def loss(self, x:t.Tensor, step:int, logging:bool=False, **kwargs):

if self.sparsity_warmup_steps is not None:
sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0)
else:
sparsity_scale = 1.0

f, f_gate = self.ae.encode(x, return_gate=True)
x_hat = self.ae.decode(f)
x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach()
Expand All @@ -90,7 +99,7 @@ def loss(self, x, logging=False, **kwargs):
L_sparse = t.linalg.norm(f_gate, ord=1, dim=-1).mean()
L_aux = (x - x_hat_gate).pow(2).sum(dim=-1).mean()

loss = L_recon + self.l1_penalty * L_sparse + L_aux
loss = L_recon + (self.l1_penalty * L_sparse * sparsity_scale) + L_aux

if not logging:
return loss
Expand All @@ -108,7 +117,7 @@ def loss(self, x, logging=False, **kwargs):
def update(self, step, x):
x = x.to(self.device)
self.optimizer.zero_grad()
loss = self.loss(x)
loss = self.loss(x, step)
loss.backward()
self.optimizer.step()
self.scheduler.step()
Expand All @@ -123,6 +132,8 @@ def config(self):
'lr' : self.lr,
'l1_penalty' : self.l1_penalty,
'warmup_steps' : self.warmup_steps,
'sparsity_warmup_steps' : self.sparsity_warmup_steps,
'seed' : self.seed,
'device' : self.device,
'layer' : self.layer,
'lm_name' : self.lm_name,
Expand Down
40 changes: 25 additions & 15 deletions trainers/jumprelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.autograd as autograd
from torch import nn
from typing import Optional

from ..dictionary import Dictionary, JumpReluAutoEncoder
from .trainer import SAETrainer
Expand Down Expand Up @@ -69,21 +70,22 @@ class JumpReluTrainer(nn.Module, SAETrainer):
def __init__(
self,
dict_class=JumpReluAutoEncoder,
activation_dim=512,
dict_size=8192,
steps=30000,
activation_dim: int = 512,
dict_size: int = 8192,
steps: int = 30000,
# XXX: Training decay is not implemented
seed=None,
seed: Optional[int] = None,
# TODO: What's the default lr use in the paper?
lr=7e-5,
bandwidth=0.001,
sparsity_penalty=1.0,
target_l0=20.0,
device="cpu",
layer=None,
lm_name=None,
wandb_name="JumpRelu",
submodule_name=None,
lr: float = 7e-5,
bandwidth: float = 0.001,
sparsity_penalty: float = 1.0,
sparsity_warmup_steps: int = 2000,
target_l0: float = 20.0,
device: str = "cpu",
layer: Optional[int] = None,
lm_name: Optional[str] = None,
wandb_name: str = "JumpRelu",
submodule_name: Optional[str] = None,
):
super().__init__()

Expand All @@ -100,6 +102,7 @@ def __init__(

self.bandwidth = bandwidth
self.sparsity_coefficient = sparsity_penalty
self.sparsity_warmup_steps = sparsity_warmup_steps
self.target_l0 = target_l0

# TODO: Better auto-naming (e.g. in BatchTopK package)
Expand All @@ -119,14 +122,20 @@ def __init__(

self.logging_parameters = []

def loss(self, x, logging=False, **_):
def loss(self, x: torch.Tensor, step: int, logging=False, **_):

if self.sparsity_warmup_steps is not None:
sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0)
else:
sparsity_scale = 1.0

f = self.ae.encode(x)
recon = self.ae.decode(f)

recon_loss = (x - recon).pow(2).sum(dim=-1).mean()
l0 = StepFunction.apply(f, self.ae.threshold, self.bandwidth).sum(dim=-1).mean()

sparsity_loss = self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2)
sparsity_loss = self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2) * sparsity_scale
loss = recon_loss + sparsity_loss

if not logging:
Expand Down Expand Up @@ -170,5 +179,6 @@ def config(self):
"submodule_name": self.submodule_name,
"bandwidth": self.bandwidth,
"sparsity_penalty": self.sparsity_coefficient,
"sparsity_warmup_steps": self.sparsity_warmup_steps,
"target_l0": self.target_l0,
}
43 changes: 27 additions & 16 deletions trainers/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Implements the standard SAE training scheme.
"""
import torch as t
from typing import Optional

from ..trainers.trainer import SAETrainer
from ..config import DEBUG
from ..dictionary import AutoEncoder
Expand Down Expand Up @@ -33,18 +35,19 @@ class StandardTrainer(SAETrainer):
"""
def __init__(self,
dict_class=AutoEncoder,
activation_dim=512,
dict_size=64*512,
lr=1e-3,
l1_penalty=1e-1,
warmup_steps=1000, # lr warmup period at start of training and after each resample
resample_steps=None, # how often to resample neurons
seed=None,
activation_dim:int=512,
dict_size:int=64*512,
lr:float=1e-3,
l1_penalty:float=1e-1,
warmup_steps:int=1000, # lr warmup period at start of training and after each resample
sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training
resample_steps:Optional[int]=None, # how often to resample neurons
seed:Optional[int]=None,
device=None,
layer=None,
lm_name=None,
wandb_name='StandardTrainer',
submodule_name=None,
layer:Optional[int]=None,
lm_name:Optional[str]=None,
wandb_name:Optional[str]='StandardTrainer',
submodule_name:Optional[str]=None,
):
super().__init__(seed)

Expand All @@ -70,10 +73,10 @@ def __init__(self,
else:
self.device = device
self.ae.to(self.device)

self.sparsity_warmup_steps = sparsity_warmup_steps

self.resample_steps = resample_steps


if self.resample_steps is not None:
# how many steps since each neuron was last activated?
self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device)
Expand Down Expand Up @@ -124,7 +127,13 @@ def resample_neurons(self, deads, activations):
state_dict[3]['exp_avg'][:,deads] = 0.
state_dict[3]['exp_avg_sq'][:,deads] = 0.

def loss(self, x, logging=False, **kwargs):
def loss(self, x, step: int, logging=False, **kwargs):

if self.sparsity_warmup_steps is not None:
sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0)
else:
sparsity_scale = 1.0

x_hat, f = self.ae(x, output_features=True)
l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean()
recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean()
Expand All @@ -136,7 +145,7 @@ def loss(self, x, logging=False, **kwargs):
self.steps_since_active[deads] += 1
self.steps_since_active[~deads] = 0

loss = recon_loss + self.l1_penalty * l1_loss
loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss

if not logging:
return loss
Expand All @@ -156,7 +165,7 @@ def update(self, step, activations):
activations = activations.to(self.device)

self.optimizer.zero_grad()
loss = self.loss(activations)
loss = self.loss(activations, step=step)
loss.backward()
self.optimizer.step()
self.scheduler.step()
Expand All @@ -175,6 +184,8 @@ def config(self):
'l1_penalty' : self.l1_penalty,
'warmup_steps' : self.warmup_steps,
'resample_steps' : self.resample_steps,
'sparsity_warmup_steps' : self.sparsity_warmup_steps,
'seed' : self.seed,
'device' : self.device,
'layer' : self.layer,
'lm_name' : self.lm_name,
Expand Down