Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self,
d_submodule = submodule.out_features
except:
raise ValueError("d_submodule cannot be inferred and must be specified directly")
self.activations = t.empty(0, d_submodule, device=device)
self.activations = t.empty(0, d_submodule, device=device, dtype=model.dtype)
self.read = t.zeros(0).bool()

self.data = data
Expand Down Expand Up @@ -105,7 +105,7 @@ def refresh(self):
self.activations = self.activations[~self.read]

current_idx = len(self.activations)
new_activations = t.empty(self.activation_buffer_size, self.d_submodule, device=self.device)
new_activations = t.empty(self.activation_buffer_size, self.d_submodule, device=self.device, dtype=self.model.dtype)

new_activations[: len(self.activations)] = self.activations
self.activations = new_activations
Expand Down
32 changes: 26 additions & 6 deletions dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,16 @@ def __init__(self, activation_dim, dict_size):
self.dict_size = dict_size
self.bias = nn.Parameter(t.zeros(activation_dim))
self.encoder = nn.Linear(activation_dim, dict_size, bias=True)

# rows of decoder weight matrix are unit vectors
self.decoder = nn.Linear(dict_size, activation_dim, bias=False)
dec_weight = t.randn_like(self.decoder.weight)
dec_weight = dec_weight / dec_weight.norm(dim=0, keepdim=True)
self.decoder.weight = nn.Parameter(dec_weight)

# initialize encoder and decoder weights
w = t.randn(activation_dim, dict_size)
## normalize columns of w
w = w / w.norm(dim=0, keepdim=True) * 0.1
## set encoder and decoder weights
self.encoder.weight = nn.Parameter(w.clone().T)
self.decoder.weight = nn.Parameter(w.clone())


def encode(self, x):
return nn.ReLU()(self.encoder(x - self.bias))
Expand Down Expand Up @@ -86,6 +90,10 @@ def forward(self, x, output_features=False, ghost_mask=None):
return x_hat, x_ghost, f
else:
return x_hat, x_ghost

def scale_biases(self, scale: float):
self.encoder.bias.data *= scale
self.bias.data *= scale

@classmethod
def from_pretrained(cls, path, dtype=t.float, device=None):
Expand Down Expand Up @@ -204,6 +212,11 @@ def forward(self, x, output_features=False):
else:
return x_hat

def scale_biases(self, scale: float):
self.decoder_bias.data *= scale
self.mag_bias.data *= scale
self.gate_bias.data *= scale

def from_pretrained(path, device=None):
"""
Load a pretrained autoencoder from a file.
Expand All @@ -215,6 +228,7 @@ def from_pretrained(path, device=None):
if device is not None:
autoencoder.to(device)
return autoencoder


class JumpReluAutoEncoder(Dictionary, nn.Module):
"""
Expand Down Expand Up @@ -267,6 +281,11 @@ def forward(self, x, output_features=False):
return x_hat, f
else:
return x_hat

def scale_biases(self, scale: float):
self.b_dec.data *= scale
self.b_enc.data *= scale
self.threshold.data *= scale

