Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
17fa9d5
feat: Add NEMOTRONH to python arch enum
gabe-l-hart Aug 21, 2025
36c88f7
feat: Add NEMOTRONH to c++ arch enum
gabe-l-hart Aug 21, 2025
62e66c6
feat: Add NEMOTRONH to llama-arch layer map
gabe-l-hart Aug 21, 2025
abe1e89
feat: First pass at conversion for nemotronh
gabe-l-hart Aug 21, 2025
c25c149
feat: Add a verbose log for each tensor loaded
gabe-l-hart Aug 21, 2025
828176e
feat: First (broken) pass at nemotronh model architecture
gabe-l-hart Aug 21, 2025
3191a8d
fix: Explicitly enable add_bos_token during conversion
gabe-l-hart Aug 25, 2025
37c42c9
fix: Use relu2 (LLM_FFN_RELU_SQR) for activation in FFN layers
gabe-l-hart Aug 26, 2025
9a9de40
fix: Only allocate attention cache for attention layers (not non-recu…
gabe-l-hart Aug 26, 2025
9d4e0d7
fix: Move residual add to after every block
gabe-l-hart Aug 26, 2025
cb03b4f
Merge remote-tracking branch 'origin/master' into gabe-l-hart/nvidia-…
gabe-l-hart Aug 28, 2025
3310f91
fix: Use the correct norm tensor for the MLP blocks
gabe-l-hart Aug 28, 2025
3132915
Merge remote-tracking branch 'origin/master' into gabe-l-hart/nvidia-…
gabe-l-hart Aug 28, 2025
ab53234
Nemotron-H: MLP gate cleanup (pass NULL for unused gate)
DominguesM Aug 28, 2025
b3304da
SSM: respect ssm_dt_rank for dt_dim when provided
DominguesM Aug 28, 2025
4223a1f
fix: plamo2 - revert dt_dim to default (remove ssm_dt_rank usage)
DominguesM Aug 28, 2025
7503535
Merge pull request #3 from DominguesM/nvidia-nemotron-nano-v2
gabe-l-hart Aug 28, 2025
f2165dd
Rename nemotronh to nemotron_h for consistency
jwjohns Aug 28, 2025
3732916
Merge pull request #4 from jwjohns/nemotron-h-naming-update
gabe-l-hart Aug 28, 2025
19f1dc6
feat: Support conversion for older NemotronH models
gabe-l-hart Aug 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7546,9 +7546,13 @@ def __init__(self, *args, **kwargs):
]

# n_group and d_inner are used during reshape_tensors for mamba2
self.d_model = self.find_hparam(["hidden_size", "d_model"])
self.n_group = self.find_hparam(["n_groups"])
self.d_inner = self.find_hparam(["expand"]) * self.d_model
# NOTE: Explicitly include hparam prefix prefix for d_model to
# disambiguate with top-level head_dim
# NOTE 2: If needed for future models, this can be isolated in a method
# to separate the prefix setting and teh keys used
self.d_model = self.find_hparam([f"{self.hparam_prefixes[0]}_head_dim", "hidden_size", "d_model"])
self.n_group = self.find_hparam(["n_groups", "num_groups"])
self.d_inner = self.find_hparam(["expand", "num_heads"]) * self.d_model

def get_attn_layers(self):
# Explicit list of layer type names
Expand Down Expand Up @@ -7609,12 +7613,12 @@ def set_gguf_parameters(self):

## Mamba mixer params ##
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state", "state_dim", "ssm_state_size"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
# in llama.cpp
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads", "num_heads"]))

## Attention params ##
head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
Expand All @@ -7641,6 +7645,55 @@ def set_vocab(self):
Mamba2Model.set_vocab(self)


@ModelBase.register("NemotronHForCausalLM")
class NemotronHModel(GraniteHybridModel):
"""Hybrid mamba2/attention model from NVIDIA"""
model_arch = gguf.MODEL_ARCH.NEMOTRON_H

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Save the top-level head_dim for later
self.head_dim = self.hparams.get("head_dim", self.hparams.get("attention_head_dim"))
assert self.head_dim is not None, "Could not find the attention head dim in config"

