33"""
44
55import torch as t
6+ from collections import defaultdict
7+
68from .buffer import ActivationBuffer , NNsightActivationBuffer
79from nnsight import LanguageModel
810from .config import DEBUG
@@ -22,12 +24,21 @@ def loss_recovered(
2224 How much of the model's loss is recovered by replacing the component output
2325 with the reconstruction by the autoencoder?
2426 """
25-
27+
2628 if max_len is None :
2729 invoker_args = {}
2830 else :
2931 invoker_args = {"truncation" : True , "max_length" : max_len }
3032
33+ with model .trace ("_" ):
34+ temp_output = submodule .output .save ()
35+
36+ output_is_tuple = False
37+ # Note: isinstance() won't work here as torch.Size is a subclass of tuple,
38+ # so isinstance(temp_output.shape, tuple) would return True even for torch.Size.
39+ if type (temp_output .shape ) == tuple :
40+ output_is_tuple = True
41+
3142 # unmodified logits
3243 with model .trace (text , invoker_args = invoker_args ):
3344 logits_original = model .output .save ()
@@ -36,57 +47,57 @@ def loss_recovered(
3647 # logits when replacing component activations with reconstruction by autoencoder
3748 with model .trace (text , ** tracer_args , invoker_args = invoker_args ):
3849 if io == 'in' :
39- x = submodule .input [0 ]
40- if type (submodule .input .shape ) == tuple : x = x [0 ]
50+ x = submodule .input
4151 if normalize_batch :
4252 scale = (dictionary .activation_dim ** 0.5 ) / x .norm (dim = - 1 ).mean ()
4353 x = x * scale
4454 elif io == 'out' :
4555 x = submodule .output
46- if type ( submodule . output . shape ) == tuple : x = x [0 ]
56+ if output_is_tuple : x = x [0 ]
4757 if normalize_batch :
4858 scale = (dictionary .activation_dim ** 0.5 ) / x .norm (dim = - 1 ).mean ()
4959 x = x * scale
5060 elif io == 'in_and_out' :
51- x = submodule .input [0 ]
52- if type (submodule .input .shape ) == tuple : x = x [0 ]
53- print (f'x.shape: { x .shape } ' )
61+ x = submodule .input
5462 if normalize_batch :
5563 scale = (dictionary .activation_dim ** 0.5 ) / x .norm (dim = - 1 ).mean ()
5664 x = x * scale
5765 else :
5866 raise ValueError (f"Invalid value for io: { io } " )
5967 x = x .save ()
6068
61- # pull this out so dictionary can be written without FakeTensor (top_k needs this)
62- x_hat = dictionary (x .view (- 1 , x .shape [- 1 ])).view (x .shape ).to (model .dtype )
69+ # If we incorrectly handle output_is_tuple, such as with some mlp submodules, we will get an error here.
70+ assert len (x .shape ) == 3 , f"Expected x to have shape (B, L, D), got { x .shape } , output_is_tuple: { output_is_tuple } "
71+
72+ x_hat = dictionary (x ).to (model .dtype )
6373
6474 # intervene with `x_hat`
6575 with model .trace (text , ** tracer_args , invoker_args = invoker_args ):
6676 if io == 'in' :
67- x = submodule .input [ 0 ]
77+ x = submodule .input
6878 if normalize_batch :
6979 scale = (dictionary .activation_dim ** 0.5 ) / x .norm (dim = - 1 ).mean ()
7080 x_hat = x_hat / scale
71- if type (submodule .input .shape ) == tuple :
72- submodule .input [0 ][:] = x_hat
73- else :
74- submodule .input = x_hat
81+ submodule .input [:] = x_hat
7582 elif io == 'out' :
7683 x = submodule .output
84+ if output_is_tuple : x = x [0 ]
7785 if normalize_batch :
7886 scale = (dictionary .activation_dim ** 0.5 ) / x .norm (dim = - 1 ).mean ()
7987 x_hat = x_hat / scale
80- if type ( submodule . output . shape ) == tuple :
81- submodule .output = ( x_hat ,)
88+ if output_is_tuple :
89+ submodule .output [ 0 ][:] = x_hat
8290 else :
83- submodule .output = x_hat
91+ submodule .output [:] = x_hat
8492 elif io == 'in_and_out' :
85- x = submodule .input [ 0 ]
93+ x = submodule .input
8694 if normalize_batch :
8795 scale = (dictionary .activation_dim ** 0.5 ) / x .norm (dim = - 1 ).mean ()
8896 x_hat = x_hat / scale
89- submodule .output = x_hat
97+ if output_is_tuple :
98+ submodule .output [0 ][:] = x_hat
99+ else :
100+ submodule .output [:] = x_hat
90101 else :
91102 raise ValueError (f"Invalid value for io: { io } " )
92103
@@ -96,22 +107,20 @@ def loss_recovered(
96107 # logits when replacing component activations with zeros
97108 with model .trace (text , ** tracer_args , invoker_args = invoker_args ):
98109 if io == 'in' :
99- x = submodule .input [0 ]
100- if type (submodule .input .shape ) == tuple :
101- submodule .input [0 ][:] = t .zeros_like (x [0 ])
102- else :
103- submodule .input = t .zeros_like (x )
110+ x = submodule .input
111+ submodule .input [:] = t .zeros_like (x )
104112 elif io in ['out' , 'in_and_out' ]:
105113 x = submodule .output
106- if type ( submodule . output . shape ) == tuple :
114+ if output_is_tuple :
107115 submodule .output [0 ][:] = t .zeros_like (x [0 ])
108116 else :
109- submodule .output = t .zeros_like (x )
117+ submodule .output [:] = t .zeros_like (x )
110118 else :
111119 raise ValueError (f"Invalid value for io: { io } " )
112120
113- input = model .input .save ()
121+ input = model .inputs .save ()
114122 logits_zero = model .output .save ()
123+
115124 logits_zero = logits_zero .value
116125
117126 # get everything into the right format
@@ -144,7 +153,7 @@ def loss_recovered(
144153
145154 return tuple (losses )
146155
147-
156+ @ t . no_grad ()
148157def evaluate (
149158 dictionary , # a dictionary
150159 activations , # a generator of activations; if an ActivationBuffer, also compute loss recovered
@@ -154,26 +163,31 @@ def evaluate(
154163 normalize_batch = False , # normalize batch before passing through dictionary
155164 tracer_args = {'use_cache' : False , 'output_attentions' : False }, # minimize cache during model trace.
156165 device = "cpu" ,
166+ n_batches : int = 1 ,
157167):
158- with t . no_grad ():
159-
160- out = {} # dict of results
168+ assert n_batches > 0
169+ out = defaultdict ( float )
170+ active_features = t . zeros ( dictionary . dict_size , dtype = t . float32 , device = device )
161171
172+ for _ in range (n_batches ):
162173 try :
163174 x = next (activations ).to (device )
164175 if normalize_batch :
165176 x = x / x .norm (dim = - 1 ).mean () * (dictionary .activation_dim ** 0.5 )
166-
167177 except StopIteration :
168178 raise StopIteration (
169179 "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data."
170180 )
171-
172181 x_hat , f = dictionary (x , output_features = True )
173182 l2_loss = t .linalg .norm (x - x_hat , dim = - 1 ).mean ()
174183 l1_loss = f .norm (p = 1 , dim = - 1 ).mean ()
175184 l0 = (f != 0 ).float ().sum (dim = - 1 ).mean ()
176- frac_alive = t .flatten (f , start_dim = 0 , end_dim = 1 ).any (dim = 0 ).sum () / dictionary .dict_size
185+
186+ features_BF = t .flatten (f , start_dim = 0 , end_dim = - 2 ).to (dtype = t .float32 ) # If f is shape (B, L, D), flatten to (B*L, D)
187+ assert features_BF .shape [- 1 ] == dictionary .dict_size
188+ assert len (features_BF .shape ) == 2
189+
190+ active_features += features_BF .sum (dim = 0 )
177191
178192 # cosine similarity between x and x_hat
179193 x_normed = x / t .linalg .norm (x , dim = - 1 , keepdim = True )
@@ -193,17 +207,16 @@ def evaluate(
193207 x_dot_x_hat = (x * x_hat ).sum (dim = - 1 )
194208 relative_reconstruction_bias = x_hat_norm_squared .mean () / x_dot_x_hat .mean ()
195209
196- out ["l2_loss" ] = l2_loss .item ()
197- out ["l1_loss" ] = l1_loss .item ()
198- out ["l0" ] = l0 .item ()
199- out ["frac_alive" ] = frac_alive .item ()
200- out ["frac_variance_explained" ] = frac_variance_explained .item ()
201- out ["cossim" ] = cossim .item ()
202- out ["l2_ratio" ] = l2_ratio .item ()
203- out ['relative_reconstruction_bias' ] = relative_reconstruction_bias .item ()
210+ out ["l2_loss" ] += l2_loss .item ()
211+ out ["l1_loss" ] += l1_loss .item ()
212+ out ["l0" ] += l0 .item ()
213+ out ["frac_variance_explained" ] += frac_variance_explained .item ()
214+ out ["cossim" ] += cossim .item ()
215+ out ["l2_ratio" ] += l2_ratio .item ()
216+ out ['relative_reconstruction_bias' ] += relative_reconstruction_bias .item ()
204217
205218 if not isinstance (activations , (ActivationBuffer , NNsightActivationBuffer )):
206- return out
219+ continue
207220
208221 # compute loss recovered
209222 loss_original , loss_reconstructed , loss_zero = loss_recovered (
@@ -218,9 +231,13 @@ def evaluate(
218231 )
219232 frac_recovered = (loss_reconstructed - loss_zero ) / (loss_original - loss_zero )
220233
221- out ["loss_original" ] = loss_original .item ()
222- out ["loss_reconstructed" ] = loss_reconstructed .item ()
223- out ["loss_zero" ] = loss_zero .item ()
224- out ["frac_recovered" ] = frac_recovered .item ()
234+ out ["loss_original" ] += loss_original .item ()
235+ out ["loss_reconstructed" ] += loss_reconstructed .item ()
236+ out ["loss_zero" ] += loss_zero .item ()
237+ out ["frac_recovered" ] += frac_recovered .item ()
238+
239+ out = {key : value / n_batches for key , value in out .items ()}
240+ frac_alive = (active_features != 0 ).float ().sum () / dictionary .dict_size
241+ out ["frac_alive" ] = frac_alive .item ()
225242
226- return out
243+ return out
0 commit comments