Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions dictionary_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
__version__ = "0.1.0"

from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder
from .dictionary import (
AutoEncoder,
GatedAutoEncoder,
JumpReluAutoEncoder,
BatchTopKSAE,
MatryoshkaBatchTopKSAE,
AutoEncoderTopK,
)
from .buffer import ActivationBuffer

__all__ = ["AutoEncoder", "GatedAutoEncoder", "JumpReluAutoEncoder", "ActivationBuffer"]
__all__ = [
"AutoEncoder",
"GatedAutoEncoder",
"JumpReluAutoEncoder",
"BatchTopKSAE",
"MatryoshkaBatchTopKSAE",
"AutoEncoderTopK",
"ActivationBuffer",
]
300 changes: 300 additions & 0 deletions dictionary_learning/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,25 @@
import torch.nn as nn
import torch.nn.init as init
import einops
from typing import Optional

@t.no_grad()
def set_decoder_norm_to_unit_norm(
W_dec_DF: t.nn.Parameter, activation_dim: int, d_sae: int
) -> t.Tensor:
"""There's a major footgun here: we use this with both nn.Linear and nn.Parameter decoders.
nn.Linear stores the decoder weights in a transposed format (d_model, d_sae). So, we pass the dimensions in
to catch this error."""

D, F = W_dec_DF.shape

assert D == activation_dim
assert F == d_sae

eps = t.finfo(W_dec_DF.dtype).eps
norm = t.norm(W_dec_DF.data, dim=0, keepdim=True)
W_dec_DF.data /= norm + eps
return W_dec_DF.data

class Dictionary(ABC, nn.Module):
"""
Expand Down Expand Up @@ -379,6 +397,288 @@ def from_pretrained(
device = autoencoder.W_enc.device
return autoencoder.to(dtype=dtype, device=device)

class AutoEncoderTopK(Dictionary, nn.Module):
"""
The top-k autoencoder architecture and initialization used in https://arxiv.org/abs/2406.04093
NOTE: (From Adam Karvonen) There is an unmaintained implementation using Triton kernels in the topk-triton-implementation branch.
We abandoned it as we didn't notice a significant speedup and it added complications, which are noted
in the AutoEncoderTopK class docstring in that branch.

