Skip to content

Conversation

adamkarvonen
Copy link
Collaborator

A handful of updates here:

The first commit adds an end to end test that trains a Standard and TopK SAE on Pythia for 10M tokens (it takes ~2 minutes on an RTX 3090) and then evaluates them, and asserts that all eval_results values are within a tolerance of 0.01. The end to end test is fully deterministic and I could check for equality instead, but this comes with the risk of failures do to very minor changes in dependencies.

The next two commits upgrade for compatibility with nnsight 0.3.0. First, I rename .input to .inputs, as that was a breaking change in nnsight 0.3.0.

Second, this if statement in evaluate: if type(submodule.input.shape) == tuple: x = x[0] was raising errors once I upgraded nnsight. I did this instead, which correctly works for both submodules that do and do not output tuples:

    with model.trace("_"):
        temp_output = submodule.output.save()

    output_is_tuple = False
    # Note: isinstance() won't work here as torch.Size is a subclass of tuple,
    # so isinstance(temp_output.shape, tuple) would return True even for torch.Size.
    if type(temp_output.shape) == tuple:
        output_is_tuple = True

In the fourth commit, I first fix the frac_alive calculation. The old implementation used t.flatten(f, start_dim=0, end_dim=1).any(dim=0) on the feature matrix of shape (batch_size, dict_size), which flattened the entire matrix into a 1D vector before calling any(). This resulted in a single scalar boolean indicating if ANY feature was active in the batch, rather than tracking per-feature activity. The fix changes this to properly track feature activity by summing across the batch dimension and then checking for non-zero activation, giving us accurate per-feature activity metrics tracked via (active_features != 0).float().sum() / dictionary.dict_size.

Before, frac_alive always was 1/dict_size. Now, it appears to do sensible things like increasing with the number of datapoints we evaluate on.

In the fourth and fifth commits, I also added the ability to evaluate SAEs over multiple batches of data. The behavior is completely unchanged with the default n_batches = 1. But, for a more accurate evaluation measurement n_batches can be increased. I increased the evaluation size for the end to end test to have a more accurate estimate that isn't as subject to random seeds.

In the sixth commit, I added early stopping (introduced in nnsight 0.3.0) when refreshing the activation buffer, for faster SAE training.

In the seventh commit, I changed save_steps to an Optional list of ints, instead of an Optional int. With save steps being an int, we had to take linearly spaced checkpoints. However, we often care about the evolution of the SAE on log spaced checkpoints. save_steps being a list of ints gives us the flexibility to do both.

For every change, I verified that the end to end test passed correctly. I did notice some minor changes in some values when upgrading from nnsight 0.2 to 0.3, such as l0 going from 61.12999725341797 to 61.29999923706055. The initial end to end test used a relatively small amount of data, so very minor differences could have caused this.

@adamkarvonen adamkarvonen merged commit c4eed3c into saprmarks:main Dec 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant