Skip to content

Conversation

gabe-l-hart
Copy link
Collaborator

@gabe-l-hart gabe-l-hart commented Aug 22, 2025

Closes #15409

Draft Status

This PR will remain in draft until the model is fully working!

It's working!

Description

This PR adds support for the nemotronh architecture (hybrid mamba2/attention used for Nemotron Nano V2).

@github-actions github-actions bot added the python python script changes label Aug 22, 2025
This is really helpful for diagnosing mismatches between the expected and
received tensors

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>
It generates tokens, just not valid ones!

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>
@gabe-l-hart gabe-l-hart force-pushed the gabe-l-hart/nvidia-nemotron-nano-15409 branch from df092d6 to 828176e Compare August 22, 2025 21:45
The `tokenizer.json`/`tokenizer_config.json` in the model are a bit
contradictory. In the config, add_bos_token is set to False, but the
tokenizer model itself has a post_processor that adds the BOS token via
type: TemplateProcessing

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>
@gabe-l-hart
Copy link
Collaborator Author

Fixing the add_bos_token setting gets the tensor/tensor comparison to line up through the output of the first conv1d, but the output of the first ssm_scan still misaligns between this branch and transformers, so the next broken piece lives somewhere between the two in the mamba2 portion

@CISC
Copy link
Collaborator

CISC commented Aug 25, 2025

Fixing the add_bos_token setting gets the tensor/tensor comparison to line up through the output of the first conv1d

The BOS token should have been detected in TemplateProcessing by SpecialVocab, but maybe this codepath is never taken, hard to tell?

@gabe-l-hart
Copy link
Collaborator Author

Hmm, I see what you mean, it definitely should have been caught. I'm guessing it has to do with the (overly complex) parent class hierarchy and possibly not getting to the standard code path that uses SpecialVocab.

@gabe-l-hart
Copy link
Collaborator Author

I dug more today and I've isolated the implementation issues to the SSM Norm (with a bunch of print statements in modeling_nemotron_h.py. I can't find any reason why this would be behaving differently than other mamba2 models, but I'll keep digging tomorrow.

transformers tensors

--> ssm (y + D_residual): tensor([[[[ 7.5073e-04, -6.0881e-03,  2.2847e-04,  ...,  5.5659e-05,
           -3.9031e-04,  7.1842e-05],
          [-8.3047e-05, -9.0506e-03, -2.1056e-04,  ..., -1.8473e-04,
            6.8728e-04, -1.3800e-04],
          [-4.9432e-05,  5.0344e-03,  3.4076e-05,  ...,  1.5251e-05,
            1.2162e-04,  4.2153e-03],
          ...,
          [ 2.4585e-04, -3.0655e-03,  5.0408e-04,  ..., -2.5004e-04,
           -2.5885e-03, -1.8478e-03],
          [-1.9601e-04,  4.9391e-04, -3.3035e-04,  ...,  1.3616e-03,
            7.4775e-03, -4.9377e-04],
          [-9.5254e-04, -7.9345e-04,  1.1792e-02,  ...,  2.9217e-03,
           -1.4934e-03,  1.5367e-03]],

         [[ 6.0237e-03, -7.1595e-03,  3.2482e-03,  ...,  6.7600e-04,
           -3.9438e-03,  7.1795e-03],
          [ 1.7518e-04, -9.8712e-03, -1.5274e-03,  ..., -7.2140e-04,
           -7.7685e-04,  2.1262e-03],
          [ 4.1907e-04,  2.5693e-03, -1.7164e-03,  ...,  1.7401e-03,
            6.2931e-05,  4.2724e-03],
          ...,
          [ 7.9968e-03,  4.4843e-03, -1.1731e-02,  ..., -1.6766e-02,
            4.3176e-02,  9.8243e-03],
          [ 3.8583e-03, -1.3276e-02, -6.1258e-03,  ..., -2.0900e-02,
           -5.1029e-03, -9.8449e-03],
          [-3.4692e-04, -7.6044e-03,  6.4007e-03,  ...,  1.2949e-03,
            2.8442e-03,  2.0942e-03]],
....
--> ssm (y unpadded/reshaped): tensor([[[ 0.0008, -0.0061,  0.0002,  ...,  0.0029, -0.0015,  0.0015],
         [ 0.0060, -0.0072,  0.0032,  ...,  0.0013,  0.0028,  0.0021]]])
--> ssm y: tensor([[[ 0.0008, -0.0061,  0.0002,  ...,  0.0029, -0.0015,  0.0015],
         [ 0.0060, -0.0072,  0.0032,  ...,  0.0013,  0.0028,  0.0021]]])
--> ssm swiglu: tensor([[[-1.3809e-04, -4.4626e-04,  1.8921e-05,  ...,  5.1799e-04,
          -4.2833e-04, -3.4217e-04],
         [ 2.7305e-03, -8.4085e-04,  1.2045e-04,  ...,  4.9373e-04,
          -5.7184e-04, -5.0433e-05]]])
--> ssm rms norm variance / variance_epsilon: tensor([[[ 1.7734],
         [10.1713]]]) / 1e-05
--> ssm rms norm unweighted: tensor([[[-1.0369e-04, -3.3511e-04,  1.4209e-05,  ...,  3.8897e-04,
          -3.2164e-04, -2.5694e-04],
         [ 8.5617e-04, -2.6365e-04,  3.7768e-05,  ...,  1.5481e-04,
          -1.7930e-04, -1.5813e-05]]])

llama-eval-callback

ggml_debug:                  node_55 = (f32)        ADD(y_ssm-0 (view){80, 128, 2, 1}, node_54{80, 128, 2, 1}}) = {80, 128, 2, 1}
                                     [
                                      [
                                       [      0.0008,      -0.0061,       0.0002, ...,       0.0001,      -0.0004,       0.0001],
                                       [     -0.0001,      -0.0091,      -0.0002, ...,      -0.0002,       0.0007,      -0.0001],
                                       [     -0.0000,       0.0050,       0.0000, ...,       0.0000,       0.0001,       0.0042],
                                       ..., 
                                       [      0.0002,      -0.0031,       0.0005, ...,      -0.0003,      -0.0026,      -0.0018],
                                       [     -0.0002,       0.0005,      -0.0003, ...,       0.0014,       0.0075,      -0.0005],
                                       [     -0.0010,      -0.0008,       0.0118, ...,       0.0029,      -0.0015,       0.0015],
                                      ],
                                      [
                                       [      0.0060,      -0.0071,       0.0033, ...,       0.0007,      -0.0040,       0.0072],
                                       [      0.0002,      -0.0099,      -0.0015, ...,      -0.0007,      -0.0008,       0.0021],
                                       [      0.0004,       0.0026,      -0.0017, ...,       0.0017,       0.0001,       0.0043],
                                       ..., 
                                       [      0.0080,       0.0045,      -0.0117, ...,      -0.0168,       0.0432,       0.0098],
                                       [      0.0039,      -0.0133,      -0.0061, ...,      -0.0209,      -0.0051,      -0.0098],
                                       [     -0.0003,      -0.0076,       0.0064, ...,       0.0013,       0.0029,       0.0021],
                                      ],
                                     ]
                                     sum = 0.002696
ggml_debug:                  node_56 = (f32)     SWIGLU(zxBCdt_z (cont){80, 128, 2, 1}, node_55{80, 128, 2, 1}}) = {80, 128, 2, 1}
                                     [
                                      [
                                       [     -0.0001,      -0.0004,       0.0000, ...,       0.0000,       0.0001,       0.0000],
                                       [      0.0000,       0.0015,      -0.0000, ...,       0.0000,      -0.0002,       0.0000],
                                       [     -0.0000,       0.0002,       0.0000, ...,      -0.0000,       0.0001,      -0.0000],
                                       ..., 
                                       [      0.0000,      -0.0009,      -0.0001, ...,       0.0000,      -0.0009,       0.0004],
                                       [      0.0000,      -0.0001,      -0.0000, ...,       0.0003,       0.0018,      -0.0002],
                                       [      0.0002,      -0.0000,       0.0001, ...,       0.0005,      -0.0004,      -0.0003],
                                      ],
                                      [
                                       [      0.0027,      -0.0008,       0.0001, ...,       0.0000,      -0.0004,      -0.0015],
                                       [      0.0001,      -0.0005,      -0.0000, ...,       0.0001,       0.0001,      -0.0002],
                                       [      0.0002,      -0.0006,       0.0005, ...,      -0.0002,      -0.0000,      -0.0004],
                                       ..., 
                                       [      0.0017,       0.0001,      -0.0005, ...,      -0.0089,      -0.0000,       0.0002],
                                       [     -0.0008,       0.0018,       0.0003, ...,      -0.0214,      -0.0036,      -0.0006],
                                       [     -0.0000,      -0.0004,      -0.0006, ...,       0.0005,      -0.0006,      -0.0000],
                                      ],
                                     ]
                                     sum = -0.032029
ggml_debug:               (reshaped) = (f32)    RESHAPE(node_56{80, 128, 2, 1}, }) = {1280, 8, 2, 1}
                                     [
                                      [
                                       [     -0.0001,      -0.0004,       0.0000, ...,      -0.0000,       0.0000,       0.0000],
                                       [     -0.0025,      -0.0030,      -0.0002, ...,      -0.0015,       0.0001,      -0.0542],
                                       [      0.0008,       0.3558,       1.1437, ...,       0.0001,       0.0003,      -0.0003],
                                       ..., 
                                       [      0.1158,      -0.0184,      -0.0012, ...,       0.0189,       0.0027,       0.0000],
                                       [     -0.0386,      -2.5810,      -0.0017, ...,       0.0004,      -0.0001,       0.0000],
                                       [      0.0000,       0.0006,       0.0020, ...,       0.0005,      -0.0004,      -0.0003],
                                      ],
                                      [
                                       [      0.0027,      -0.0008,       0.0001, ...,       0.0002,       0.0000,       0.0001],
                                       [      0.0023,      -0.0009,      -0.0008, ...,      -0.0018,      -0.0021,      -0.0014],
                                       [     -0.0097,      -4.2951,      10.1652, ...,      -0.0019,       0.0013,      -0.0014],
                                       ..., 
                                       [      0.1052,       0.0139,       0.4142, ...,       0.0184,      -0.0060,       0.0002],
                                       [      0.0340,      -7.9632,      -0.0004, ...,      -0.0017,      -0.0007,      -0.0021],
                                       [      0.0085,      -0.0021,      -0.0276, ...,       0.0005,      -0.0006,      -0.0000],
                                      ],
                                     ]
                                     sum = -2.615886