# Don't use expand to calculate d_inner
self.d_inner = self.find_hparam(["num_heads"]) * self.d_model

# Update the ssm / attn / mlp layers
# M: Mamba2, *: Attention, -: MLP
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]

def get_attn_layers(self):
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!"
return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"]

def set_gguf_parameters(self):
super().set_gguf_parameters()

self.gguf_writer.add_key_length(self.head_dim)
self.gguf_writer.add_value_length(self.head_dim)

# Set feed_forward_length
# NOTE: This will trigger an override warning. This is preferrable to
# duplicating all the parent logic
n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
self.gguf_writer.add_feed_forward_length([
n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
])

def set_vocab(self):
super().set_vocab()

# The tokenizer _does_ add a BOS token (via post_processor type
# TemplateProcessing) but does not set add_bos_token to true in the
# config, so we need to explicitly override it here.
self.gguf_writer.add_add_bos_token(True)


@ModelBase.register("BailingMoeForCausalLM")
class BailingMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE
Expand Down
21 changes: 21 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ class MODEL_ARCH(IntEnum):
T5ENCODER = auto()
JAIS = auto()
NEMOTRON = auto()
NEMOTRON_H = auto()
EXAONE = auto()
EXAONE4 = auto()
GRANITE = auto()
Expand Down Expand Up @@ -700,6 +701,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.NEMOTRON_H: "nemotron_h",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.EXAONE4: "exaone4",
MODEL_ARCH.GRANITE: "granite",
Expand Down Expand Up @@ -2297,6 +2299,25 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.NEMOTRON_H: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_OUT,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.EXAONE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.q_proj", # llama4
"model.transformer.blocks.{bid}.q_proj", # llada
"layers.{bid}.self_attn.q_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.q_proj", # nemotron-h
),

# Attention key
Expand All @@ -209,6 +210,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.k_proj", # llama4
"model.transformer.blocks.{bid}.k_proj", # llada
"layers.{bid}.self_attn.k_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.k_proj", # nemotron-h
),

# Attention value
Expand All @@ -226,6 +228,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.v_proj", # llama4
"model.transformer.blocks.{bid}.v_proj", # llada
"layers.{bid}.self_attn.v_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.v_proj", # nemotron-h
),

# Attention output
Expand Down Expand Up @@ -260,6 +263,7 @@ class TensorNameMap:
"transformer_encoder.{bid}.wo", # neobert
"model.transformer.blocks.{bid}.attn_out", # llada
"layers.{bid}.self_attn.o_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.o_proj", # nemotron-h
),

# Attention output norm
Expand Down Expand Up @@ -387,6 +391,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.up", # smallthinker
"model.transformer.blocks.{bid}.up_proj", # llada
"layers.{bid}.mlp.up_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.up_proj", # nemotron-h
),

MODEL_TENSOR.FFN_UP_EXP: (
Expand Down Expand Up @@ -480,6 +485,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.down", # smallthinker
"model.transformer.blocks.{bid}.ff_out", # llada
"layers.{bid}.mlp.down_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.down_proj", # nemotron-h
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand Down
27 changes: 27 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_T5ENCODER, "t5encoder" },
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_NEMOTRON_H, "nemotron_h" },
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_EXAONE4, "exaone4" },
{ LLM_ARCH_RWKV6, "rwkv6" },
Expand Down Expand Up @@ -1550,6 +1551,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_NEMOTRON_H,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
// mamba(2) ssm layers
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
// attention layers
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
// dense FFN
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_EXAONE,
{
Expand Down Expand Up @@ -2355,6 +2381,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_PLAMO2:
case LLM_ARCH_GRANITE_HYBRID:
case LLM_ARCH_LFM2:
case LLM_ARCH_NEMOTRON_H:
return true;
default:
return false;
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum llm_arch {
LLM_ARCH_T5ENCODER,
LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON,
LLM_ARCH_NEMOTRON_H,
LLM_ARCH_EXAONE,
LLM_ARCH_EXAONE4,
LLM_ARCH_RWKV6,
Expand Down
1 change: 1 addition & 0 deletions src/llama-model-loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri
}

struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags) {
LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str());
const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));

if (cur == NULL) {
Expand Down
Loading
Loading