With some additional effort, you can train a Top-K SAE with the Triton kernels and modify the state dict for compatibility with this class.
Notably, the Triton kernels currently have the decoder to be stored in nn.Parameter, not nn.Linear, and the decoder weights must also
be stored in the same shape as the encoder.
"""

def __init__(self, activation_dim: int, dict_size: int, k: int):
super().__init__()
self.activation_dim = activation_dim
self.dict_size = dict_size

assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer"
self.register_buffer("k", t.tensor(k, dtype=t.int))
self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32))

self.decoder = nn.Linear(dict_size, activation_dim, bias=False)
self.decoder.weight.data = set_decoder_norm_to_unit_norm(
self.decoder.weight, activation_dim, dict_size
)

self.encoder = nn.Linear(activation_dim, dict_size)
self.encoder.weight.data = self.decoder.weight.T.clone()
self.encoder.bias.data.zero_()

self.b_dec = nn.Parameter(t.zeros(activation_dim))

def encode(
self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = False
):
post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec))

if use_threshold:
encoded_acts_BF = post_relu_feat_acts_BF * (
post_relu_feat_acts_BF > self.threshold
)
if return_topk:
post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1)
return (
encoded_acts_BF,
post_topk.values,
post_topk.indices,
post_relu_feat_acts_BF,
)
else:
return encoded_acts_BF

post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1)

# We can't split immediately due to nnsight
tops_acts_BK = post_topk.values
top_indices_BK = post_topk.indices

buffer_BF = t.zeros_like(post_relu_feat_acts_BF)
encoded_acts_BF = buffer_BF.scatter_(
dim=-1, index=top_indices_BK, src=tops_acts_BK
)

if return_topk:
return encoded_acts_BF, tops_acts_BK, top_indices_BK, post_relu_feat_acts_BF
else:
return encoded_acts_BF

def decode(self, x: t.Tensor) -> t.Tensor:
return self.decoder(x) + self.b_dec

def forward(self, x: t.Tensor, output_features: bool = False):
encoded_acts_BF = self.encode(x)
x_hat_BD = self.decode(encoded_acts_BF)
if not output_features:
return x_hat_BD
else:
return x_hat_BD, encoded_acts_BF

def scale_biases(self, scale: float):
self.encoder.bias.data *= scale
self.b_dec.data *= scale
if self.threshold >= 0:
self.threshold *= scale

def from_pretrained(path, k: Optional[int] = None, device=None):
"""
Load a pretrained autoencoder from a file.
"""
state_dict = t.load(path)
dict_size, activation_dim = state_dict["encoder.weight"].shape

if k is None:
k = state_dict["k"].item()
elif "k" in state_dict and k != state_dict["k"].item():
raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']")

autoencoder = AutoEncoderTopK(activation_dim, dict_size, k)
autoencoder.load_state_dict(state_dict)
if device is not None:
autoencoder.to(device)
return autoencoder

class BatchTopKSAE(Dictionary, nn.Module):
def __init__(self, activation_dim: int, dict_size: int, k: int):
super().__init__()
self.activation_dim = activation_dim
self.dict_size = dict_size

assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer"
self.register_buffer("k", t.tensor(k, dtype=t.int))
self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32))

self.decoder = nn.Linear(dict_size, activation_dim, bias=False)
self.decoder.weight.data = set_decoder_norm_to_unit_norm(
self.decoder.weight, activation_dim, dict_size
)

self.encoder = nn.Linear(activation_dim, dict_size)
self.encoder.weight.data = self.decoder.weight.T.clone()
self.encoder.bias.data.zero_()
self.b_dec = nn.Parameter(t.zeros(activation_dim))

def encode(
self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True
):
post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec))

if use_threshold:
encoded_acts_BF = post_relu_feat_acts_BF * (
post_relu_feat_acts_BF > self.threshold
)
else:
# Flatten and perform batch top-k
flattened_acts = post_relu_feat_acts_BF.flatten()
post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1)

encoded_acts_BF = (
t.zeros_like(post_relu_feat_acts_BF.flatten())
.scatter_(-1, post_topk.indices, post_topk.values)
.reshape(post_relu_feat_acts_BF.shape)
)

if return_active:
return encoded_acts_BF, encoded_acts_BF.sum(0) > 0, post_relu_feat_acts_BF
else:
return encoded_acts_BF

def decode(self, x: t.Tensor) -> t.Tensor:
return self.decoder(x) + self.b_dec

def forward(self, x: t.Tensor, output_features: bool = False):
encoded_acts_BF = self.encode(x)
x_hat_BD = self.decode(encoded_acts_BF)

if not output_features:
return x_hat_BD
else:
return x_hat_BD, encoded_acts_BF

def scale_biases(self, scale: float):
self.encoder.bias.data *= scale
self.b_dec.data *= scale
if self.threshold >= 0:
self.threshold *= scale

@classmethod
def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE":
state_dict = t.load(path)
dict_size, activation_dim = state_dict["encoder.weight"].shape
if k is None:
k = state_dict["k"].item()
elif "k" in state_dict and k != state_dict["k"].item():
raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']")

autoencoder = cls(activation_dim, dict_size, k)
autoencoder.load_state_dict(state_dict)
if device is not None:
autoencoder.to(device)
return autoencoder

class MatryoshkaBatchTopKSAE(Dictionary, nn.Module):
def __init__(
self, activation_dim: int, dict_size: int, k: int, group_sizes: list[int]
):
super().__init__()
self.activation_dim = activation_dim
self.dict_size = dict_size

assert sum(group_sizes) == dict_size, "group sizes must sum to dict_size"
assert all(s > 0 for s in group_sizes), "all group sizes must be positive"

assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer"
self.register_buffer("k", t.tensor(k, dtype=t.int))
self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32))

self.active_groups = len(group_sizes)
group_indices = [0] + list(t.cumsum(t.tensor(group_sizes), dim=0))
self.group_indices = group_indices

self.register_buffer("group_sizes", t.tensor(group_sizes))

self.W_enc = nn.Parameter(t.empty(activation_dim, dict_size))
self.b_enc = nn.Parameter(t.zeros(dict_size))
self.W_dec = nn.Parameter(
t.nn.init.kaiming_uniform_(t.empty(dict_size, activation_dim))
)
self.b_dec = nn.Parameter(t.zeros(activation_dim))

# We must transpose because we are using nn.Parameter, not nn.Linear
self.W_dec.data = set_decoder_norm_to_unit_norm(
self.W_dec.data.T, activation_dim, dict_size
).T
self.W_enc.data = self.W_dec.data.clone().T

def encode(
self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True
):
post_relu_feat_acts_BF = nn.functional.relu(
(x - self.b_dec) @ self.W_enc + self.b_enc
)

if use_threshold:
encoded_acts_BF = post_relu_feat_acts_BF * (
post_relu_feat_acts_BF > self.threshold
)
else:
# Flatten and perform batch top-k
flattened_acts = post_relu_feat_acts_BF.flatten()
post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1)

encoded_acts_BF = (
t.zeros_like(post_relu_feat_acts_BF.flatten())
.scatter_(-1, post_topk.indices, post_topk.values)
.reshape(post_relu_feat_acts_BF.shape)
)

max_act_index = self.group_indices[self.active_groups]
encoded_acts_BF[:, max_act_index:] = 0

if return_active:
return encoded_acts_BF, encoded_acts_BF.sum(0) > 0, post_relu_feat_acts_BF
else:
return encoded_acts_BF

def decode(self, x: t.Tensor) -> t.Tensor:
return x @ self.W_dec + self.b_dec

def forward(self, x: t.Tensor, output_features: bool = False):
encoded_acts_BF = self.encode(x)
x_hat_BD = self.decode(encoded_acts_BF)

if not output_features:
return x_hat_BD
else:
return x_hat_BD, encoded_acts_BF

@t.no_grad()
def scale_biases(self, scale: float):
self.b_enc.data *= scale
self.b_dec.data *= scale
if self.threshold >= 0:
self.threshold *= scale

@classmethod
def from_pretrained(
cls, path, k=None, device=None, **kwargs
) -> "MatryoshkaBatchTopKSAE":
state_dict = t.load(path)
activation_dim, dict_size = state_dict["W_enc"].shape
if k is None:
k = state_dict["k"].item()
elif "k" in state_dict and k != state_dict["k"].item():
raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']")

group_sizes = state_dict["group_sizes"].tolist()

autoencoder = cls(activation_dim, dict_size, k=k, group_sizes=group_sizes)
autoencoder.load_state_dict(state_dict)
if device is not None:
autoencoder.to(device)
return autoencoder

# TODO merge this with AutoEncoder
class AutoEncoderNew(Dictionary, nn.Module):
Expand Down
7 changes: 5 additions & 2 deletions dictionary_learning/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from .gated_anneal import GatedAnnealTrainer
from .top_k import TopKTrainer
from .jumprelu import JumpReluTrainer
from .batch_top_k import BatchTopKTrainer, BatchTopKSAE
from .batch_top_k import BatchTopKTrainer
from .matryoshka_batch_top_k import MatryoshkaBatchTopKTrainer
from .top_k import TopKTrainer


__all__ = [
Expand All @@ -15,5 +17,6 @@
"TopKTrainer",
"JumpReluTrainer",
"BatchTopKTrainer",
"BatchTopKSAE",
"MatryoshkaBatchTopKTrainer",
"TopKTrainer",
]
Loading