ggml_debug:                   norm-0 = (f32)   RMS_NORM( (reshaped){1280, 8, 2, 1}, }) = {1280, 8, 2, 1}
                                     [
                                      [
                                       [     -0.0144,      -0.0466,       0.0020, ...,      -0.0007,       0.0001,       0.0002],
                                       [     -0.0076,      -0.0094,      -0.0007, ...,      -0.0046,       0.0002,      -0.1679],
                                       [      0.0004,       0.2001,       0.6430, ...,       0.0000,       0.0002,      -0.0002],
                                       ..., 
                                       [      0.0963,      -0.0153,      -0.0010, ...,       0.0157,       0.0022,       0.0000],
                                       [     -0.0133,      -0.8924,      -0.0006, ...,       0.0002,      -0.0000,       0.0000],
                                       [      0.0001,       0.0018,       0.0061, ...,       0.0016,      -0.0013,      -0.0010],
                                      ],
                                      [
                                       [      0.0887,      -0.0274,       0.0039, ...,       0.0072,       0.0005,       0.0017],
                                       [      0.0015,      -0.0006,      -0.0005, ...,      -0.0012,      -0.0014,      -0.0009],
                                       [     -0.0017,      -0.7639,       1.8079, ...,      -0.0003,       0.0002,      -0.0002],
                                       ..., 
                                       [      0.0240,       0.0032,       0.0947, ...,       0.0042,      -0.0014,       0.0000],
                                       [      0.0085,      -1.9846,      -0.0001, ...,      -0.0004,      -0.0002,      -0.0005],
                                       [      0.0083,      -0.0020,      -0.0268, ...,       0.0005,      -0.0006,      -0.0000],
                                      ],
                                     ]
                                     sum = -0.966779
