From b8bb4ae91b1b45eb60291a1922acbd6ef277f86d Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 21 Aug 2025 01:15:22 +0000 Subject: [PATCH 1/2] Add optional mask for high norm tokens --- dictionary_learning/buffer.py | 12 ++++++++++++ dictionary_learning/pytorch_buffer.py | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/dictionary_learning/buffer.py b/dictionary_learning/buffer.py index e60a6dd..0594694 100644 --- a/dictionary_learning/buffer.py +++ b/dictionary_learning/buffer.py @@ -15,6 +15,9 @@ class ActivationBuffer: """ Implements a buffer of activations. The buffer stores activations from a model, yields them in batches, and refreshes them when the buffer is less than half full. + + remove_high_norm: remove all activations with norm greater than median norm * remove_high_norm. 10 is a good default. + This is useful for models like Qwen which have random, unpredictable high norm activation sinks which reduce training effectiveness. """ def __init__(self, data, # generator which yields text data @@ -29,6 +32,7 @@ def __init__(self, device='cpu', # device on which to store the activations remove_bos: bool = False, add_special_tokens: bool = True, + remove_high_norm: int | None = None, ): if io not in ['in', 'out']: @@ -58,6 +62,7 @@ def __init__(self, self.device = device self.add_special_tokens = add_special_tokens self.remove_bos = remove_bos + self.remove_high_norm = remove_high_norm if remove_bos and self.model.tokenizer.bos_token_id is None: print( @@ -155,6 +160,13 @@ def refresh(self): first_one = (mask.to(t.int64).cumsum(dim=1) == 1) & mask mask = mask & ~first_one + if self.remove_high_norm is not None: + # some models (like Qwen) have random high norm activation sinks which reduce training effectiveness + norms_BL = hidden_states.norm(dim=-1) + median_norm = norms_BL.median() + norm_mask = norms_BL > median_norm * self.remove_high_norm + mask = mask & ~norm_mask + hidden_states = hidden_states[mask] remaining_space = self.activation_buffer_size - current_idx diff --git a/dictionary_learning/pytorch_buffer.py b/dictionary_learning/pytorch_buffer.py index a31d4c6..6c34a11 100644 --- a/dictionary_learning/pytorch_buffer.py +++ b/dictionary_learning/pytorch_buffer.py @@ -73,6 +73,9 @@ class ActivationBuffer: """ Implements a buffer of activations. The buffer stores activations from a model, yields them in batches, and refreshes them when the buffer is less than half full. + + remove_high_norm: remove all activations with norm greater than median norm * remove_high_norm. 10 is a good default. + This is useful for models like Qwen which have random, unpredictable high norm activation sinks which reduce training effectiveness. """ def __init__( @@ -89,6 +92,7 @@ def __init__( device="cpu", # device on which to store the activations remove_bos: bool = False, add_special_tokens: bool = True, + remove_high_norm: int | None = None, ): if io not in ["in", "out"]: raise ValueError("io must be either 'in' or 'out'") @@ -120,6 +124,7 @@ def __init__( self.add_special_tokens = add_special_tokens self.tokenizer = AutoTokenizer.from_pretrained(model.name_or_path) self.remove_bos = remove_bos + self.remove_high_norm = remove_high_norm if remove_bos and self.tokenizer.bos_token_id is None: print( @@ -208,6 +213,13 @@ def refresh(self): first_one = (mask.to(t.int64).cumsum(dim=1) == 1) & mask mask = mask & ~first_one + if self.remove_high_norm is not None: + # some models (like Qwen) have random high norm activation sinks which reduce training effectiveness + norms_BL = hidden_states.norm(dim=-1) + median_norm = norms_BL.median() + norm_mask = norms_BL > median_norm * self.remove_high_norm + mask = mask & ~norm_mask + hidden_states = hidden_states[mask] remaining_space = self.activation_buffer_size - current_idx From 1d2737a9ecbe1bc4d3137dec5caab8b320608217 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 21 Aug 2025 01:21:00 +0000 Subject: [PATCH 2/2] Improve variable name --- dictionary_learning/buffer.py | 10 +++++----- dictionary_learning/pytorch_buffer.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dictionary_learning/buffer.py b/dictionary_learning/buffer.py index 0594694..7d3c260 100644 --- a/dictionary_learning/buffer.py +++ b/dictionary_learning/buffer.py @@ -16,7 +16,7 @@ class ActivationBuffer: Implements a buffer of activations. The buffer stores activations from a model, yields them in batches, and refreshes them when the buffer is less than half full. - remove_high_norm: remove all activations with norm greater than median norm * remove_high_norm. 10 is a good default. + max_activation_norm_multiple: remove all activations with norm greater than median norm * max_activation_norm_multiple. 10 is a good default. This is useful for models like Qwen which have random, unpredictable high norm activation sinks which reduce training effectiveness. """ def __init__(self, @@ -31,9 +31,9 @@ def __init__(self, out_batch_size=8192, # size of batches in which to yield activations device='cpu', # device on which to store the activations remove_bos: bool = False, - add_special_tokens: bool = True, - remove_high_norm: int | None = None, - ): + add_special_tokens: bool = True, + max_activation_norm_multiple: int | None = None, + ): if io not in ['in', 'out']: raise ValueError("io must be either 'in' or 'out'") @@ -62,7 +62,7 @@ def __init__(self, self.device = device self.add_special_tokens = add_special_tokens self.remove_bos = remove_bos - self.remove_high_norm = remove_high_norm + self.remove_high_norm = max_activation_norm_multiple if remove_bos and self.model.tokenizer.bos_token_id is None: print( diff --git a/dictionary_learning/pytorch_buffer.py b/dictionary_learning/pytorch_buffer.py index 6c34a11..3533f83 100644 --- a/dictionary_learning/pytorch_buffer.py +++ b/dictionary_learning/pytorch_buffer.py @@ -74,7 +74,7 @@ class ActivationBuffer: Implements a buffer of activations. The buffer stores activations from a model, yields them in batches, and refreshes them when the buffer is less than half full. - remove_high_norm: remove all activations with norm greater than median norm * remove_high_norm. 10 is a good default. + max_activation_norm_multiple: remove all activations with norm greater than median norm * max_activation_norm_multiple. 10 is a good default. This is useful for models like Qwen which have random, unpredictable high norm activation sinks which reduce training effectiveness. """ @@ -92,7 +92,7 @@ def __init__( device="cpu", # device on which to store the activations remove_bos: bool = False, add_special_tokens: bool = True, - remove_high_norm: int | None = None, + max_activation_norm_multiple: int | None = None, ): if io not in ["in", "out"]: raise ValueError("io must be either 'in' or 'out'") @@ -124,7 +124,7 @@ def __init__( self.add_special_tokens = add_special_tokens self.tokenizer = AutoTokenizer.from_pretrained(model.name_or_path) self.remove_bos = remove_bos - self.remove_high_norm = remove_high_norm + self.remove_high_norm = max_activation_norm_multiple if remove_bos and self.tokenizer.bos_token_id is None: print(