diff --git a/dictionary_learning/buffer.py b/dictionary_learning/buffer.py index e60a6dd..7d3c260 100644 --- a/dictionary_learning/buffer.py +++ b/dictionary_learning/buffer.py @@ -15,6 +15,9 @@ class ActivationBuffer: """ Implements a buffer of activations. The buffer stores activations from a model, yields them in batches, and refreshes them when the buffer is less than half full. + + max_activation_norm_multiple: remove all activations with norm greater than median norm * max_activation_norm_multiple. 10 is a good default. + This is useful for models like Qwen which have random, unpredictable high norm activation sinks which reduce training effectiveness. """ def __init__(self, data, # generator which yields text data @@ -28,8 +31,9 @@ def __init__(self, out_batch_size=8192, # size of batches in which to yield activations device='cpu', # device on which to store the activations remove_bos: bool = False, - add_special_tokens: bool = True, - ): + add_special_tokens: bool = True, + max_activation_norm_multiple: int | None = None, + ): if io not in ['in', 'out']: raise ValueError("io must be either 'in' or 'out'") @@ -58,6 +62,7 @@ def __init__(self, self.device = device self.add_special_tokens = add_special_tokens self.remove_bos = remove_bos + self.remove_high_norm = max_activation_norm_multiple if remove_bos and self.model.tokenizer.bos_token_id is None: print( @@ -155,6 +160,13 @@ def refresh(self): first_one = (mask.to(t.int64).cumsum(dim=1) == 1) & mask mask = mask & ~first_one + if self.remove_high_norm is not None: + # some models (like Qwen) have random high norm activation sinks which reduce training effectiveness + norms_BL = hidden_states.norm(dim=-1) + median_norm = norms_BL.median() + norm_mask = norms_BL > median_norm * self.remove_high_norm + mask = mask & ~norm_mask + hidden_states = hidden_states[mask] remaining_space = self.activation_buffer_size - current_idx diff --git a/dictionary_learning/pytorch_buffer.py b/dictionary_learning/pytorch_buffer.py index a31d4c6..3533f83 100644 --- a/dictionary_learning/pytorch_buffer.py +++ b/dictionary_learning/pytorch_buffer.py @@ -73,6 +73,9 @@ class ActivationBuffer: """ Implements a buffer of activations. The buffer stores activations from a model, yields them in batches, and refreshes them when the buffer is less than half full. + + max_activation_norm_multiple: remove all activations with norm greater than median norm * max_activation_norm_multiple. 10 is a good default. + This is useful for models like Qwen which have random, unpredictable high norm activation sinks which reduce training effectiveness. """ def __init__( @@ -89,6 +92,7 @@ def __init__( device="cpu", # device on which to store the activations remove_bos: bool = False, add_special_tokens: bool = True, + max_activation_norm_multiple: int | None = None, ): if io not in ["in", "out"]: raise ValueError("io must be either 'in' or 'out'") @@ -120,6 +124,7 @@ def __init__( self.add_special_tokens = add_special_tokens self.tokenizer = AutoTokenizer.from_pretrained(model.name_or_path) self.remove_bos = remove_bos + self.remove_high_norm = max_activation_norm_multiple if remove_bos and self.tokenizer.bos_token_id is None: print( @@ -208,6 +213,13 @@ def refresh(self): first_one = (mask.to(t.int64).cumsum(dim=1) == 1) & mask mask = mask & ~first_one + if self.remove_high_norm is not None: + # some models (like Qwen) have random high norm activation sinks which reduce training effectiveness + norms_BL = hidden_states.norm(dim=-1) + median_norm = norms_BL.median() + norm_mask = norms_BL > median_norm * self.remove_high_norm + mask = mask & ~norm_mask + hidden_states = hidden_states[mask] remaining_space = self.activation_buffer_size - current_idx