(modified) modeling_nemotron_h.py
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch NemotronH model."""

import math
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from transformers.activations import ACT2FN
from transformers.cache_utils import DynamicCache  # we need __iter__ and __len__ of pkv
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import (
    AttentionMaskConverter,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
)
from transformers.utils.import_utils import (
    is_causal_conv1d_available,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    is_mamba_2_ssm_available,
)
from .configuration_nemotron_h import NemotronHConfig


logger = logging.get_logger(__name__)


# Copied from transformers.models.mamba.modeling_mamba2.modeling_mamba2.py with MAMBA2->NEMOTRONH,Mamba2->NemotronH
# For Mamba2 components Mamba2->NemotronHMamba2
if is_mamba_2_ssm_available():
    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
    from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
else:
    mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None

try:
    #from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
    from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn
    FAST_RMSNORM = True
except ImportError:
    FAST_RMSNORM = False
    # raise ImportError("mamba-ssm is required by the Mamba model but cannot be imported")

if is_causal_conv1d_available():
    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
    causal_conv1d_update, causal_conv1d_fn = None, None

if is_flash_attn_2_available():
    from transformers.modeling_flash_attention_utils import _flash_attention_forward

is_fast_path_available = all(
    (
        selective_state_update,
        mamba_chunk_scan_combined,
        mamba_split_conv1d_scan_combined,
        causal_conv1d_fn,
        causal_conv1d_update,
    )
)


_CHECKPOINT_FOR_DOC = "nvidia/Nemotron-H-56B-Base-8K"
_CONFIG_FOR_DOC = "NemotronHConfig"


# Helper methods for segment sum computation


def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
    """
    Padding x tensor with `pad_size` on the seq_len dim (dim=1)

    Assumes that we only have tensors of either size 4 or 3
    """
    pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)

    return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)


def reshape_into_chunks(input_tensor, pad_size, chunk_size):
    """
    Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
    simultaneously splitting it into chunk sequences.

    Assumes that we only have tensors of either size 4 or 3
    """
    # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
    input_tensor = pad_tensor_by_size(input_tensor, pad_size)

    if len(input_tensor.shape) == 3:
        # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
        return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
    else:
        # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
        return input_tensor.reshape(
            input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
        )


def segment_sum(input_tensor):
    """
    More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
    """
    chunk_size = input_tensor.size(-1)
    # 1. expand input tensor to have an additional dimension and repeat along that dimension
    # [..., chunk_size] -> [..., chunk_size, chunk_size]
    input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
    # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
    mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
    input_tensor = input_tensor.masked_fill(~mask, 0)
    # 3. compute actual cumsum
    tensor_segsum = torch.cumsum(input_tensor, dim=-2)

    # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
    mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
    tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
    return tensor_segsum


def apply_mask_to_padding_states(hidden_states, attention_mask):
    """
    Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
    """
    if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
        dtype = hidden_states.dtype
        hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)

    return hidden_states

# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py
class HybridMambaAttentionDynamicCache(DynamicCache):
    """
    A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
    (which has a constant shape regardless of seq_len).

    This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
    and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
    For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
    while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
    For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
    while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
    and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
    """

    def __init__(self, config, batch_size, dtype=torch.float16, device=None):
        super().__init__()
        self.dtype = dtype
        self.hybrid_override_pattern = config.hybrid_override_pattern
        self.has_previous_state = False  # only used by mamba
        #intermediate_size = config.expand * config.hidden_size
        intermediate_size = config.mamba_num_heads * config.mamba_head_dim
        ssm_state_size = config.ssm_state_size
        conv_kernel_size = config.conv_kernel
        self.conv_states = []
        self.ssm_states = []
        self.transformer_layers = []
        for i in range(config.num_hidden_layers):
            if self.hybrid_override_pattern[i] == "M":
                # Mamba layer
                self.conv_states += [
                    torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
                ]
                self.ssm_states += [
                    torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
                ]
            else:
                # Attention or MLP layer
                self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
                self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
                self.transformer_layers.append(i)

        self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
        self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Update the cache
        if self.key_cache[layer_idx].shape[-1] == 0:
            self.key_cache[layer_idx] = key_states
            self.value_cache[layer_idx] = value_states
        else:
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    def reorder_cache(self, beam_idx: torch.LongTensor):
        """Reorders the cache for beam search, given the selected beam indices."""
        for layer_idx in range(len(self.key_cache)):
            device = self.key_cache[layer_idx].device
            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
            device = self.value_cache[layer_idx].device
            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

            device = self.conv_states[layer_idx].device
            self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
            device = self.ssm_states[layer_idx].device
            self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        # take any layer that contains cache and not empty tensor
        layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
        if len(self.key_cache) <= layer_idx:
            return 0
        return self.key_cache[layer_idx].shape[-2]

    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")

    @classmethod
    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
        raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")

    # Copied from modeling_mamba2.py
    def update_conv_state(
        self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
    ) -> torch.Tensor:
        if cache_init:
            self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
        else:
            self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
            self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
        return self.conv_states[layer_idx]

    def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
        self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
        return self.ssm_states[layer_idx]

    def reset(self):
        self.conv_states.zero_()
        self.ssm_states.zero_()

class MambaRMSNormGated(torch.nn.Module):
    def __init__(self, hidden_size, group_size, eps=1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps
        self.group_size = group_size

    # jan28b version
    def forward(self, hidden_states, gate=None):
        if FAST_RMSNORM:
            return rmsnorm_fn(x=hidden_states,
                            weight=self.weight,
                            bias=None, # No bias
                            z=gate,
                            eps=self.variance_epsilon,
                            group_size=self.group_size,
                            norm_before_gate=False
            )
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)

        if gate is not None:
            hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
            print(f"--> ssm swiglu: {hidden_states}")
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        print(f"--> ssm rms norm variance / variance_epsilon: {variance} / {self.variance_epsilon}")
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        print(f"--> ssm rms norm unweighted: {hidden_states}\nSHAPE: {hidden_states.shape}")

        print(f"--> ssm rms norm weights: {self.weight}\nSHAPE: {self.weight.shape}")
        return self.weight * hidden_states.to(input_dtype)

class NemotronHMamba2Mixer(nn.Module):
    """
    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
    A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
    ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
    and is why Mamba is called **selective** state spaces)
    """

    def __init__(self, config: NemotronHConfig, layer_idx: int):
        super().__init__()
        self.num_heads = config.mamba_num_heads
        self.hidden_size = config.hidden_size
        self.ssm_state_size = config.ssm_state_size
        self.conv_kernel_size = config.conv_kernel
        self.intermediate_size = config.mamba_num_heads * config.mamba_head_dim
        self.layer_idx = layer_idx
        self.use_conv_bias = config.use_conv_bias
        self.activation = config.mamba_hidden_act
        self.act = ACT2FN[config.mamba_hidden_act]

        self.layer_norm_epsilon = config.layer_norm_epsilon

        self.n_groups = config.n_groups
        self.head_dim = config.mamba_head_dim
        self.chunk_size = config.chunk_size

        self.time_step_limit = config.time_step_limit
        self.time_step_min = config.time_step_min
        self.time_step_max = config.time_step_max

        self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
        self.conv1d = nn.Conv1d(
            in_channels=self.conv_dim,
            out_channels=self.conv_dim,
            bias=config.use_conv_bias,
            kernel_size=config.conv_kernel,
            groups=self.conv_dim,
            padding=config.conv_kernel - 1,
        )

        # projection of the input hidden states
        projection_size = self.intermediate_size + self.conv_dim + self.num_heads
        self.in_proj = nn.Linear(
            self.hidden_size,
            projection_size,
            bias=config.use_bias,
        )
        # selective projection used to make dt, B and C input dependant

        # time step projection (discretization)
        # instantiate once and copy inv_dt in init_weights of PretrainedModel
        self.dt_bias = nn.Parameter(torch.ones(self.num_heads))

        # S4D real initialization. These are not discretized!
        # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
        A = torch.arange(1, self.num_heads + 1)
        self.A_log = nn.Parameter(torch.log(A))
        self.A_log._no_weight_decay = True
        self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon, group_size=self.intermediate_size // self.n_groups)
        self.D = nn.Parameter(torch.ones(self.num_heads))
        self.D._no_weight_decay = True

        self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
        self.use_bias = config.use_bias

        if not is_fast_path_available:
            logger.warning_once(
                "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
                " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
                " https://github.com/Dao-AILab/causal-conv1d"
            )

    def cuda_kernels_forward(
        self,
        hidden_states: torch.Tensor,
        cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ):
        # 1. Gated MLP's linear projection
        hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
        print(f"--> masked/padded: {hidden_states}")
        projected_states = self.in_proj(hidden_states)
        print(f"--> in_proj: {hidden_states}")

        # Set up dimensions for reshapes later
        batch_size, seq_len, _ = hidden_states.shape
        groups_time_state_size = self.n_groups * self.ssm_state_size
        d_mlp = (
            projected_states.shape[-1]
            - 2 * self.intermediate_size
            - 2 * self.n_groups * self.ssm_state_size
            - self.num_heads
        ) // 2

        # Single step calculations via cache
        if cache_params is not None and cache_position is not None and cache_position[0] > 0:
            _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
                [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
            )

            # 2. Convolution sequence transformation
            hidden_states_B_C = causal_conv1d_update(
                hidden_states_B_C,
                cache_params.conv_states[self.layer_idx],
                self.conv1d.weight.squeeze(1),
                self.conv1d.bias,
                self.activation,
            )

            hidden_states, B, C = torch.split(
                hidden_states_B_C,
                [self.intermediate_size, groups_time_state_size, groups_time_state_size],
                dim=-1,
            )
            print(f"--> conv1d (hidden_states): {hidden_states}")
            print(f"--> conv1d (B): {B}")
            print(f"--> conv1d (C): {C}")

            # 3. SSM transformation
            A = -torch.exp(self.A_log.float())  # (nheads,)
            A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
            dt = dt[:, :, None].expand(-1, -1, self.head_dim)
            dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
            D = self.D[:, None, ...].expand(-1, self.head_dim)
            B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
            C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
            hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
            hidden_states = selective_state_update(
                cache_params.ssm_states[self.layer_idx],
                hidden_states_reshaped,
                dt,
                A,
                B,
                C,
                D,
                z=None,
                dt_bias=dt_bias,
                dt_softplus=True,
            )
            hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
            print(f"--> ssm_states: {hidden_states}")
            hidden_states = self.norm(hidden_states, gate)
            print(f"--> norm: {hidden_states}")

            # 4. Final linear projection
            out = self.out_proj(hidden_states)[:, None, ...]
            print(f"--> out_proj: {out}")

        # Fused calculations or step by step if no initialized cache is found
        else:
            A = -torch.exp(self.A_log.float())  # (num_heads) or (intermediate_size, state_size)
            dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}

            # 2-4. Fused kernel for conv1d, SSM, and the final projection
            if self.training and cache_params is None:
                out = mamba_split_conv1d_scan_combined(
                    projected_states,
                    self.conv1d.weight.squeeze(1),
                    self.conv1d.bias,
                    self.dt_bias,
                    A,
                    D=self.D,
                    chunk_size=self.chunk_size,
                    seq_idx=None,  # was seq_idx
                    activation=self.activation,
                    rmsnorm_weight=self.norm.weight,
                    rmsnorm_eps=self.norm.variance_epsilon,
                    outproj_weight=self.out_proj.weight,
                    outproj_bias=self.out_proj.bias,
                    headdim=self.head_dim,
                    ngroups=self.n_groups,
                    norm_before_gate=False,
                    return_final_states=False,
                    **dt_limit_kwargs,
                )

            else:
                _, _, gate, hidden_states_B_C, dt = projected_states.split(
                    [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
                )

                # 2. Convolution sequence transformation
                # Init cache
                if cache_params is not None:
                    hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
                    conv_states = nn.functional.pad(
                        hidden_states_B_C_transposed,
                        (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
                    )
                    cache_params.update_conv_state(
                        layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True
                    )

                if self.activation not in ["silu", "swish"]:
                    hidden_states_B_C = self.act(
                        self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
                    )
                else:
                    hidden_states_B_C = causal_conv1d_fn(
                        x=hidden_states_B_C.transpose(1, 2),
                        weight=self.conv1d.weight.squeeze(1),
                        bias=self.conv1d.bias,
                        activation=self.activation,
                    ).transpose(1, 2)
                hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
                hidden_states, B, C = torch.split(
                    hidden_states_B_C,
                    [self.intermediate_size, groups_time_state_size, groups_time_state_size],
                    dim=-1,
                )

                # 3. SSM transformation
                scan_output, ssm_state = mamba_chunk_scan_combined(
                    hidden_states.view(batch_size, seq_len, -1, self.head_dim),
                    dt,
                    A,
                    B.view(batch_size, seq_len, self.n_groups, -1),
                    C.view(batch_size, seq_len, self.n_groups, -1),
                    chunk_size=self.chunk_size,
                    D=self.D,
                    z=None,
                    seq_idx=None,
                    return_final_states=True,
                    dt_bias=self.dt_bias,
                    dt_softplus=True,
                    **dt_limit_kwargs,
                )

                # Init cache
                if ssm_state is not None and cache_params is not None:
                    cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)

                scan_output = scan_output.view(batch_size, seq_len, -1)

                # Multiply "gate" branch and apply extra normalization layer
                scan_output = self.norm(scan_output, gate)

                # 4. Final linear projection
                out = self.out_proj(scan_output)
        return out

    # fmt: off
    def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
        batch_size, seq_len, _ = input_states.shape
        dtype = input_states.dtype

        # 1. Gated MLP's linear projection
        input_states = apply_mask_to_padding_states(input_states, attention_mask)
        print(f"--> masked/padded: {input_states}")
        projected_states = self.in_proj(input_states)
        print(f"--> in_proj: {projected_states}")
        d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
        _, _, gate, hidden_states_B_C, dt = projected_states.split(
                [d_mlp, d_mlp, self.intermediate_size,  self.conv_dim, self.num_heads], dim=-1
        )
        print(f"--> d_mlp: {d_mlp}")
        print(f"--> gate: {gate}")
        print(f"--> hidden_states_B_C: {hidden_states_B_C}\nSUM: {hidden_states_B_C.sum()}")
        print(f"--> dt: {dt}")

        # 2. Convolution sequence transformation
        if cache_params is not None and cache_position is not None and cache_position[0] > 0:
            cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False)

            # We need to guarantee that anything regarding the cache is on the same device
            conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device)

            hidden_states_B_C = torch.sum(
                conv_states * self.conv1d.weight.squeeze(1), dim=-1
            )
            if self.use_conv_bias:
                hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
            hidden_states_B_C = self.act(hidden_states_B_C)
        else:
            # Init cache
            if cache_params is not None:
                hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
                conv_states = nn.functional.pad(
                    hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0)
                )
                cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True)

            conv1d_input = hidden_states_B_C.transpose(1, 2)
            print(f"--> conv1d input: {conv1d_input}\nSUM: {conv1d_input.sum()}")
            pre_act = self.conv1d(conv1d_input)[..., :seq_len].transpose(1, 2)
            print(f"--> conv1d pre-activation: {pre_act}\n:SUM: {pre_act.sum()}")
            hidden_states_B_C = self.act(pre_act)

        print(f"--> conv1d pre-pad-mask (hidden_states_B_C): {hidden_states_B_C}\nSUM: {hidden_states_B_C.sum()}")

        hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
        hidden_states, B, C = torch.split(
            hidden_states_B_C,
            [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
            dim=-1
        )
        print(f"--> conv1d (hidden_states): {hidden_states}")
        print(f"--> conv1d (B): {B}")
        print(f"--> conv1d (C): {C}")

        # 3. SSM transformation
        A = -torch.exp(self.A_log.float())                            # [num_heads]
        print(f"--> ssm (A): {A}\nSUM: {A.sum()}\nSHAPE: {A.shape}")
        if cache_params is not None and cache_position is not None and cache_position[0] > 0:
            print("--> reading from cache")
            # We need to guarantee that anything regarding the cache is on the same device
            cache_device = cache_params.ssm_states.device

            # Note: there is no need to pad parameter matrices here, as there is just one new token
            # for batched generation
            dt = dt[:, 0, :][:, None, ...]
            dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
            # [num_heads] -> [num_heads, head_dim]
            dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)

            dt_plus_b = dt + dt_bias.to(dt.dtype)
            print(f"--> ssm (dt + bias): {dt_plus_b}")
            dt = torch.nn.functional.softplus(dt_plus_b)
            dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
            A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
            # [bsz, num_heads, head_dim, state_size]
            dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)

            # Discretize B
            # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
            # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
            B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
            B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
            B = B.reshape(batch_size, -1, B.shape[-1])
            # [bsz, num_heads, head_dim, state_size]
            dB = dt[..., None] * B[..., None, :]

            # Discretize x into dB
            # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
            hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
            dBx = (dB * hidden_states[..., None]).to(device=cache_device)

            # State calculation
            cache_params.update_ssm_state(
                layer_idx=self.layer_idx,
                new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx
            )

            # Subsequent output
            # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
            C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
            C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
            C = C.reshape(batch_size, -1, C.shape[-1])
            # [bsz, num_heads, head_dim]

            ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype)  # Shape: [b, h, d, n]
            # Reshape ssm_states to merge the first two dimensions
            ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size)  # Shape: [b*h, d, n]
            C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1)  # Shape: [b*h, n, 1]
            y = torch.bmm(ssm_states_reshaped, C_reshaped)
            y = y.view(batch_size, self.num_heads, self.head_dim)

            # D skip connection
            # [num_heads] -> [num_heads, head_dim]
            D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
            y = (y + hidden_states * D).to(y.dtype)

            # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
            y = y.reshape(batch_size, -1)[:, None, ...]
        else:
            # begin ssd naive implementation without einsums
            dt_plus_b = dt + self.dt_bias
            print(f"--> ssm (dt + bias): {dt_plus_b}")
            dt = nn.functional.softplus(dt_plus_b)
            dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
            print(f"--> ssm (dt softplus): {dt}")
            hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
            B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
            C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
            print(f"--> ssm (B before repeat): {B}")
            print(f"--> ssm (C before repeat): {C}")
            B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
            C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
            print(f"--> ssm (B): {B}")
            print(f"--> ssm (C): {C}")
            pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
            print(f"--> ssm (pad_size): {pad_size}")

            D = self.D[..., None]
            D_residual = D * pad_tensor_by_size(hidden_states, pad_size)
            print(f"--> ssm (D): {D}\nSUM: {D.sum()}\nSHAPE: {D.shape}")
            print(f"--> ssm (D_residual): {D_residual}")

            # Discretize x and A
            hidden_states = hidden_states * dt[..., None]
            A = A.to(hidden_states.dtype) * dt

            # Rearrange into blocks/chunks
            hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]

            # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
            A = A.permute(0, 3, 1, 2)
            A_cumsum = torch.cumsum(A, dim=-1)

            # 1. Compute the output for each intra-chunk (diagonal blocks)
            # This is the analog of a causal mask
            L = torch.exp(segment_sum(A))

            # Contraction of C and B to get G (attention-weights like)
            G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :]  # shape: (b, c, l, s, h, n)
            G = G_intermediate.sum(dim=-1)  # shape: (b, c, l, s, h)

            # Compute M, equivalent to applying attention mask to weights
            M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
            M = M_intermediate.sum(dim=-1)

            # Compute Y_diag (apply to values)
            Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)

            # 2. Compute the state for each intra-chunk
            # (right term of low-rank factorization of off-diagonal blocks; B terms)
            decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
            B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
            states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)

            # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
            # (middle term of factorization of off-diag blocks; A terms)
            if cache_params is not None and cache_position is not None and cache_position[0] > 0:
                previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
            else:
                previous_states = torch.zeros_like(states[:, :1])
            states = torch.cat([previous_states, states], dim=1)
            decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
            decay_chunk = decay_chunk.transpose(1, 3)
            new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
            states, ssm_state = new_states[:, :-1], new_states[:, -1]
            print(f"--> ssm states: {ssm_state}")

            # 4. Compute state -> output conversion per chunk
            # (left term of low-rank factorization of off-diagonal blocks; C terms)
            state_decay_out = torch.exp(A_cumsum)
            C_times_states = (C[..., None, :] * states[:, :, None, ...])
            state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
            Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])

            # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
            y = Y_diag + Y_off
            # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
            y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)

            y = y + D_residual
            print(f"--> ssm (y + D_residual): {y}")
            # Cutting off padded chunks
            if pad_size > 0:
                y = y[:, :seq_len, :, :]
            y = y.reshape(batch_size, seq_len, -1)
            print(f"--> ssm (y unpadded/reshaped): {y}")

            # Init cache
            if ssm_state is not None and cache_params is not None:
                cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)

        print(f"--> ssm y: {y}")
        scan_output = self.norm(y, gate)
        print(f"--> norm: {scan_output}")

        # end ssd naive

        # 4. Final linear projection
        contextualized_states = self.out_proj(scan_output.to(dtype))  # [batch, seq_len, hidden_size]
        print(f"--> mamba2 out: {contextualized_states}")
        return contextualized_states
    # fmt: on

    def forward(
        self,
        hidden_states,
        cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ):
        if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
            return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
        dtype = hidden_states.dtype
        if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
            # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
            hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)

        return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)


class NemotronHRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        NemotronHRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # Weights are in float32
        return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)

class NemotronHBlock(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.residual_in_fp32 = config.residual_in_fp32
        self.norm = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        # M: Mamba2, *: Attention, -: MLP
        self.block_type = config.layers_block_type[layer_idx]
        if self.block_type == "mamba":
            self.mixer = NemotronHMamba2Mixer(config, layer_idx=layer_idx)
        elif self.block_type == "attention":
            self.mixer = NEMOTRONH_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
        elif self.block_type == "mlp":
            self.mixer = NemotronHMLP(config, layer_idx=layer_idx)
        else:
            raise ValueError(f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}")

    def forward(
        self,
        hidden_states,
        cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ):
        # with (
        #     torch.cuda.stream(torch.cuda.default_stream(hidden_states.device))
        #     if torch.cuda.is_available()
        #     else nullcontext
        # ):

            print(f"[{self.layer_idx}] input: {hidden_states}\nSUM: {hidden_states.sum()}\nSHAPE: {hidden_states.shape}")

            # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
            residual = hidden_states
            hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)

            if self.block_type == "mamba":
                hidden_states = self.mixer(
                    hidden_states, cache_params=cache_params, cache_position=cache_position
                )
            elif self.block_type == "attention":
                hidden_states = self.mixer(
                    hidden_states, cache_position=cache_position
                )
                hidden_states = hidden_states[0]
            elif self.block_type == "mlp":
                hidden_states = self.mixer(
                    hidden_states
                )
            else:
                raise ValueError(f"Invalid block_type: {self.block_type}")

            hidden_states = residual + hidden_states
            print(f"[{self.layer_idx}] --------------------------")
            return hidden_states


# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
class NemotronHMLP(nn.Module):
    def __init__(self, config, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )
        self.hidden_size = config.hidden_size
        #intermediate_size = config.expand * config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        self.act_fn = ACT2FN[config.mlp_hidden_act]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.up_proj(x)))


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class NemotronHAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: NemotronHConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        if config.head_dim is not None:
            self.head_dim = config.head_dim
        else:
            self.head_dim = config.hidden_size // config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.is_causal = True

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=config.attention_bias)

    def forward(
        self,
        hidden_states: torch.Tensor,
        # position_embeddings: Tuple[torch.Tensor, torch.Tensor], #TODO
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if past_key_value is not None:
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]

        if query_states.device.type == "cuda" and attention_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )
        attn_output = attn_output.transpose(1, 2).contiguous()
        #attn_output = attn_output.view(bsz, q_len, self.hidden_size)
        attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value


# Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba
#class JambaFlashAttention2(JambaAttention):
class NemotronHFlashAttention2(NemotronHAttention):
    """
    Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays
    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    flash attention and deal with padding tokens in case the input contains any of them.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Flash attention requires the input to have the shape
        # batch_size x seq_length x head_dim x hidden_dim
        # therefore we just need to keep the original shape
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if past_key_value is not None:
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        dropout_rate = 0.0 if not self.training else self.attention_dropout

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in float16 just to be sure everything works as expected.
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        # Reashape to the expected shape for Flash Attention
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        attn_output = _flash_attention_forward(
            query_states,
            key_states,
            value_states,
            attention_mask,
            q_len,
            dropout=dropout_rate,
            sliding_window=getattr(self.config, "sliding_window", None),
            is_causal=self.is_causal,
            use_top_left_mask=self._flash_attn_uses_top_left_mask,
        )

        #attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


# Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba
#class JambaSdpaAttention(JambaAttention):
class NemotronHSdpaAttention(NemotronHAttention):
    """
    Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """

    # Adapted from NemotronHAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
            logger.warning_once(
                "NemotronHModel is using NemotronHSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if past_key_value is not None:
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and attention_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
        is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value


NEMOTRONH_ATTENTION_CLASSES = {
    "eager": NemotronHAttention,
    "flash_attention_2": NemotronHFlashAttention2,
    "sdpa": NemotronHSdpaAttention,
}

# Copied from transformers.models.mamba.modeling_mamba2.Mamba2PreTrainedModel
class NemotronHPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = NemotronHConfig
    base_model_prefix = "backbone"
    _no_split_modules = ["NemotronHBlock"]
    supports_gradient_checkpointing = True
    _is_stateful = True

    def _init_weights(self, module):
        """Initialize the weights."""
        if isinstance(module, NemotronHMamba2Mixer):
            module.A_log._no_weight_decay = True
            module.D._no_weight_decay = True

            dt = torch.exp(
                torch.rand(self.config.mamba_num_heads)
                * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
                + math.log(self.config.time_step_min)
            ).clamp(min=self.config.time_step_floor)

            # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
            inv_dt = dt + torch.log(-torch.expm1(-dt))
            with torch.no_grad():
                module.dt_bias.copy_(inv_dt)
            module.dt_bias._no_reinit = True

        if isinstance(module, nn.Linear):
            if module.bias is not None:
                if not getattr(module.bias, "_no_reinit", False):
                    nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=self.config.initializer_range)

        # TODO: Check
        if self.config.rescale_prenorm_residual:
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
            #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
            #
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
            for name, p in module.named_parameters():
                if name in ["out_proj.weight"]:
                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
                    # We need to reinit p since this code could be called multiple times
                    # Having just p *= scale would repeatedly scale it down
                    nn.init.kaiming_uniform_(p, a=math.sqrt(5))
                    with torch.no_grad():
                        p /= math.sqrt(self.config.num_hidden_layers)


