Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions dictionary_learning/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class ActivationBuffer:
"""
Implements a buffer of activations. The buffer stores activations from a model,
yields them in batches, and refreshes them when the buffer is less than half full.

max_activation_norm_multiple: remove all activations with norm greater than median norm * max_activation_norm_multiple. 10 is a good default.
This is useful for models like Qwen which have random, unpredictable high norm activation sinks which reduce training effectiveness.
"""
def __init__(self,
data, # generator which yields text data
Expand All @@ -28,8 +31,9 @@ def __init__(self,
out_batch_size=8192, # size of batches in which to yield activations
device='cpu', # device on which to store the activations
remove_bos: bool = False,
add_special_tokens: bool = True,
):
add_special_tokens: bool = True,
max_activation_norm_multiple: int | None = None,
):

if io not in ['in', 'out']:
raise ValueError("io must be either 'in' or 'out'")
Expand Down Expand Up @@ -58,6 +62,7 @@ def __init__(self,
self.device = device
self.add_special_tokens = add_special_tokens
self.remove_bos = remove_bos
self.remove_high_norm = max_activation_norm_multiple

if remove_bos and self.model.tokenizer.bos_token_id is None:
print(
Expand Down Expand Up @@ -155,6 +160,13 @@ def refresh(self):
first_one = (mask.to(t.int64).cumsum(dim=1) == 1) & mask
mask = mask & ~first_one

if self.remove_high_norm is not None:
# some models (like Qwen) have random high norm activation sinks which reduce training effectiveness
norms_BL = hidden_states.norm(dim=-1)
median_norm = norms_BL.median()
norm_mask = norms_BL > median_norm * self.remove_high_norm
mask = mask & ~norm_mask

hidden_states = hidden_states[mask]

remaining_space = self.activation_buffer_size - current_idx
Expand Down
12 changes: 12 additions & 0 deletions dictionary_learning/pytorch_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class ActivationBuffer:
"""
Implements a buffer of activations. The buffer stores activations from a model,
yields them in batches, and refreshes them when the buffer is less than half full.

max_activation_norm_multiple: remove all activations with norm greater than median norm * max_activation_norm_multiple. 10 is a good default.
This is useful for models like Qwen which have random, unpredictable high norm activation sinks which reduce training effectiveness.
"""

def __init__(
Expand All @@ -89,6 +92,7 @@ def __init__(
device="cpu", # device on which to store the activations
remove_bos: bool = False,
add_special_tokens: bool = True,
max_activation_norm_multiple: int | None = None,
):
if io not in ["in", "out"]:
raise ValueError("io must be either 'in' or 'out'")
Expand Down Expand Up @@ -120,6 +124,7 @@ def __init__(
self.add_special_tokens = add_special_tokens
self.tokenizer = AutoTokenizer.from_pretrained(model.name_or_path)
self.remove_bos = remove_bos
self.remove_high_norm = max_activation_norm_multiple

if remove_bos and self.tokenizer.bos_token_id is None:
print(
Expand Down Expand Up @@ -208,6 +213,13 @@ def refresh(self):
first_one = (mask.to(t.int64).cumsum(dim=1) == 1) & mask
mask = mask & ~first_one

if self.remove_high_norm is not None:
# some models (like Qwen) have random high norm activation sinks which reduce training effectiveness
norms_BL = hidden_states.norm(dim=-1)
median_norm = norms_BL.median()
norm_mask = norms_BL > median_norm * self.remove_high_norm
mask = mask & ~norm_mask

hidden_states = hidden_states[mask]

remaining_space = self.activation_buffer_size - current_idx
Expand Down