Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,12 @@ def refresh(self):
invoker_args={"truncation": True, "max_length": self.ctx_len},
):
if self.io == "in":
hidden_states = self.submodule.input[0].save()
hidden_states = self.submodule.inputs[0].save()
else:
hidden_states = self.submodule.output.save()
input = self.model.input.save()
input = self.model.inputs.save()

self.submodule.output.stop()
attn_mask = input.value[1]["attention_mask"]
hidden_states = hidden_states.value
if isinstance(hidden_states, tuple):
Expand Down Expand Up @@ -251,8 +253,8 @@ def refresh(self):
while len(self.activations) < self.n_ctxs * self.ctx_len:
with t.no_grad():
with self.model.trace(self.text_batch(), **tracer_kwargs, invoker_args={'truncation': True, 'max_length': self.ctx_len}, remote=self.remote):
input = self.model.input.save()
hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.input[0][0]#.save()
input = self.model.inputs.save()
hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.inputs[0][0]#.save()
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]

Expand Down Expand Up @@ -416,7 +418,7 @@ def refresh(self):
invoker_args={"truncation": True, "max_length": self.ctx_len},
):
if self.io in ["in", "in_and_out"]:
hidden_states_in = self.submodule.input[0].save()
hidden_states_in = self.submodule.inputs[0].save()
if self.io in ["out", "in_and_out"]:
hidden_states_out = self.submodule.output.save()

Expand Down
113 changes: 65 additions & 48 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""

import torch as t
from collections import defaultdict