@dataclass
# Copied from transformers.models.mamba.modeling_mamba2.Mamba2Output with MAMBA2->NemotronH,Mamba2->NemotronH
class NemotronHOutput(ModelOutput):
    """
    Class for the NemotronH model outputs.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        cache_params (`HybridMambaAttentionDynamicCache`):
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
            avoid providing the old `input_ids`.

            Includes both the State space model state matrices after the selective scan, and the Convolutional states
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    """

    last_hidden_state: Optional[torch.FloatTensor] = None
    cache_params: Optional[HybridMambaAttentionDynamicCache] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
# Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
class NemotronHCausalLMOutput(ModelOutput):
    """
    Base class for causal language model (or autoregressive) outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        cache_params (`HybridMambaAttentionDynamicCache`):
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
            avoid providing the old `input_ids`.

            Includes both the State space model state matrices after the selective scan, and the Convolutional states
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    """

    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    cache_params: Optional[HybridMambaAttentionDynamicCache] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


NEMOTRONH_START_DOCSTRING = r"""

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`NemotronHConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

NEMOTRONH_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
            Indices of input sequence tokens in the vocabulary.

            If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        position_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings.
        cache_params (`HybridMambaAttentionDynamicCache`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        use_cache (`bool`, *optional*):
            If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
            If `cache_params` is passed, `cache_position` should also be passed.
        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
"""


