Skip to content

Conversation

adamkarvonen
Copy link
Collaborator

Changes

1. Activation Normalization (Optional)

  • Added optional activation normalization following the Gemascope paper methodology
  • Before training, finds a fixed scalar to set unit mean squared norm of activations to 1
  • Feature is disabled by default

The scaling factors are applied to SAE thresholds and biases during model saving, eliminating the need for normalization at inference time. This approach significantly improves hyperparameter transfer across different layers and models. In particular, Jump ReLU requires significant hyperparameter tuning without this.

2. Global Threshold Implementation for BatchTopK SAE

  • Implemented fixed global threshold to remove inter-input dependencies
  • Tracks average minimum non-zero activation value during training
  • Uses this global threshold by default in encode() method

I also track this global threshold for the TopK SAE during training, but I don't use it by default in the encode method. Using a global threshold provides several benefits: it achieves better loss recovery compared to standard TopK, eliminates feature dependencies within inputs, and enables forward pass on a limited subset of SAE latents (useful for steering, autointerp, etc).

3. Standard SAE Improvements

With the standard SAE, W_dec is now W_enc.T. I also used the correct reconstruction loss for standard and p anneal, which was already used in all other trainers. The initialization provided a major benefit and the reconstruction loss provided a minor benefit.

@adamkarvonen adamkarvonen merged commit 67a7857 into main Dec 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant