Skip to content

Commit c4eed3c

Browse files
authored
Merge pull request #30 from adamkarvonen/add_tests
Add end to end test, upgrade nnsight to support 0.3.0, fix bugs
2 parents 2ec1890 + d350415 commit c4eed3c

File tree

6 files changed

+371
-68
lines changed

6 files changed

+371
-68
lines changed

buffer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,12 @@ def refresh(self):
121121
invoker_args={"truncation": True, "max_length": self.ctx_len},
122122
):
123123
if self.io == "in":
124-
hidden_states = self.submodule.input[0].save()
124+
hidden_states = self.submodule.inputs[0].save()
125125
else:
126126
hidden_states = self.submodule.output.save()
127-
input = self.model.input.save()
127+
input = self.model.inputs.save()
128+
129+
self.submodule.output.stop()
128130
attn_mask = input.value[1]["attention_mask"]
129131
hidden_states = hidden_states.value
130132
if isinstance(hidden_states, tuple):
@@ -251,8 +253,8 @@ def refresh(self):
251253
while len(self.activations) < self.n_ctxs * self.ctx_len:
252254
with t.no_grad():
253255
with self.model.trace(self.text_batch(), **tracer_kwargs, invoker_args={'truncation': True, 'max_length': self.ctx_len}, remote=self.remote):
254-
input = self.model.input.save()
255-
hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.input[0][0]#.save()
256+
input = self.model.inputs.save()
257+
hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.inputs[0][0]#.save()
256258
if isinstance(hidden_states, tuple):
257259
hidden_states = hidden_states[0]
258260

@@ -416,7 +418,7 @@ def refresh(self):
416418
invoker_args={"truncation": True, "max_length": self.ctx_len},
417419
):
418420
if self.io in ["in", "in_and_out"]:
419-
hidden_states_in = self.submodule.input[0].save()
421+
hidden_states_in = self.submodule.inputs[0].save()
420422
if self.io in ["out", "in_and_out"]:
421423
hidden_states_out = self.submodule.output.save()
422424

evaluation.py

Lines changed: 65 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
"""
44

55
import torch as t
6+
from collections import defaultdict
7+
68
from .buffer import ActivationBuffer, NNsightActivationBuffer
79
from nnsight import LanguageModel
810
from .config import DEBUG
@@ -22,12 +24,21 @@ def loss_recovered(
2224
How much of the model's loss is recovered by replacing the component output
2325
with the reconstruction by the autoencoder?
2426
"""
25-
27+
2628
if max_len is None:
2729
invoker_args = {}
2830
else:
2931
invoker_args = {"truncation": True, "max_length": max_len }
3032

33+
with model.trace("_"):
34+
temp_output = submodule.output.save()
35+
36+
output_is_tuple = False
37+
# Note: isinstance() won't work here as torch.Size is a subclass of tuple,
38+
# so isinstance(temp_output.shape, tuple) would return True even for torch.Size.
39+
if type(temp_output.shape) == tuple:
40+
output_is_tuple = True
41+
3142
# unmodified logits
3243
with model.trace(text, invoker_args=invoker_args):
3344
logits_original = model.output.save()
@@ -36,57 +47,57 @@ def loss_recovered(
3647
# logits when replacing component activations with reconstruction by autoencoder
3748
with model.trace(text, **tracer_args, invoker_args=invoker_args):
3849
if io == 'in':
39-
x = submodule.input[0]
40-
if type(submodule.input.shape) == tuple: x = x[0]
50+
x = submodule.input
4151
if normalize_batch:
4252
scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
4353
x = x * scale
4454
elif io == 'out':
4555
x = submodule.output
46-
if type(submodule.output.shape) == tuple: x = x[0]
56+
if output_is_tuple: x = x[0]
4757
if normalize_batch:
4858
scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
4959
x = x * scale
5060
elif io == 'in_and_out':
51-
x = submodule.input[0]
52-
if type(submodule.input.shape) == tuple: x = x[0]
53-
print(f'x.shape: {x.shape}')
61+
x = submodule.input
5462
if normalize_batch:
5563
scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
5664
x = x * scale
5765
else:
5866
raise ValueError(f"Invalid value for io: {io}")
5967
x = x.save()
6068

61-
# pull this out so dictionary can be written without FakeTensor (top_k needs this)
62-
x_hat = dictionary(x.view(-1, x.shape[-1])).view(x.shape).to(model.dtype)
69+
# If we incorrectly handle output_is_tuple, such as with some mlp submodules, we will get an error here.
70+
assert len(x.shape) == 3, f"Expected x to have shape (B, L, D), got {x.shape}, output_is_tuple: {output_is_tuple}"
71+
72+
x_hat = dictionary(x).to(model.dtype)
6373

6474
# intervene with `x_hat`
6575
with model.trace(text, **tracer_args, invoker_args=invoker_args):
6676
if io == 'in':
67-
x = submodule.input[0]
77+
x = submodule.input
6878
if normalize_batch:
6979
scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
7080
x_hat = x_hat / scale
71-
if type(submodule.input.shape) == tuple:
72-
submodule.input[0][:] = x_hat
73-
else:
74-
submodule.input = x_hat
81+
submodule.input[:] = x_hat
7582
elif io == 'out':
7683
x = submodule.output
84+
if output_is_tuple: x = x[0]
7785
if normalize_batch:
7886
scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
7987
x_hat = x_hat / scale
80-
if type(submodule.output.shape) == tuple:
81-
submodule.output = (x_hat,)
88+
if output_is_tuple:
89+
submodule.output[0][:] = x_hat
8290
else:
83-
submodule.output = x_hat
91+
submodule.output[:] = x_hat
8492
elif io == 'in_and_out':
85-
x = submodule.input[0]
93+
x = submodule.input
8694
if normalize_batch:
8795
scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
8896
x_hat = x_hat / scale
89-
submodule.output = x_hat
97+
if output_is_tuple:
98+
submodule.output[0][:] = x_hat
99+
else:
100+
submodule.output[:] = x_hat
90101
else:
91102
raise ValueError(f"Invalid value for io: {io}")
92103

@@ -96,22 +107,20 @@ def loss_recovered(
96107
# logits when replacing component activations with zeros
97108
with model.trace(text, **tracer_args, invoker_args=invoker_args):
98109
if io == 'in':
99-
x = submodule.input[0]
100-
if type(submodule.input.shape) == tuple:
101-
submodule.input[0][:] = t.zeros_like(x[0])
102-
else:
103-
submodule.input = t.zeros_like(x)
110+
x = submodule.input
111+
submodule.input[:] = t.zeros_like(x)
104112
elif io in ['out', 'in_and_out']:
105113
x = submodule.output
106-
if type(submodule.output.shape) == tuple:
114+
if output_is_tuple:
107115
submodule.output[0][:] = t.zeros_like(x[0])
108116
else:
109-
submodule.output = t.zeros_like(x)
117+
submodule.output[:] = t.zeros_like(x)
110118
else:
111119
raise ValueError(f"Invalid value for io: {io}")
112120

113-
input = model.input.save()
121+
input = model.inputs.save()
114122
logits_zero = model.output.save()
123+
115124
logits_zero = logits_zero.value
116125

117126
# get everything into the right format
@@ -144,7 +153,7 @@ def loss_recovered(
144153

145154
return tuple(losses)
146155

147-
156+
@t.no_grad()
148157
def evaluate(
149158
dictionary, # a dictionary
150159
activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered
@@ -154,26 +163,31 @@ def evaluate(
154163
normalize_batch=False, # normalize batch before passing through dictionary
155164
tracer_args={'use_cache': False, 'output_attentions': False}, # minimize cache during model trace.
156165
device="cpu",
166+
n_batches: int = 1,
157167
):
158-
with t.no_grad():
159-
160-
out = {} # dict of results
168+
assert n_batches > 0
169+
out = defaultdict(float)
170+
active_features = t.zeros(dictionary.dict_size, dtype=t.float32, device=device)
161171

172+
for _ in range(n_batches):
162173
try:
163174
x = next(activations).to(device)
164175
if normalize_batch:
165176
x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim ** 0.5)
166-
167177
except StopIteration:
168178
raise StopIteration(
169179
"Not enough activations in buffer. Pass a buffer with a smaller batch size or more data."
170180
)
171-
172181
x_hat, f = dictionary(x, output_features=True)
173182
l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean()
174183
l1_loss = f.norm(p=1, dim=-1).mean()
175184
l0 = (f != 0).float().sum(dim=-1).mean()
176-
frac_alive = t.flatten(f, start_dim=0, end_dim=1).any(dim=0).sum() / dictionary.dict_size
185+
186+
features_BF = t.flatten(f, start_dim=0, end_dim=-2).to(dtype=t.float32) # If f is shape (B, L, D), flatten to (B*L, D)
187+
assert features_BF.shape[-1] == dictionary.dict_size
188+
assert len(features_BF.shape) == 2
189+
190+
active_features += features_BF.sum(dim=0)
177191

178192
# cosine similarity between x and x_hat
179193
x_normed = x / t.linalg.norm(x, dim=-1, keepdim=True)
@@ -193,17 +207,16 @@ def evaluate(
193207
x_dot_x_hat = (x * x_hat).sum(dim=-1)
194208
relative_reconstruction_bias = x_hat_norm_squared.mean() / x_dot_x_hat.mean()
195209

196-
out["l2_loss"] = l2_loss.item()
197-
out["l1_loss"] = l1_loss.item()
198-
out["l0"] = l0.item()
199-
out["frac_alive"] = frac_alive.item()
200-
out["frac_variance_explained"] = frac_variance_explained.item()
201-
out["cossim"] = cossim.item()
202-
out["l2_ratio"] = l2_ratio.item()
203-
out['relative_reconstruction_bias'] = relative_reconstruction_bias.item()
210+
out["l2_loss"] += l2_loss.item()
211+
out["l1_loss"] += l1_loss.item()
212+
out["l0"] += l0.item()
213+
out["frac_variance_explained"] += frac_variance_explained.item()
214+
out["cossim"] += cossim.item()
215+
out["l2_ratio"] += l2_ratio.item()
216+
out['relative_reconstruction_bias'] += relative_reconstruction_bias.item()
204217

205218
if not isinstance(activations, (ActivationBuffer, NNsightActivationBuffer)):
206-
return out
219+
continue
207220

208221
# compute loss recovered
209222
loss_original, loss_reconstructed, loss_zero = loss_recovered(
@@ -218,9 +231,13 @@ def evaluate(
218231
)
219232
frac_recovered = (loss_reconstructed - loss_zero) / (loss_original - loss_zero)
220233

221-
out["loss_original"] = loss_original.item()
222-
out["loss_reconstructed"] = loss_reconstructed.item()
223-
out["loss_zero"] = loss_zero.item()
224-
out["frac_recovered"] = frac_recovered.item()
234+
out["loss_original"] += loss_original.item()
235+
out["loss_reconstructed"] += loss_reconstructed.item()
236+
out["loss_zero"] += loss_zero.item()
237+
out["frac_recovered"] += frac_recovered.item()
238+
239+
out = {key: value / n_batches for key, value in out.items()}
240+
frac_alive = (active_features != 0).float().sum() / dictionary.dict_size
241+
out["frac_alive"] = frac_alive.item()
225242

226-
return out
243+
return out

interp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _list_decode(x):
101101
inputs = buffer.tokenized_batch(batch_size=n_inputs)
102102

103103
with t.no_grad(), model.trace(inputs, **tracer_kwargs):
104-
tokens = model.input[1][
104+
tokens = model.inputs[1][
105105
"input_ids"
106106
].save() # if you're getting errors, check here; might only work for pythia models
107107
activations = submodule.output

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ circuitsvis>=1.43.2
22
datasets>=2.18.0
33
einops>=0.7.0
44
matplotlib>=3.8.3
5-
nnsight>=0.2.11
5+
nnsight>=0.3.0
66
pandas>=2.2.1
77
plotly>=5.18.0
88
torch>=2.1.2
99
tqdm>=4.66.1
1010
umap-learn>=0.5.6
1111
zstandard>=0.22.0
12-
wandb
12+
wandb>=0.12.0
13+
pytest>=6.2.4

0 commit comments

Comments
 (0)