@add_start_docstrings(
    "The bare NemotronH Model transformer outputting raw hidden-states without any specific head on top.",
    NEMOTRONH_START_DOCSTRING,
)
class NemotronHModel(NemotronHPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([NemotronHBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])

        self.gradient_checkpointing = False
        self.norm_f = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        # Initialize weights and apply final processing
        self._register_load_state_dict_pre_hook(self.load_hook)
        self.post_init()

    def load_hook(self, state_dict, prefix, *args):
        for k in state_dict:
            if "embedding." in k:
                state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
                break

    def get_input_embeddings(self):
        return self.embeddings

    def set_input_embeddings(self, new_embeddings):
        self.embeddings = new_embeddings

    @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=NemotronHOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[Tuple, NemotronHOutput]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # use_cache = use_cache if use_cache is not None else self.config.use_cache
        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):  # ^ is python for xor
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embeddings(input_ids)

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        # From zamba_modeling.py
        if use_cache and cache_params is None:
            logger.warning_once(
                "NemotronH requires an initialized `NemotronHHybridDynamicCache` to return a cache. None was "
                "provided, so no cache will be returned."
            )

        hidden_states = inputs_embeds

        if cache_position is None:
            cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
        mamba_mask = self._update_mamba_mask(attention_mask, cache_position)

        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        # Until HERE

        for layer_idx, mixer_block in enumerate(self.layers):
            # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
            if mixer_block.block_type == "mamba":
                layer_mask = mamba_mask
            elif mixer_block.block_type == "attention":
                layer_mask = causal_mask
            elif mixer_block.block_type == "mlp":
                layer_mask = None
            else:
                raise ValueError(f"Invalid block_type: {self.block_type}")

            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                hidden_states = self._gradient_checkpointing_func(
                    mixer_block.__call__, hidden_states, cache_params, cache_position, layer_mask
                )
            else:
                hidden_states = mixer_block(
                    hidden_states,
                    cache_params=cache_params,
                    cache_position=cache_position,
                    attention_mask=layer_mask,
                )

            # TODO: Store attentions
            # if output_attentions:
            #     if layer_outputs[1] is not None:
            #         # append attentions only of attention layers. Mamba layers return `None` as the attention weights
            #         all_self_attns += (layer_outputs[1],)

            # TODO (Check): should it happen before the forward pass?
            # if output_hidden_states:
            #     all_hidden_states = all_hidden_states + (hidden_states,)

        hidden_states = self.norm_f(hidden_states)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)

        return NemotronHOutput(
            last_hidden_state=hidden_states,
            cache_params=cache_params if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    # Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask
    def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and 0.0 in attention_mask:
                return attention_mask
            return None

        dtype, device = input_tensor.dtype, input_tensor.device
        min_dtype = torch.finfo(dtype).min
        sequence_length = input_tensor.shape[1]
        target_length = cache_position[-1] + 1

        causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
        if sequence_length != 1:
            causal_mask = torch.triu(causal_mask, diagonal=1)
        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
        causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
        if attention_mask is not None:
            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
            if attention_mask.dim() == 2:
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
                causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)

        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type == "cuda"
        ):
            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
            # Details: https://github.com/pytorch/pytorch/issues/110213
            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

        return causal_mask

    def _update_mamba_mask(self, attention_mask, cache_position):
        """
        No need for zeroing states when
            1. Cached forward
            2. Attending to all inputs
        """
        mamba_mask = attention_mask
        if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
            mamba_mask = None
        return mamba_mask


