Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,18 @@
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
GPTNeoXTokenizer as GPTNeoXTokenizer,
)
from keras_hub.src.models.gpt_oss.gpt_oss_backbone import (
GptOssBackbone as GptOssBackbone,
)
from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm import (
GptOssCausalLM as GptOssCausalLM,
)
from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm_preprocessor import (
GptOssCausalLMPreprocessor as GptOssCausalLMPreprocessor,
)
from keras_hub.src.models.gpt_oss.gpt_oss_tokenizer import (
GptOssTokenizer as GptOssTokenizer,
)
from keras_hub.src.models.hgnetv2.hgnetv2_backbone import (
HGNetV2Backbone as HGNetV2Backbone,
)
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
GPTNeoXTokenizer as GPTNeoXTokenizer,
)
from keras_hub.src.models.gpt_oss.gpt_oss_tokenizer import (
GptOssTokenizer as GptOssTokenizer,
)
from keras_hub.src.models.llama.llama_tokenizer import (
LlamaTokenizer as LlamaTokenizer,
)
Expand Down
19 changes: 19 additions & 0 deletions keras_hub/src/models/gpt_oss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_hub.src.models.gpt_oss.gpt_oss_backbone import GptOssBackbone
from keras_hub.src.models.gpt_oss.gpt_oss_presets import backbone_presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(backbone_presets, GptOssBackbone)
304 changes: 304 additions & 0 deletions keras_hub/src/models/gpt_oss/gpt_oss_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import keras
from keras import ops

from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.utils.keras_utils import clone_initializer


class GptOssAttention(keras.layers.Layer):
"""A cached attention layer with sliding window and sink tokens.

This layer implements the attention mechanism described in the GPT-OSS
paper. It includes grouped-query attention, rotary position embeddings,
sliding window attention, and sink tokens for improved performance on
long sequences.

Args:
num_query_heads (int): The number of query attention heads.
num_key_value_heads (int): The number of key and value attention
heads.
rope_max_wavelength (int, optional): The maximum wavelength for the
rotary position embedding. Defaults to 10000.
rope_scaling_factor (float, optional): The scaling factor for the
rotary position embedding. Defaults to 1.0.
kernel_initializer (str, optional): The initializer for the kernel
weights. Defaults to "glorot_uniform".
sliding_window (int, optional): The size of the sliding window.
Defaults to 4096.
dropout (float, optional): The dropout rate. Defaults to 0.
"""

def __init__(
self,
num_query_heads,
num_key_value_heads,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
kernel_initializer="glorot_uniform",
sliding_window=4096,
dropout=0,
**kwargs,
):
super().__init__(**kwargs)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.sliding_window = sliding_window
self.dropout = dropout

self.num_key_value_groups = num_query_heads // num_key_value_heads
self.rope_max_wavelength = rope_max_wavelength

self._kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)

self.rope_scaling_factor = rope_scaling_factor

def build(self, inputs_shape):
# Einsum variables:
# b = batch size
# q = query length
# k = key/value length
# m = model dim
# u = num query heads
# v = num key/value heads
# h = head dim
self._hidden_dim = inputs_shape[-1]
# For GPT-OSS, the head_dim is fixed at 64, not hidden_dim // num_query_heads
self._head_dim = (
64 # This is the actual head dimension in the HuggingFace model
)
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)

# Calculate rotary dimension -
# use the largest even number <= head_dim
self._rotary_dim = (self._head_dim // 2) * 2

self.query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self.num_query_heads, self._head_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="query",
)
self.query_dense.build(inputs_shape)

self.key_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="key",
)
self.key_dense.build(inputs_shape)

self.value_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="value",
)
self.value_dense.build(inputs_shape)

self.dropout_layer = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
)

self.output_dense = keras.layers.EinsumDense(
equation="bquh,uhm->bqm",
output_shape=(None, self._hidden_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="attention_output",
)
self.output_dense.build(
(None, None, self.num_query_heads, self._head_dim)
)

self.rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self.rope_max_wavelength,
scaling_factor=self.rope_scaling_factor,
dtype=self.dtype_policy,
)

