Skip to content

Commit a11670f

Browse files
authored
Merge pull request #32 from saprmarks/add_sparsity_warmup
Add sparsity warmup
2 parents 9687bb9 + 911b958 commit a11670f

File tree

4 files changed

+80
-47
lines changed

4 files changed

+80
-47
lines changed

tests/test_end_to_end.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def test_sae_training():
149149
"lr": learning_rate,
150150
"l1_penalty": sparsity_penalty,
151151
"warmup_steps": warmup_steps,
152+
"sparsity_warmup_steps": None,
152153
"resample_steps": resample_steps,
153154
"seed": RANDOM_SEED,
154155
"wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}",

trainers/gdm.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
"""
44

55
import torch as t
6+
from typing import Optional
7+
68
from ..trainers.trainer import SAETrainer
79
from ..config import DEBUG
810
from ..dictionary import GatedAutoEncoder
@@ -33,19 +35,19 @@ class GatedSAETrainer(SAETrainer):
3335
Gated SAE training scheme.
3436
"""
3537
def __init__(self,
36-
dict_class=GatedAutoEncoder,
37-
activation_dim=512,
38-
dict_size=64*512,
39-
lr=5e-5,
40-
l1_penalty=1e-1,
41-
warmup_steps=1000, # lr warmup period at start of training and after each resample
42-
resample_steps=None, # how often to resample neurons
43-
seed=None,
44-
device=None,
45-
layer=None,
46-
lm_name=None,
47-
wandb_name='GatedSAETrainer',
48-
submodule_name=None,
38+
dict_class = GatedAutoEncoder,
39+
activation_dim: int = 512,
40+
dict_size: int = 64*512,
41+
lr: float = 5e-5,
42+
l1_penalty: float = 1e-1,
43+
warmup_steps: int = 1000, # lr warmup period at start of training and after each resample
44+
sparsity_warmup_steps: int = 2000,
45+
seed: Optional[int] = None,
46+
device: Optional[str] = None,
47+
layer: Optional[int] = None,
48+
lm_name: Optional[str] = None,
49+
wandb_name: Optional[str] = 'GatedSAETrainer',
50+
submodule_name: Optional[str] = None,
4951
):
5052
super().__init__(seed)
5153

@@ -64,6 +66,7 @@ def __init__(self,
6466
self.lr = lr
6567
self.l1_penalty=l1_penalty
6668
self.warmup_steps = warmup_steps
69+
self.sparsity_warmup_steps = sparsity_warmup_steps
6770
self.wandb_name = wandb_name
6871

6972
if device is None:
@@ -81,7 +84,13 @@ def warmup_fn(step):
8184
return min(1, step / warmup_steps)
8285
self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_fn)
8386

84-
def loss(self, x, logging=False, **kwargs):
87+
def loss(self, x:t.Tensor, step:int, logging:bool=False, **kwargs):
88+
89+
if self.sparsity_warmup_steps is not None:
90+
sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0)
91+
else:
92+
sparsity_scale = 1.0
93+
8594
f, f_gate = self.ae.encode(x, return_gate=True)
8695
x_hat = self.ae.decode(f)
8796
x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach()
@@ -90,7 +99,7 @@ def loss(self, x, logging=False, **kwargs):
9099
L_sparse = t.linalg.norm(f_gate, ord=1, dim=-1).mean()
91100
L_aux = (x - x_hat_gate).pow(2).sum(dim=-1).mean()
92101

93-
loss = L_recon + self.l1_penalty * L_sparse + L_aux
102+
loss = L_recon + (self.l1_penalty * L_sparse * sparsity_scale) + L_aux
94103

95104
if not logging:
96105
return loss
@@ -108,7 +117,7 @@ def loss(self, x, logging=False, **kwargs):
108117
def update(self, step, x):
109118
x = x.to(self.device)
110119
self.optimizer.zero_grad()
111-
loss = self.loss(x)
120+
loss = self.loss(x, step)
112121
loss.backward()
113122
self.optimizer.step()
114123
self.scheduler.step()
@@ -123,6 +132,8 @@ def config(self):
123132
'lr' : self.lr,
124133
'l1_penalty' : self.l1_penalty,
125134
'warmup_steps' : self.warmup_steps,
135+
'sparsity_warmup_steps' : self.sparsity_warmup_steps,
136+
'seed' : self.seed,
126137
'device' : self.device,
127138
'layer' : self.layer,
128139
'lm_name' : self.lm_name,

trainers/jumprelu.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torch.autograd as autograd
55
from torch import nn
6+
from typing import Optional
67

78
from ..dictionary import Dictionary, JumpReluAutoEncoder
89
from .trainer import SAETrainer
@@ -69,21 +70,22 @@ class JumpReluTrainer(nn.Module, SAETrainer):
6970
def __init__(
7071
self,
7172
dict_class=JumpReluAutoEncoder,
72-
activation_dim=512,
73-
dict_size=8192,
74-
steps=30000,
73+
activation_dim: int = 512,
74+
dict_size: int = 8192,
75+
steps: int = 30000,
7576
# XXX: Training decay is not implemented
76-
seed=None,
77+
seed: Optional[int] = None,
7778
# TODO: What's the default lr use in the paper?
78-
lr=7e-5,
79-
bandwidth=0.001,
80-
sparsity_penalty=1.0,
81-
target_l0=20.0,
82-
device="cpu",
83-
layer=None,
84-
lm_name=None,
85-
wandb_name="JumpRelu",
86-
submodule_name=None,
79+
lr: float = 7e-5,
80+
bandwidth: float = 0.001,
81+
sparsity_penalty: float = 1.0,
82+
sparsity_warmup_steps: int = 2000,
83+
target_l0: float = 20.0,
84+
device: str = "cpu",
85+
layer: Optional[int] = None,
86+
lm_name: Optional[str] = None,
87+
wandb_name: str = "JumpRelu",
88+
submodule_name: Optional[str] = None,
8789
):
8890
super().__init__()
8991