@classmethod
def from_pretrained(
Expand All @@ -284,9 +303,10 @@ def from_pretrained(
"""
if not load_from_sae_lens:
state_dict = t.load(path)
dict_size, activation_dim = state_dict['W_enc'].shape
activation_dim, dict_size = state_dict['W_enc'].shape
autoencoder = JumpReluAutoEncoder(activation_dim, dict_size)
autoencoder.load_state_dict(state_dict)
autoencoder = autoencoder.to(dtype=dtype, device=device)
else:
from sae_lens import SAE
sae, cfg_dict, _ = SAE.from_pretrained(**kwargs)
Expand Down
60 changes: 14 additions & 46 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from dictionary_learning.training import trainSAE
from dictionary_learning.trainers.standard import StandardTrainer
from dictionary_learning.trainers.top_k import TrainerTopK, AutoEncoderTopK
from dictionary_learning.utils import hf_dataset_to_generator
from dictionary_learning.trainers.top_k import TopKTrainer, AutoEncoderTopK
from dictionary_learning.utils import hf_dataset_to_generator, get_nested_folders, load_dictionary
from dictionary_learning.buffer import ActivationBuffer
from dictionary_learning.dictionary import (
AutoEncoder,
Expand Down Expand Up @@ -58,50 +58,11 @@
EVAL_TOLERANCE = 0.01


def get_nested_folders(path: str) -> list[str]:
"""
Recursively get a list of folders that contain an ae.pt file, starting the search from the given path
"""
folder_names = []

for root, dirs, files in os.walk(path):
if "ae.pt" in files:
folder_names.append(root)

return folder_names


def load_dictionary(base_path: str, device: str) -> tuple:
ae_path = f"{base_path}/ae.pt"
config_path = f"{base_path}/config.json"

with open(config_path, "r") as f:
config = json.load(f)

# TODO: Save the submodule name in the config?
# submodule_str = config["trainer"]["submodule_name"]
dict_class = config["trainer"]["dict_class"]

if dict_class == "AutoEncoder":
dictionary = AutoEncoder.from_pretrained(ae_path, device=device)
elif dict_class == "GatedAutoEncoder":
dictionary = GatedAutoEncoder.from_pretrained(ae_path, device=device)
elif dict_class == "AutoEncoderNew":
dictionary = AutoEncoderNew.from_pretrained(ae_path, device=device)
elif dict_class == "AutoEncoderTopK":
k = config["trainer"]["k"]
dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device)
elif dict_class == "JumpReluAutoEncoder":
dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device)
else:
raise ValueError(f"Dictionary class {dict_class} not supported")

return dictionary, config


def test_sae_training():
"""End to end test for training an SAE. Takes ~2 minutes on an RTX 3090.
This isn't a nice suite of unit tests, but it's better than nothing."""
This isn't a nice suite of unit tests, but it's better than nothing.
I have observed that results can slightly vary with library versions. For full determinism,
use pytorch 2.2.0 and nnsight 0.3.3."""
random.seed(RANDOM_SEED)
t.manual_seed(RANDOM_SEED)

Expand Down Expand Up @@ -158,7 +119,7 @@ def test_sae_training():
trainer_configs.extend(
[
{
"trainer": TrainerTopK,
"trainer": TopKTrainer,
"dict_class": AutoEncoderTopK,
"activation_dim": activation_dim,
"dict_size": expansion_factor * activation_dim,
Expand Down Expand Up @@ -278,5 +239,12 @@ def test_evaluation():
dict_class = config["trainer"]["dict_class"]
expected_results = EXPECTED_RESULTS[dict_class]

max_diff = 0
max_diff_percent = 0
for key, value in expected_results.items():
assert abs(eval_results[key] - value) < EVAL_TOLERANCE
diff = abs(eval_results[key] - value)
max_diff = max(max_diff, diff)
max_diff_percent = max(max_diff_percent, diff / value)

print(f"Max diff: {max_diff}, max diff %: {max_diff_percent}")
assert max_diff < EVAL_TOLERANCE
6 changes: 3 additions & 3 deletions trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from .gdm import GatedSAETrainer
from .p_anneal import PAnnealTrainer
from .gated_anneal import GatedAnnealTrainer
from .top_k import TrainerTopK
from .jumprelu import TrainerJumpRelu
from .batch_top_k import TrainerBatchTopK, BatchTopKSAE
from .top_k import TopKTrainer
from .jumprelu import JumpReluTrainer
from .batch_top_k import BatchTopKTrainer, BatchTopKSAE
64 changes: 47 additions & 17 deletions trainers/batch_top_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, activation_dim: int, dict_size: int, k: int):

assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer"
self.register_buffer("k", t.tensor(k))
self.register_buffer("threshold", t.tensor(-1.0))

self.encoder = nn.Linear(activation_dim, dict_size)
self.encoder.bias.data.zero_()
Expand All @@ -24,9 +25,16 @@ def __init__(self, activation_dim: int, dict_size: int, k: int):
self.set_decoder_norm_to_unit_norm()
self.b_dec = nn.Parameter(t.zeros(activation_dim))

def encode(self, x: t.Tensor, return_active: bool = False):
def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True):
post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec))

if use_threshold:
encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold)
if return_active:
return encoded_acts_BF, encoded_acts_BF.sum(0) > 0
else:
return encoded_acts_BF

# Flatten and perform batch top-k
flattened_acts = post_relu_feat_acts_BF.flatten()
post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1)
Expand Down Expand Up @@ -75,13 +83,19 @@ def remove_gradient_parallel_to_decoder_directions(self):
"d_sae, d_in d_sae -> d_in d_sae",
)

def scale_biases(self, scale: float):
self.encoder.bias.data *= scale
self.b_dec.data *= scale
if self.threshold >= 0:
self.threshold *= scale

@classmethod
def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE":
state_dict = t.load(path)
dict_size, activation_dim = state_dict['encoder.weight'].shape
dict_size, activation_dim = state_dict["encoder.weight"].shape
if k is None:
k = state_dict['k'].item()
elif 'k' in state_dict and k != state_dict['k'].item():
k = state_dict["k"].item()
elif "k" in state_dict and k != state_dict["k"].item():
raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']")

autoencoder = cls(activation_dim, dict_size, k)
Expand All @@ -91,7 +105,7 @@ def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE":
return autoencoder