@add_start_docstrings(
    """
    The NEMOTRONH Model transformer with a language modeling head on top (linear layer with weights not tied to the input
    embeddings).
    """,
    NEMOTRONH_START_DOCSTRING,
)
class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.backbone = NemotronHModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.backbone.get_input_embeddings()

    def set_input_embeddings(self, new_embeddings):
        return self.backbone.set_input_embeddings(new_embeddings)

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def get_decoder(self):
        return self.model

    def set_decoder(self, decoder):
        self.model = decoder

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        cache_position=None,
        position_ids=None,
        use_cache=True,
        **kwargs,
    ):
        # Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py
        # Overwitten -- uses `cache_params` as opposed to `past_key_values`
        empty_past_kv = past_key_values is None

        # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
        # Exception 1: when passing input_embeds, input_ids may be missing entries
        # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
        # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
        #              (we can't check exception 3 while compiling)
        if not empty_past_kv:
            if (
                inputs_embeds is not None  # Exception 1
                or cache_position[-1] >= input_ids.shape[1]  # Exception 3
            ):
                input_ids = input_ids[:, -cache_position.shape[0] :]
            elif input_ids.shape[1] != cache_position.shape[0]:  # Default case (the "else", a no op, is Exception 2)
                input_ids = input_ids[:, cache_position]
        else:
            past_key_values = HybridMambaAttentionDynamicCache(
                self.config, input_ids.shape[0], self.dtype, device=self.device
            )

        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if not empty_past_kv:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and empty_past_kv:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids.contiguous()}  # `contiguous()` needed for compilation use cases

        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": use_cache,
                "attention_mask": attention_mask,
                "logits_to_keep": self.config.num_logits_to_keep,
                "cache_position": cache_position,
            }
        )
        return model_inputs

    @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=NemotronHCausalLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs,  # for now we need this for generation
    ) -> Union[Tuple, NemotronHCausalLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        nemotron_h_outputs = self.backbone(
            input_ids,
            cache_params=cache_params,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            use_cache=use_cache,
            cache_position=cache_position,
            attention_mask=attention_mask,
        )
        hidden_states = nemotron_h_outputs[0]

        # TODO: Check zamba_modeling.py: https://github.com/huggingface/transformers/blob/d7188ba600e36d3fd191b12e19f1b3bb81a8404f/src/transformers/models/zamba/modeling_zamba.py#L1284C1-L1286C2
        #logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
        logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()

        loss = None
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(logits.device)
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if not return_dict:
            output = (logits,) + nemotron_h_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return NemotronHCausalLMOutput(
            loss=loss,
            logits=logits,
            cache_params=nemotron_h_outputs.cache_params,
            hidden_states=nemotron_h_outputs.hidden_states,
            attentions=nemotron_h_outputs.attentions,
        )

@gabe-l-hart
Copy link
Collaborator Author

gabe-l-hart commented Aug 26, 2025

I typed that last comment while rushing off to kid time yesterday. Here are the rest of the details:

Steps to repro with transformers

  1. Use the above modeling_nemotron_h.py in place of the out-of-the-box one in your downloaded model
  2. Run the following script (with appropriate path modifications) piped to some log file somewhere
from transformers import AutoTokenizer, AutoModelForCausalLM

prompt = "hello"
model_path = "/Users/ghart/models/nvidia/NVIDIA-Nemotron-Nano-9B-v2"
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokens = tokenizer(prompt, return_tensors="pt")

model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
res = model.generate(max_new_tokens=1, **tokens)

steps to repro with llama-eval-callback

  1. Apply the following diff to llama-model.cpp (not committing it because it's super hacky-debuggy)
llama-model-debug.patch
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 286fb99f5..48c2b2a0d 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -11297,8 +11297,10 @@ struct llm_graph_context_mamba : public llm_graph_context {
 
         // split the above in three
         ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
+        z = ggml_set_name(z, "zxBCdt_z");
         ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
         ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
+        dt = ggml_set_name(dt, "zxBCdt_dt");
 
         // conv
         {
@@ -11342,6 +11344,9 @@ struct llm_graph_context_mamba : public llm_graph_context {
             dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
 
             ggml_tensor * A = model.layers[il].ssm_a;
+            //DEBUG
+            A = ggml_scale(ctx0, A, 1.0);
+            cb(A, "ssm_A", il);
 
             // use the states and the indices provided by build_recurrent_state
             // (this is necessary in order to properly use the states before they are overwritten,
@@ -11361,18 +11366,25 @@ struct llm_graph_context_mamba : public llm_graph_context {
                 ggml_cpy(ctx0,
                     ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
                     ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
+            cb(y_ssm, "y_ssm", il);
 
             ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
 
             // TODO: skip computing output earlier for unused tokens
 
-            y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
+            //DEBUG
+            ggml_tensor * D = ggml_scale(ctx0, model.layers[il].ssm_d, 1.0);
+            cb(D, "D", il);
+
+            y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D));
             y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
 
             // grouped RMS norm
             if (model.layers[il].ssm_norm) {
                 y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
-                y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
+                //DEBUG
+                ggml_tensor * ssm_norm_weight = ggml_scale(ctx0, model.layers[il].ssm_norm, 1.0);
+                y = build_norm(y, ssm_norm_weight, NULL, LLM_NORM_RMS, il);
             }
 
             y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
  1. Run llama-eval-callback
./bin/llama-eval-callback -m ~/models/nvidia/NVIDIA-Nemotron-Nano-9B-v2/NVIDIA-Nemotron-Nano-9B-v2-F16.gguf -p "hello" -n 1 -ngl 0 -t 1 2>&1 | tee hello-tensors.log

NOTE 1: I've found that in order to get input tensors (either weights or request input tensors) to print un-modified in the trace output, the simplest way is to put an artificial ggml_scale(ctx0, <tensor>, 1.0) in the graph and then look for the output of that node.

NOTE 2: In order for this side-by-side to be valid, the prompts must be identical (including case and the addition of the BOS token)

@gabe-l-hart gabe-l-hart force-pushed the gabe-l-hart/nvidia-nemotron-nano-15409 branch from e391dc3 to 9a9de40 Compare August 26, 2025 15:36
@gabe-l-hart
Copy link
Collaborator Author

Ok, I think I've figured out why this one is behaving differently than others. From what I can tell, all other mamba2 models we've tested (hybrid or otherwise) set num_groups to 1.

This masks a difference in how the transformers implementations of mamba2 perform the post-SSM norm and how it's currently implemented in llama.cpp. In llama.cpp, the norm is done over {d_inner / n_group, n_group, n_seq_tokens, n_seqs} (here). In the common transformers implementation that gets copied everywhere, it's implemented over {d_inner, n_seqs} (eg here). This results in the mean being calculated very differently and thus a very different norm output!

@gabe-l-hart
Copy link
Collaborator Author

gabe-l-hart commented Aug 26, 2025

Confirmed! By making the following change here, I now see the mamba2 layer outputs lining up exactly:

diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index c77b28e26..cc7881ab9 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -11371,8 +11371,12 @@ struct llm_graph_context_mamba : public llm_graph_context {
 
             // grouped RMS norm
             if (model.layers[il].ssm_norm) {
-                y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
-                y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
+                // y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
+                // y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
+                //DEBUG
+                y = ggml_reshape_2d(ctx0, y, d_inner, n_seq_tokens * n_seqs);
+                ggml_tensor * ssm_norm_1d = ggml_reshape_1d(ctx0, model.layers[il].ssm_norm, d_inner);
+                y = build_norm(y, ssm_norm_1d, NULL, LLM_NORM_RMS, il);
             }
 
             y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);

NOTE: Full generation is still garbage, so something else is still broken. Baby steps!

@gabe-l-hart
Copy link
Collaborator Author

I've also confirmed that this patch does not adversely effect mamba2-370m or granite-4.0-tiny-preview, so I think this should be a safe change to make for models with n_groups == 1.

@compilade
Copy link
Collaborator

From what I can tell, all other mamba2 models we've tested (hybrid or otherwise) set num_groups to 1

@gabe-l-hart
Mamba-Codestral-7B-v0.1 has n_groups to 8, though, and I remember testing it at the time.

https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1/blob/main/config.json#L18

Does this model handle groups differently?

@gabe-l-hart
Copy link
Collaborator Author

@compilade Thanks for pointing that out! No idea right now, but I'll take a look.

@gabe-l-hart
Copy link
Collaborator Author

I see that Mamba-Codestral-7B-v0.1 uses Mamba2ForCausalLM and that the modeling file for that explicitly flattens the hidden states before running the norm (here and here). Is it possible that this was a bug that ultimately didn't make much difference to the output? My results are still garbage, so I suspect there's something else significant in the non-recurrent layers that's broken, so maybe if I fixed that, this difference would not show up?

@gabe-l-hart
Copy link
Collaborator Author

I just downloaded and tested https://huggingface.co/gabriellarson/Mamba-Codestral-7B-v0.1-GGUF (F16) and a sniff test prompt produced correct tokens with the change to flatten before the norm.

@gabe-l-hart
Copy link
Collaborator Author

With the latest changes, the tensor values stay close through the entire prefill. There is definitely some precision drift, and the decoded output tokens still seem to be broken, so I'm not clear yet if it's caused by the drift or by something else missing in the implementation.

@gabe-l-hart
Copy link
Collaborator Author

More interesting info: I tried running my same dummy prompts with the modified modeling_nemotron_h.py that I posted earlier which enabled CPU inference. I got bad results (not the same bad results that I'm getting here, but bad results). I then got things up and running on my CUDA box and ran the same prompt and got good results, so it seems that something in the CPU code is broken for the modeling_nemotron_h.py posted with the model. This isn't totally surprising since it explicitly states that it requires CUDA, but the fact that all CUDA-specific code paths also have CPU code paths (copied from other models most likely) made me think it should work. It looks like I'll need to dig further on the delta between the CUDA version and the CPU version.

@gabe-l-hart
Copy link
Collaborator Author

It looks like the difference between CUDA/CPU is (again) in the gated RMS norm. This makes me wonder if the original non-flattened implementation is correct @compilade and it's just happening in the optimized kernels on the transformers side.

Use GGUF-provided time_step_rank (ssm_dt_rank) to set dt_dim when > 0; fallback to max(64, n_embd/16).
@gabe-l-hart
Copy link
Collaborator Author

I've also verified that the model gives valid outputs when quantized with MXFP4_MOE and Q4_K_M. Interestingly, the results are slightly different (not unexpected given the quantization). Also interestingly, the Q4_K_M struggled to terminate for my stock sample prompt ("tell me a story about a developer and their dog"). It concluded the classic story about Alex and Max (it is shocking how many models choose Alex and Max for this prompt), but then \n output: (which seems to be its preferred prefix for responding) then gave a totally fresh response and kept generating.

@jwjohns
Copy link
Contributor

jwjohns commented Aug 28, 2025

@gabe-l-hart i am getting my naming commit setup now. Apologies for the delay.

@gabe-l-hart
Copy link
Collaborator Author

No problem at all! No rush on my end.

jwjohns and others added 2 commits August 28, 2025 14:39
- Update architecture name from NEMOTRONH to NEMOTRON_H in constants.py
- Change architecture string from 'nemotronh' to 'nemotron_h' in all files
- Update enum LLM_ARCH_NEMOTRONH to LLM_ARCH_NEMOTRON_H
- Update class name llm_build_nemotronh to llm_build_nemotron_h
- Consistent naming with underscore convention (nemotron_h vs nemotronh)
@gabe-l-hart
Copy link
Collaborator Author

All contributor changes are now merged, so it should be ready for final review @ggerganov (or others)

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!

@ggerganov
Copy link
Member

I've also verified that the model gives valid outputs when quantized with MXFP4_MOE

MXFP4 quantization is only supposed to work when the BF16 weights are already upscaled from existing MXFP4 quantization. So not recommended to use it for anything else.

@gabe-l-hart
Copy link
Collaborator Author

MXFP4 quantization is only supposed to work when the BF16 weights are already upscaled from existing MXFP4 quantization. So not recommended to use it for anything else.

Got it, that's great to know. Anecdotally, I've been playing with it as a general quantization scheme and it seems to do pretty well, though the resulting size compression seems to be very different depending on the model architecture (same size as Q4_K_M for Granite4 which uses MoE, but almost 2x size for Granite3 which is dense).

@jacekpoplawski
Copy link
Contributor

I assume it adds support for NVIDIA-Nemotron-Nano-9B-v2, but does that mean it will also help with adding support for older Nemotron-H models, like Nemotron-H-47B-Base-8K in the future?

@gabe-l-hart
Copy link
Collaborator Author

@jacekpoplawski That's a good question. I haven't tested any of the older ones. Any chance there's a small-ish version of V1 you can point me at to download and test? Theoretically, assuming the architecture has not changed, this should support them all, but the devil may be in the details of how the hparams are used.

@gabe-l-hart
Copy link
Collaborator Author

I see https://huggingface.co/nvidia/Nemotron-H-8B-Reasoning-128K. I'll see if I can pull it down and test it.

@jacekpoplawski
Copy link
Contributor

@jacekpoplawski That's a good question. I haven't tested any of the older ones. Any chance there's a small-ish version of V1 you can point me at to download and test? Theoretically, assuming the architecture has not changed, this should support them all, but the devil may be in the details of how the hparams are used.

A few months ago, after hybrid/mamba support was added, I wanted to work on support for https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K but when I chatted with ChatGPT, it told me that some parts of the model were still hard to implement 🙂 I wonder if this is handled now by your changes.

@gabe-l-hart
Copy link
Collaborator Author

@jacekpoplawski There are some errors in conversion that I'll need to poke through.

@gabe-l-hart
Copy link
Collaborator Author

Good news! It was just a couple of mis-converted hparams. With those fixed, https://huggingface.co/nvidia/Nemotron-H-8B-Reasoning-128K converts and runs cleanly.

@DominguesM
Copy link
Contributor

Good news! It was just a couple of mis-converted hparams. With those fixed, https://huggingface.co/nvidia/Nemotron-H-8B-Reasoning-128K converts and runs cleanly.

Nice, I was testing this too. I used it as a fallback when head_dim is missing:

self.head_dim = self.hparams.get("head_dim")
if self.head_dim is None:
    self.head_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]

@gabe-l-hart gabe-l-hart merged commit e8d99dd into ggml-org:master Aug 29, 2025
131 of 134 checks passed
@gabe-l-hart gabe-l-hart deleted the gabe-l-hart/nvidia-nemotron-nano-15409 branch August 29, 2025 00:39
@jwjohns
Copy link
Contributor

jwjohns commented Aug 29, 2025

I assume it adds support for NVIDIA-Nemotron-Nano-9B-v2, but does that mean it will also help with adding support for older Nemotron-H models, like Nemotron-H-47B-Base-8K in the future?

I was already exploring this to see, I’ll keep you posted

qnixsynapse pushed a commit to menloresearch/llama.cpp that referenced this pull request Aug 30, 2025
* feat: Add NEMOTRONH to python arch enum

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Add NEMOTRONH to c++ arch enum

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Add NEMOTRONH to llama-arch layer map

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: First pass at conversion for nemotronh

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Add a verbose log for each tensor loaded

This is really helpful for diagnosing mismatches between the expected and
received tensors

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: First (broken) pass at nemotronh model architecture

It generates tokens, just not valid ones!

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Explicitly enable add_bos_token during conversion

The `tokenizer.json`/`tokenizer_config.json` in the model are a bit
contradictory. In the config, add_bos_token is set to False, but the
tokenizer model itself has a post_processor that adds the BOS token via
type: TemplateProcessing

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Use relu2 (LLM_FFN_RELU_SQR) for activation in FFN layers

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Only allocate attention cache for attention layers (not non-recurrent)

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Move residual add to after every block

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Use the correct norm tensor for the MLP blocks

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>

* Nemotron-H: MLP gate cleanup (pass NULL for unused gate)

This model does not use a gate in MLP blocks; pass NULLs for gate tensors to make intent clear and avoid unused-pointer noise.

* SSM: respect ssm_dt_rank for dt_dim when provided

Use GGUF-provided time_step_rank (ssm_dt_rank) to set dt_dim when > 0; fallback to max(64, n_embd/16).

* fix: plamo2 - revert dt_dim to default (remove ssm_dt_rank usage)

* Rename nemotronh to nemotron_h for consistency

- Update architecture name from NEMOTRONH to NEMOTRON_H in constants.py
- Change architecture string from 'nemotronh' to 'nemotron_h' in all files
- Update enum LLM_ARCH_NEMOTRONH to LLM_ARCH_NEMOTRON_H
- Update class name llm_build_nemotronh to llm_build_nemotron_h
- Consistent naming with underscore convention (nemotron_h vs nemotronh)

* feat: Support conversion for older NemotronH models

https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409
Branch: gabe-l-hart/nvidia-nemotron-nano-15409

Signed-off-by: Gabe Goodhart <[email protected]>

---------

Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Maicon Domingues <[email protected]>
Co-authored-by: weatherman <[email protected]>
@jacekpoplawski
Copy link
Contributor

jacekpoplawski commented Aug 30, 2025

I was able to load https://huggingface.co/bartowski/nvidia_Nemotron-H-47B-Reasoning-128K-GGUF

EDIT I have some issues with flash attention but looks like they are not related to specific model

@jwjohns
Copy link
Contributor

jwjohns commented Aug 31, 2025

@jacekpoplawski any specific error? I can fire it up today and see what’s occurring.

@jacekpoplawski
Copy link
Contributor

@jacekpoplawski any specific error? I can fire it up today and see what’s occurring.

I thought there was an issue with nemotron, but it was introducted by another PR #15434

Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Oct 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model Model specific python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: Support for NVidia Nemotron Nano v2
7 participants