Skip to content

Conversation

@TDulka
Copy link

@TDulka TDulka commented Nov 29, 2024

Implements a BatchTopKToJump class which trains like a BatchTopK but switches to JumpRELU in inference time.

The thresholds for jumpRELU are taken from the minimum above zero latent activations during the last part of the training (last 10% by default). This results in having slightly more active latents at inference time (in my experiments it was around 106 with k=100).

Here is a notebook that runs the training and some simple evaluation, comparing the BatchTopKToJump ran in the batchTopK mode or JumpReLU mode to a classical JumpReLU or a classical TopK approach (probably could be done better).
https://colab.research.google.com/drive/1GuFaBmbVvM-rQoWjgMTZAHDxl76xaE1G?usp=sharing

Here are the three wandb runs (I ran more runs before when iterating but wanted to have just three cleaner for comparison).
https://wandb.ai/tomasdulka/batchtopk_jumprelu

@adamkarvonen
Copy link
Collaborator

This was implemented with a single global threshold in this PR: #31

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants