Add end to end test, upgrade nnsight to support 0.3.0, fix bugs #30
+371
−68
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
A handful of updates here:
The first commit adds an end to end test that trains a Standard and TopK SAE on Pythia for 10M tokens (it takes ~2 minutes on an RTX 3090) and then evaluates them, and asserts that all eval_results values are within a tolerance of 0.01. The end to end test is fully deterministic and I could check for equality instead, but this comes with the risk of failures do to very minor changes in dependencies.
The next two commits upgrade for compatibility with nnsight 0.3.0. First, I rename .input to .inputs, as that was a breaking change in nnsight 0.3.0.
Second, this if statement in
evaluate
:if type(submodule.input.shape) == tuple: x = x[0]
was raising errors once I upgraded nnsight. I did this instead, which correctly works for both submodules that do and do not output tuples:In the fourth commit, I first fix the frac_alive calculation. The old implementation used t.flatten(f, start_dim=0, end_dim=1).any(dim=0) on the feature matrix of shape (batch_size, dict_size), which flattened the entire matrix into a 1D vector before calling any(). This resulted in a single scalar boolean indicating if ANY feature was active in the batch, rather than tracking per-feature activity. The fix changes this to properly track feature activity by summing across the batch dimension and then checking for non-zero activation, giving us accurate per-feature activity metrics tracked via (active_features != 0).float().sum() / dictionary.dict_size.
Before, frac_alive always was 1/dict_size. Now, it appears to do sensible things like increasing with the number of datapoints we evaluate on.
In the fourth and fifth commits, I also added the ability to evaluate SAEs over multiple batches of data. The behavior is completely unchanged with the default n_batches = 1. But, for a more accurate evaluation measurement n_batches can be increased. I increased the evaluation size for the end to end test to have a more accurate estimate that isn't as subject to random seeds.
In the sixth commit, I added early stopping (introduced in nnsight 0.3.0) when refreshing the activation buffer, for faster SAE training.
In the seventh commit, I changed save_steps to an Optional list of ints, instead of an Optional int. With save steps being an int, we had to take linearly spaced checkpoints. However, we often care about the evolution of the SAE on log spaced checkpoints. save_steps being a list of ints gives us the flexibility to do both.
For every change, I verified that the end to end test passed correctly. I did notice some minor changes in some values when upgrading from nnsight 0.2 to 0.3, such as l0 going from 61.12999725341797 to 61.29999923706055. The initial end to end test used a relatively small amount of data, so very minor differences could have caused this.