Skip to content

Commit 0a15908

Browse files
committed
add logging
1 parent 3575636 commit 0a15908

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

keras_hub/src/models/smollm3/smollm3_backbone.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ def __init__(
7777
output_dim=hidden_dim,
7878
name="token_embedding",
7979
)
80+
81+
self.rotary_embedding = SmolLM3RotaryEmbedding(
82+
hidden_size=hidden_dim,
83+
num_attention_heads=num_attention_heads,
84+
max_position_embeddings=max_position_embeddings,
85+
rope_theta=rope_theta,
86+
partial_rotary_factor=partial_rotary_factor,
87+
)
88+
8089
self.transformer_layers = []
8190
for i in range(num_layers):
8291
layer = SmolLM3DecoderLayer(
@@ -100,14 +109,6 @@ def __init__(
100109
name="sequence_output_layernorm",
101110
)
102111

103-
self.rotary_embedding = SmolLM3RotaryEmbedding(
104-
hidden_size=hidden_dim,
105-
num_attention_heads=num_attention_heads,
106-
max_position_embeddings=max_position_embeddings,
107-
rope_theta=rope_theta,
108-
partial_rotary_factor=partial_rotary_factor,
109-
)
110-
111112
# === Functional Model ===
112113
token_id_input = keras.Input(
113114
shape=(None,), dtype="int32", name="token_ids"

keras_hub/src/models/smollm3/smollm3_causal_lm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ def call_with_cache(
7070
x = self.backbone.token_embedding(token_ids)
7171

7272
# Each decoder layer has a cache; we update them separately.
73-
7473
updated_cache = []
74+
position_embeddings = self.backbone.rotary_embedding(x, start_index=cache_update_index)
7575
for i in range(self.backbone.num_layers):
76-
position_embeddings = self.backbone.rotary_embedding(x, start_index=cache_update_index)
7776
current_cache = cache[:, i, ...]
77+
print(x.shape)
7878
x, next_cache = self.backbone.transformer_layers[i](
7979
x,
8080
position_embeddings=position_embeddings,
@@ -103,9 +103,8 @@ def _build_cache(self, token_ids):
103103
head_dim,
104104
]
105105
cache = ops.zeros(shape, dtype=self.compute_dtype)
106-
index = ops.convert_to_tensor(0, dtype="int32")
107106
# Seed the cache.
108-
_, hidden_states, cache = self.call_with_cache(token_ids, cache, index)
107+
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
109108
return hidden_states, cache
110109

111110
def generate_step(

0 commit comments

Comments
 (0)