@@ -100,6 +102,7 @@ def __init__(
100102

101103
self.bandwidth = bandwidth
102104
self.sparsity_coefficient = sparsity_penalty
105+
self.sparsity_warmup_steps = sparsity_warmup_steps
103106
self.target_l0 = target_l0
104107

105108
# TODO: Better auto-naming (e.g. in BatchTopK package)
@@ -119,14 +122,20 @@ def __init__(
119122

120123
self.logging_parameters = []
121124

122-
def loss(self, x, logging=False, **_):
125+
def loss(self, x: torch.Tensor, step: int, logging=False, **_):
126+
127+
if self.sparsity_warmup_steps is not None:
128+
sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0)
129+
else:
130+
sparsity_scale = 1.0
131+
123132
f = self.ae.encode(x)
124133
recon = self.ae.decode(f)
125134

126135
recon_loss = (x - recon).pow(2).sum(dim=-1).mean()
127136
l0 = StepFunction.apply(f, self.ae.threshold, self.bandwidth).sum(dim=-1).mean()
128137

129-
sparsity_loss = self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2)
138+
sparsity_loss = self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2) * sparsity_scale
130139
loss = recon_loss + sparsity_loss
131140

132141
if not logging:
@@ -170,5 +179,6 @@ def config(self):
170179
"submodule_name": self.submodule_name,
171180
"bandwidth": self.bandwidth,
172181
"sparsity_penalty": self.sparsity_coefficient,
182+
"sparsity_warmup_steps": self.sparsity_warmup_steps,
173183
"target_l0": self.target_l0,
174184
}

trainers/standard.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Implements the standard SAE training scheme.
33
"""
44
import torch as t
5+
from typing import Optional
6+
57
from ..trainers.trainer import SAETrainer
68
from ..config import DEBUG
79
from ..dictionary import AutoEncoder
@@ -33,18 +35,19 @@ class StandardTrainer(SAETrainer):
3335
"""
3436
def __init__(self,
3537
dict_class=AutoEncoder,
36-
activation_dim=512,
37-
dict_size=64*512,
38-
lr=1e-3,
39-
l1_penalty=1e-1,
40-
warmup_steps=1000, # lr warmup period at start of training and after each resample
41-
resample_steps=None, # how often to resample neurons
42-
seed=None,
38+
activation_dim:int=512,
39+
dict_size:int=64*512,
40+
lr:float=1e-3,
41+
l1_penalty:float=1e-1,
42+
warmup_steps:int=1000, # lr warmup period at start of training and after each resample
43+
sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training
44+
resample_steps:Optional[int]=None, # how often to resample neurons
45+
seed:Optional[int]=None,
4346
device=None,
44-
layer=None,
45-
lm_name=None,
46-
wandb_name='StandardTrainer',
47-
submodule_name=None,
47+
layer:Optional[int]=None,
48+
lm_name:Optional[str]=None,
49+
wandb_name:Optional[str]='StandardTrainer',
50+
submodule_name:Optional[str]=None,
4851
):
4952
super().__init__(seed)
5053

@@ -70,10 +73,10 @@ def __init__(self,
7073
else:
7174
self.device = device
7275
self.ae.to(self.device)
76+
77+
self.sparsity_warmup_steps = sparsity_warmup_steps
7378

7479
self.resample_steps = resample_steps
75-
76-
7780
if self.resample_steps is not None:
7881
# how many steps since each neuron was last activated?
7982
self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device)
@@ -124,7 +127,13 @@ def resample_neurons(self, deads, activations):
124127
state_dict[3]['exp_avg'][:,deads] = 0.
125128
state_dict[3]['exp_avg_sq'][:,deads] = 0.
126129

127-
def loss(self, x, logging=False, **kwargs):
130+
def loss(self, x, step: int, logging=False, **kwargs):
131+
132+
if self.sparsity_warmup_steps is not None:
133+
sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0)
134+
else:
135+
sparsity_scale = 1.0
136+
128137
x_hat, f = self.ae(x, output_features=True)
129138
l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean()
130139
recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean()
@@ -136,7 +145,7 @@ def loss(self, x, logging=False, **kwargs):
136145
self.steps_since_active[deads] += 1
137146
self.steps_since_active[~deads] = 0
138147

139-
loss = recon_loss + self.l1_penalty * l1_loss
148+
loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss
140149

141150
if not logging:
142151
return loss
@@ -156,7 +165,7 @@ def update(self, step, activations):
156165
activations = activations.to(self.device)
157166

158167
self.optimizer.zero_grad()
159-
loss = self.loss(activations)
168+
loss = self.loss(activations, step=step)
160169
loss.backward()
161170
self.optimizer.step()
162171
self.scheduler.step()
@@ -175,6 +184,8 @@ def config(self):
175184
'l1_penalty' : self.l1_penalty,
176185
'warmup_steps' : self.warmup_steps,
177186
'resample_steps' : self.resample_steps,
187+
'sparsity_warmup_steps' : self.sparsity_warmup_steps,
188+
'seed' : self.seed,
178189
'device' : self.device,
179190
'layer' : self.layer,
180191
'lm_name' : self.lm_name,

0 commit comments

Comments
 (0)