-
Notifications
You must be signed in to change notification settings - Fork 296
Added Voxtral backbone and it's test #2394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sanskarmodi8
wants to merge
3
commits into
keras-team:master
Choose a base branch
from
sanskarmodi8:adding_voxtral/pr#1
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,352 @@ | ||
from keras import Input | ||
from keras import initializers | ||
from keras import layers | ||
from keras import mixed_precision | ||
from keras import ops | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder | ||
from keras_hub.src.models.backbone import Backbone | ||
|
||
|
||
def voxtral_kernel_initializer(stddev=0.02): | ||
""" | ||
Create a TruncatedNormal initializer for VoxTral layers. | ||
|
||
Args: | ||
stddev (float): Standard deviation of the truncated normal distribution. | ||
|
||
Returns: | ||
keras.initializers.Initializer: Truncated normal initializer. | ||
""" | ||
return initializers.TruncatedNormal(stddev=stddev) | ||
|
||
|
||
class ChunkAndPad(layers.Layer): | ||
""" | ||
Pads and splits an input spectrogram into fixed-length chunks. | ||
|
||
This layer ensures the time axis of the input is divisible by | ||
`frames_per_chunk` by padding zeros, then reshapes into | ||
`(batch * n_chunks, frames_per_chunk, features)`. | ||
|
||
Args: | ||
frames_per_chunk (int): Number of frames per chunk. | ||
""" | ||
|
||
def __init__(self, frames_per_chunk, **kwargs): | ||
super().__init__(**kwargs) | ||
self.frames_per_chunk = int(frames_per_chunk) | ||
|
||
def call(self, x): | ||
"""Pad and chunk the input tensor along time dimension.""" | ||
B, T = ops.shape(x)[0], ops.shape(x)[1] | ||
pad_len = (-T) % self.frames_per_chunk | ||
x = ops.pad(x, [[0, 0], [0, pad_len], [0, 0]]) | ||
n_chunks = ops.floor_divide(T + pad_len, self.frames_per_chunk) | ||
return ops.reshape( | ||
x, [B * n_chunks, self.frames_per_chunk, ops.shape(x)[2]] | ||
) | ||
|
||
def compute_output_shape(self, input_shape): | ||
""" | ||
Compute static output shape for Keras/JAX backends. | ||
|
||
Args: | ||
input_shape (tuple): (batch, time, features). | ||
|
||
Returns: | ||
tuple: (batch * n_chunks, frames_per_chunk, features) | ||
""" | ||
batch, time, feat = input_shape | ||
if time is None: | ||
n_chunks = None | ||
else: | ||
import math | ||
|
||
n_chunks = math.ceil(time / self.frames_per_chunk) | ||
return ( | ||
None if batch is None else batch * n_chunks, | ||
self.frames_per_chunk, | ||
feat, | ||
) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update({"frames_per_chunk": self.frames_per_chunk}) | ||
return config | ||
|
||
|
||
class PositionalEmbedding(layers.Layer): | ||
""" | ||
Learnable positional embedding added to each time step in a chunk. | ||
|
||
Args: | ||
length (int): Sequence length of each chunk | ||
(frames per chunk post-conv). | ||
dim (int): Embedding dimension. | ||
""" | ||
|
||
def __init__(self, length, dim, **kwargs): | ||
super().__init__(**kwargs) | ||
self.length = int(length) | ||
self.dim = int(dim) | ||
|
||
def build(self, input_shape): | ||
"""Create the embedding weights.""" | ||
self.pos_emb = self.add_weight( | ||
name="pos_emb", | ||
shape=(self.length, self.dim), | ||
initializer=initializers.RandomNormal(stddev=0.02), | ||
trainable=True, | ||
dtype=self.compute_dtype, | ||
) | ||
super().build(input_shape) | ||
|
||
def call(self, x): | ||
""" | ||
Add the positional embedding to the input tensor. | ||
|
||
Args: | ||
x (Tensor): Input tensor of shape (batch, time, dim). | ||
|
||
Returns: | ||
Tensor: Input tensor with positional embedding added. | ||
""" | ||
return x + ops.cast(self.pos_emb[None, :, :], x.dtype) | ||
|
||
def compute_output_shape(self, input_shape): | ||
"""Return same shape as input.""" | ||
return input_shape | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update({"length": self.length, "dim": self.dim}) | ||
return config | ||
|
||
|
||
class ReassembleChunks(layers.Layer): | ||
""" | ||
Reassembles chunked outputs back into `(batch, time, hidden_dim)`. | ||
|
||
Args: | ||
frames_per_chunk (int): Frames per chunk pre-conv. | ||
postproc_chunk_len (int, optional): Chunk length after processing. | ||
""" | ||
|
||
def __init__(self, frames_per_chunk, postproc_chunk_len=None, **kwargs): | ||
super().__init__(**kwargs) | ||
self.frames_per_chunk = int(frames_per_chunk) | ||
self.postproc_chunk_len = ( | ||
None if postproc_chunk_len is None else int(postproc_chunk_len) | ||
) | ||
|
||
def call(self, processed_chunks, orig_spectrogram): | ||
""" | ||
Reassemble processed chunks into a continuous time sequence. | ||
|
||
Args: | ||
processed_chunks (Tensor): Output of transformer layers | ||
of shape (B*n_chunks, T_chunk, H). | ||
orig_spectrogram (Tensor): Original input spectrogram | ||
of shape (B, T, F). | ||
|
||
Returns: | ||
Tensor: Reassembled tensor of shape (B, T', H). | ||
""" | ||
B, T = ops.shape(orig_spectrogram)[0], ops.shape(orig_spectrogram)[1] | ||
n_chunks = ops.cast( | ||
ops.floor_divide( | ||
T + self.frames_per_chunk - 1, self.frames_per_chunk | ||
), | ||
"int32", | ||
) | ||
T_chunk, H = ( | ||
ops.shape(processed_chunks)[1], | ||
ops.shape(processed_chunks)[2], | ||
) | ||
return ops.reshape(processed_chunks, [B, n_chunks * T_chunk, H]) | ||
|
||
def compute_output_shape(self, input_shape): | ||
"""Return shape compatible with a single long sequence.""" | ||
return input_shape | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"frames_per_chunk": self.frames_per_chunk, | ||
"postproc_chunk_len": self.postproc_chunk_len, | ||
} | ||
) | ||
return config | ||
|
||
|
||
@keras_hub_export("keras_hub.models.VoxTralBackbone") | ||
class VoxTralBackbone(Backbone): | ||
""" | ||
VoxTral audio encoder + adapter backbone. | ||
|
||
This model implements the encoder portion of the VoxTral model. | ||
It takes a log-Mel spectrogram and produces a sequence of hidden states. | ||
|
||
Args: | ||
num_layers (int): Number of transformer layers. | ||
num_heads (int): Number of attention heads. | ||
hidden_dim (int): Embedding size. | ||
intermediate_dim (int): Size of feedforward network hidden layer. | ||
adapter_downsample (int): Pooling factor after adapter dense. | ||
dropout (float): Dropout probability. | ||
max_chunk_seconds (int): Chunking length in seconds. | ||
sr (int): Audio sample rate. | ||
hop_length (int): Hop length for spectrogram frames. | ||
dtype (str or mixed_precision.Policy, optional): Layer dtype. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_layers=32, | ||
num_heads=20, | ||
hidden_dim=1280, | ||
intermediate_dim=5120, | ||
adapter_downsample=4, | ||
dropout=0.1, | ||
max_chunk_seconds=30, | ||
sr=16000, | ||
hop_length=160, | ||
dtype=None, | ||
**kwargs, | ||
): | ||
"""Initialize the VoxTral backbone.""" | ||
# Store configuration | ||
self.num_layers = int(num_layers) | ||
self.num_heads = int(num_heads) | ||
self.hidden_dim = int(hidden_dim) | ||
self.intermediate_dim = int(intermediate_dim) | ||
self.adapter_downsample = int(adapter_downsample) | ||
self.dropout = float(dropout) | ||
self.max_chunk_seconds = int(max_chunk_seconds) | ||
self.sr = int(sr) | ||
self.hop_length = int(hop_length) | ||
|
||
# Frames per chunk before conv | ||
self.frames_per_chunk_preconv = int( | ||
self.max_chunk_seconds * (self.sr / self.hop_length) | ||
) | ||
self.postconv_frames_per_chunk = self.frames_per_chunk_preconv // 2 | ||
|
||
# --- Mixed precision policy --- | ||
if dtype is None: | ||
policy = mixed_precision.global_policy() | ||
elif isinstance(dtype, str): | ||
policy = mixed_precision.Policy(dtype) | ||
elif isinstance(dtype, dict): # coming from config | ||
policy = mixed_precision.Policy(dtype["config"]["name"]) | ||
else: | ||
policy = dtype # already a Policy | ||
|
||
variable_dtype = policy.variable_dtype | ||
compute_dtype = policy.compute_dtype | ||
self._policy = policy # save for get_config() | ||
|
||
# --- Layers --- | ||
self.conv_stem_1 = layers.Conv1D( | ||
filters=self.hidden_dim, | ||
kernel_size=3, | ||
strides=2, | ||
padding="same", | ||
activation="relu", | ||
kernel_initializer=voxtral_kernel_initializer(), | ||
dtype=variable_dtype, | ||
name="conv_stem_1", | ||
) | ||
self.conv_stem_2 = layers.Conv1D( | ||
filters=self.hidden_dim, | ||
kernel_size=3, | ||
strides=1, | ||
padding="same", | ||
activation="relu", | ||
kernel_initializer=voxtral_kernel_initializer(), | ||
dtype=variable_dtype, | ||
name="conv_stem_2", | ||
) | ||
|
||
self.transformer_layers = [ | ||
TransformerEncoder( | ||
num_heads=self.num_heads, | ||
intermediate_dim=self.intermediate_dim, | ||
dropout=self.dropout, | ||
name=f"transformer_layer_{i}", | ||
dtype=variable_dtype, | ||
) | ||
for i in range(self.num_layers) | ||
] | ||
|
||
self.adapter_dense = layers.Dense( | ||
self.hidden_dim, | ||
activation="relu", | ||
kernel_initializer=voxtral_kernel_initializer(), | ||
dtype=variable_dtype, | ||
name="adapter_dense", | ||
) | ||
self.adapter_pool = layers.AveragePooling1D( | ||
pool_size=self.adapter_downsample, | ||
strides=self.adapter_downsample, | ||
padding="valid", | ||
name="adapter_downsample", | ||
dtype=variable_dtype, | ||
) | ||
|
||
self.pos_emb = PositionalEmbedding( | ||
self.postconv_frames_per_chunk, | ||
self.hidden_dim, | ||
name="pos_emb", | ||
dtype=variable_dtype, | ||
) | ||
|
||
# --- Functional graph --- | ||
spectrogram_input = Input( | ||
shape=(None, 128), dtype=compute_dtype, name="spectrogram" | ||
) | ||
x = ChunkAndPad( | ||
self.frames_per_chunk_preconv, | ||
name="chunk_and_pad", | ||
dtype=compute_dtype, | ||
)(spectrogram_input) | ||
x = self.conv_stem_1(x) | ||
x = self.conv_stem_2(x) | ||
x = self.pos_emb(x) | ||
for transformer_layer in self.transformer_layers: | ||
x = transformer_layer(x) | ||
x = self.adapter_dense(x) | ||
x = self.adapter_pool(x) | ||
outputs = ReassembleChunks( | ||
self.frames_per_chunk_preconv, | ||
name="reassemble_chunks", | ||
dtype=compute_dtype, | ||
)(x, spectrogram_input) | ||
|
||
super().__init__( | ||
inputs=spectrogram_input, | ||
outputs=outputs, | ||
dtype=compute_dtype, | ||
**kwargs, | ||
) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"num_layers": self.num_layers, | ||
"num_heads": self.num_heads, | ||
"hidden_dim": self.hidden_dim, | ||
"intermediate_dim": self.intermediate_dim, | ||
"adapter_downsample": self.adapter_downsample, | ||
"dropout": self.dropout, | ||
"max_chunk_seconds": self.max_chunk_seconds, | ||
"sr": self.sr, | ||
"hop_length": self.hop_length, | ||
"dtype": self._policy.name, # store string | ||
} | ||
) | ||
return config |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.