from .buffer import ActivationBuffer, NNsightActivationBuffer
from nnsight import LanguageModel
from .config import DEBUG
Expand All @@ -22,12 +24,21 @@ def loss_recovered(
How much of the model's loss is recovered by replacing the component output
with the reconstruction by the autoencoder?
"""

if max_len is None:
invoker_args = {}
else:
invoker_args = {"truncation": True, "max_length": max_len }

with model.trace("_"):
temp_output = submodule.output.save()

output_is_tuple = False
# Note: isinstance() won't work here as torch.Size is a subclass of tuple,
# so isinstance(temp_output.shape, tuple) would return True even for torch.Size.
if type(temp_output.shape) == tuple:
output_is_tuple = True

# unmodified logits
with model.trace(text, invoker_args=invoker_args):
logits_original = model.output.save()
Expand All @@ -36,57 +47,57 @@ def loss_recovered(
# logits when replacing component activations with reconstruction by autoencoder
with model.trace(text, **tracer_args, invoker_args=invoker_args):
if io == 'in':
x = submodule.input[0]
if type(submodule.input.shape) == tuple: x = x[0]
x = submodule.input
if normalize_batch:
scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
x = x * scale
elif io == 'out':
x = submodule.output
if type(submodule.output.shape) == tuple: x = x[0]
if output_is_tuple: x = x[0]
if normalize_batch:
scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
x = x * scale
elif io == 'in_and_out':
x = submodule.input[0]
if type(submodule.input.shape) == tuple: x = x[0]
print(f'x.shape: {x.shape}')
x = submodule.input
if normalize_batch:
scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
x = x * scale
else:
raise ValueError(f"Invalid value for io: {io}")
x = x.save()

# pull this out so dictionary can be written without FakeTensor (top_k needs this)
x_hat = dictionary(x.view(-1, x.shape[-1])).view(x.shape).to(model.dtype)
# If we incorrectly handle output_is_tuple, such as with some mlp submodules, we will get an error here.
assert len(x.shape) == 3, f"Expected x to have shape (B, L, D), got {x.shape}, output_is_tuple: {output_is_tuple}"

x_hat = dictionary(x).to(model.dtype)

# intervene with `x_hat`
with model.trace(text, **tracer_args, invoker_args=invoker_args):
if io == 'in':
x = submodule.input[0]
x = submodule.input
if normalize_batch:
scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
x_hat = x_hat / scale
if type(submodule.input.shape) == tuple:
submodule.input[0][:] = x_hat
else:
submodule.input = x_hat
submodule.input[:] = x_hat
elif io == 'out':
x = submodule.output
if output_is_tuple: x = x[0]
if normalize_batch:
scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
x_hat = x_hat / scale
if type(submodule.output.shape) == tuple:
submodule.output = (x_hat,)
if output_is_tuple:
submodule.output[0][:] = x_hat
else:
submodule.output = x_hat
submodule.output[:] = x_hat
elif io == 'in_and_out':
x = submodule.input[0]
x = submodule.input
if normalize_batch:
scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
x_hat = x_hat / scale
submodule.output = x_hat
if output_is_tuple:
submodule.output[0][:] = x_hat
else:
submodule.output[:] = x_hat
else:
raise ValueError(f"Invalid value for io: {io}")

Expand All @@ -96,22 +107,20 @@ def loss_recovered(
# logits when replacing component activations with zeros
with model.trace(text, **tracer_args, invoker_args=invoker_args):
if io == 'in':
x = submodule.input[0]
if type(submodule.input.shape) == tuple:
submodule.input[0][:] = t.zeros_like(x[0])
else:
submodule.input = t.zeros_like(x)
x = submodule.input
submodule.input[:] = t.zeros_like(x)
elif io in ['out', 'in_and_out']:
x = submodule.output
if type(submodule.output.shape) == tuple:
if output_is_tuple:
submodule.output[0][:] = t.zeros_like(x[0])
else:
submodule.output = t.zeros_like(x)
submodule.output[:] = t.zeros_like(x)
else:
raise ValueError(f"Invalid value for io: {io}")

input = model.input.save()
input = model.inputs.save()
logits_zero = model.output.save()

logits_zero = logits_zero.value

# get everything into the right format
Expand Down Expand Up @@ -144,7 +153,7 @@ def loss_recovered(

return tuple(losses)


@t.no_grad()
def evaluate(
dictionary, # a dictionary
activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered
Expand All @@ -154,26 +163,31 @@ def evaluate(
normalize_batch=False, # normalize batch before passing through dictionary
tracer_args={'use_cache': False, 'output_attentions': False}, # minimize cache during model trace.
device="cpu",
n_batches: int = 1,
):
with t.no_grad():

out = {} # dict of results
assert n_batches > 0
out = defaultdict(float)
active_features = t.zeros(dictionary.dict_size, dtype=t.float32, device=device)

for _ in range(n_batches):
try:
x = next(activations).to(device)
if normalize_batch:
x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim ** 0.5)

except StopIteration:
raise StopIteration(
"Not enough activations in buffer. Pass a buffer with a smaller batch size or more data."
)

x_hat, f = dictionary(x, output_features=True)
l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean()
l1_loss = f.norm(p=1, dim=-1).mean()
l0 = (f != 0).float().sum(dim=-1).mean()
frac_alive = t.flatten(f, start_dim=0, end_dim=1).any(dim=0).sum() / dictionary.dict_size

features_BF = t.flatten(f, start_dim=0, end_dim=-2).to(dtype=t.float32) # If f is shape (B, L, D), flatten to (B*L, D)
assert features_BF.shape[-1] == dictionary.dict_size
assert len(features_BF.shape) == 2

active_features += features_BF.sum(dim=0)

# cosine similarity between x and x_hat
x_normed = x / t.linalg.norm(x, dim=-1, keepdim=True)
Expand All @@ -193,17 +207,16 @@ def evaluate(
x_dot_x_hat = (x * x_hat).sum(dim=-1)
relative_reconstruction_bias = x_hat_norm_squared.mean() / x_dot_x_hat.mean()

out["l2_loss"] = l2_loss.item()
out["l1_loss"] = l1_loss.item()
out["l0"] = l0.item()
out["frac_alive"] = frac_alive.item()
out["frac_variance_explained"] = frac_variance_explained.item()
out["cossim"] = cossim.item()
out["l2_ratio"] = l2_ratio.item()
out['relative_reconstruction_bias'] = relative_reconstruction_bias.item()
out["l2_loss"] += l2_loss.item()
out["l1_loss"] += l1_loss.item()
out["l0"] += l0.item()
out["frac_variance_explained"] += frac_variance_explained.item()
out["cossim"] += cossim.item()
out["l2_ratio"] += l2_ratio.item()
out['relative_reconstruction_bias'] += relative_reconstruction_bias.item()

if not isinstance(activations, (ActivationBuffer, NNsightActivationBuffer)):
return out
continue

# compute loss recovered
loss_original, loss_reconstructed, loss_zero = loss_recovered(
Expand All @@ -218,9 +231,13 @@ def evaluate(
)
frac_recovered = (loss_reconstructed - loss_zero) / (loss_original - loss_zero)

out["loss_original"] = loss_original.item()
out["loss_reconstructed"] = loss_reconstructed.item()
out["loss_zero"] = loss_zero.item()
out["frac_recovered"] = frac_recovered.item()
out["loss_original"] += loss_original.item()
out["loss_reconstructed"] += loss_reconstructed.item()
out["loss_zero"] += loss_zero.item()
out["frac_recovered"] += frac_recovered.item()

out = {key: value / n_batches for key, value in out.items()}
frac_alive = (active_features != 0).float().sum() / dictionary.dict_size
out["frac_alive"] = frac_alive.item()

return out
return out
2 changes: 1 addition & 1 deletion interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _list_decode(x):
inputs = buffer.tokenized_batch(batch_size=n_inputs)

with t.no_grad(), model.trace(inputs, **tracer_kwargs):
tokens = model.input[1][
tokens = model.inputs[1][
"input_ids"
].save() # if you're getting errors, check here; might only work for pythia models
activations = submodule.output
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ circuitsvis>=1.43.2
datasets>=2.18.0
einops>=0.7.0
matplotlib>=3.8.3
nnsight>=0.2.11
nnsight>=0.3.0
pandas>=2.2.1
plotly>=5.18.0
torch>=2.1.2
tqdm>=4.66.1
umap-learn>=0.5.6
zstandard>=0.22.0
wandb
wandb>=0.12.0
pytest>=6.2.4
Loading