class TrainerBatchTopK(SAETrainer):
class BatchTopKTrainer(SAETrainer):
def __init__(
self,
dict_class=BatchTopKSAE,
Expand All @@ -100,6 +114,8 @@ def __init__(
k=8,
auxk_alpha=1 / 32,
decay_start=24000,
threshold_beta=0.999,
threshold_start_step=1000,
steps=30000,
top_k_aux=512,
seed=None,
Expand All @@ -117,6 +133,8 @@ def __init__(
self.wandb_name = wandb_name
self.steps = steps
self.k = k
self.threshold_beta = threshold_beta
self.threshold_start_step = threshold_start_step

if seed is not None:
t.manual_seed(seed)
Expand All @@ -136,9 +154,7 @@ def __init__(
self.dead_feature_threshold = 10_000_000
self.top_k_aux = top_k_aux

self.optimizer = t.optim.Adam(
self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)
)
self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999))

def lr_fn(step):
if step < decay_start:
Expand All @@ -165,20 +181,34 @@ def get_auxiliary_loss(self, x, x_reconstruct, acts):
acts_aux = t.zeros_like(acts[:, dead_features]).scatter(
-1, acts_topk_aux.indices, acts_topk_aux.values
)
x_reconstruct_aux = F.linear(
acts_aux, self.ae.decoder.weight[:, dead_features]
)
x_reconstruct_aux = F.linear(acts_aux, self.ae.decoder.weight[:, dead_features])
l2_loss_aux = (
self.auxk_alpha
* (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
self.auxk_alpha * (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
)
return l2_loss_aux
else:
return t.tensor(0, dtype=x.dtype, device=x.device)

def loss(self, x, step=None, logging=False):
f, active_indices = self.ae.encode(x, return_active=True)
l0 = (f != 0).float().sum(dim=-1).mean().item()
f, active_indices = self.ae.encode(x, return_active=True, use_threshold=False)
# l0 = (f != 0).float().sum(dim=-1).mean().item()

if step > self.threshold_start_step:
with t.no_grad():
active = f[f > 0]

if active.size(0) == 0:
min_activation = 0.0
else:
min_activation = active.min().detach()

if self.ae.threshold < 0:
self.ae.threshold = min_activation
else:
self.ae.threshold = (self.threshold_beta * self.ae.threshold) + (
(1 - self.threshold_beta) * min_activation
)

x_hat = self.ae.decode(f)

e = x_hat - x
Expand Down Expand Up @@ -230,14 +260,14 @@ def update(self, step, x):
@property
def config(self):
return {
"trainer_class": "TrainerBatchTopK",
"trainer_class": "BatchTopKTrainer",
"dict_class": "BatchTopKSAE",
"lr": self.lr,
"steps": self.steps,
"seed": self.seed,
"activation_dim": self.ae.activation_dim,
"dict_size": self.ae.dict_size,
"k": self.ae.k,
"k": self.ae.k.item(),
"device": self.device,
"layer": self.layer,
"lm_name": self.lm_name,
Expand Down
14 changes: 10 additions & 4 deletions trainers/jumprelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def backward(ctx, grad_output):
return x_grad, threshold_grad, None # None for bandwidth


class TrainerJumpRelu(nn.Module, SAETrainer):
class JumpReluTrainer(nn.Module, SAETrainer):
"""
Trains a JumpReLU autoencoder.

Expand All @@ -77,7 +77,8 @@ def __init__(
# TODO: What's the default lr use in the paper?
lr=7e-5,
bandwidth=0.001,
sparsity_penalty=0.1,
sparsity_penalty=1.0,
target_l0=20.0,
device="cpu",
layer=None,
lm_name=None,
Expand All @@ -99,6 +100,7 @@ def __init__(

self.bandwidth = bandwidth
self.sparsity_coefficient = sparsity_penalty
self.target_l0 = target_l0

# TODO: Better auto-naming (e.g. in BatchTopK package)
self.wandb_name = wandb_name
Expand All @@ -123,7 +125,8 @@ def loss(self, x, logging=False, **_):

recon_loss = (x - recon).pow(2).sum(dim=-1).mean()
l0 = StepFunction.apply(f, self.ae.threshold, self.bandwidth).sum(dim=-1).mean()
sparsity_loss = self.sparsity_coefficient * l0

sparsity_loss = self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2)
loss = recon_loss + sparsity_loss

if not logging:
Expand Down Expand Up @@ -153,7 +156,7 @@ def update(self, step, x):
@property
def config(self):
return {
"trainer_class": "TrainerJumpRelu",
"trainer_class": "JumpReluTrainer",
"dict_class": "JumpReluAutoEncoder",
"lr": self.lr,
"steps": self.steps,
Expand All @@ -165,4 +168,7 @@ def config(self):
"lm_name": self.lm_name,
"wandb_name": self.wandb_name,
"submodule_name": self.submodule_name,
"bandwidth": self.bandwidth,
"sparsity_penalty": self.sparsity_coefficient,
"target_l0": self.target_l0,
}
Loading