self.sinks = self.add_weight(
shape=(self.num_query_heads,),
initializer="random_normal",
dtype=self.dtype,
name="sinks",
)

self._dot_product_equation = "bquh,bkuh->buqk"
self._combine_equation = "buqk,bkuh->bquh"

self.built = True

def call(
self,
hidden_states,
attention_mask=None,
cache=None,
cache_update_index=None,
training=None,
):
start_index = (
cache_update_index if cache_update_index is not None else 0
)

query = self.query_dense(hidden_states)

# Compute RoPE for queries (only apply to first _rotary_dim dimensions)
if self._rotary_dim < self._head_dim:
query_rot = query[..., : self._rotary_dim]
query_rot = self.rotary_embedding_layer(
query_rot, start_index=start_index
)
query = ops.concatenate(
[query_rot, query[..., self._rotary_dim :]], axis=-1
)
else:
query = self.rotary_embedding_layer(query, start_index=start_index)

def _compute_key_value(x):
key, value = self.key_dense(x), self.value_dense(x)
# Compute RoPE for keys (only apply to first _rotary_dim dimensions)
if self._rotary_dim < self._head_dim:
key_rot = key[..., : self._rotary_dim]
key_rot = self.rotary_embedding_layer(
key_rot, start_index=start_index
)
key = ops.concatenate(
[key_rot, key[..., self._rotary_dim :]], axis=-1
)
else:
key = self.rotary_embedding_layer(key, start_index=start_index)
return key, value

if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update, value_update = _compute_key_value(hidden_states)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)
key, value = _compute_key_value(hidden_states)

# [batch_shape, seq_len, num_key_value_heads, head_dim]
# -> [batch_shape, seq_len, num_heads, head_dim]
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)

attention_output = self._compute_attention(
query, key, value, attention_mask
)

attention_output = self.dropout_layer(
attention_output, training=training
)

attention_output = self.output_dense(attention_output)

if cache is not None:
return attention_output, cache
return attention_output

def _compute_attention(self, query, key, value, attention_mask=None):
attention_scores = ops.einsum(self._dot_product_equation, query, key)
attention_scores = ops.multiply(
attention_scores,
ops.cast(self._inv_norm_factor, self.compute_dtype),
)

if attention_mask is not None:
# The mask is a boolean tensor, True for positions to be masked.
# We add a large negative number to the masked positions.
# Use a large negative value for masking
if self.compute_dtype == "float32":
adder = ops.cast(-1e9, self.compute_dtype)
else:
adder = ops.cast(-1e4, self.compute_dtype)
attention_scores = ops.where(
attention_mask[:, None, :, :], adder, attention_scores
)

# Handle sink tokens by concatenating them to the logits.
b = ops.shape(query)[0]
q = ops.shape(query)[1]
sinks = ops.reshape(self.sinks, (1, self.num_query_heads, 1, 1))
sinks = ops.broadcast_to(sinks, (b, self.num_query_heads, q, 1))
# attention_scores shape: [b, num_heads, q, k]
# sinks shape: [b, num_heads, q, 1]
# We need to concatenate along the last dimension
combined_logits = ops.concatenate([attention_scores, sinks], axis=-1)

# Stabilize logits before softmax for numerical stability.
max_logits = ops.max(combined_logits, axis=-1, keepdims=True)
max_logits = ops.stop_gradient(max_logits)
combined_logits = combined_logits - max_logits

probs = ops.softmax(combined_logits, axis=-1)

# Remove the sink probabilities before computing the output.
attention_scores = probs[..., :-1]
attention_scores = ops.cast(attention_scores, self.compute_dtype)

attention_output = ops.einsum(
self._combine_equation, attention_scores, value
)

return attention_output

def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"kernel_initializer": keras.initializers.serialize(
self._kernel_initializer
),
"sliding_window": self.sliding_window,
"dropout": self.dropout,
}
)
return config
Loading
Loading