diff --git a/tensorrt_llm/_torch/models/__init__.py b/tensorrt_llm/_torch/models/__init__.py
index e49e37a6e67..177bae4ad97 100644
--- a/tensorrt_llm/_torch/models/__init__.py
+++ b/tensorrt_llm/_torch/models/__init__.py
@@ -7,6 +7,7 @@
from .modeling_exaone4 import Exaone4ForCausalLM
from .modeling_gemma3 import Gemma3ForCausalLM
from .modeling_gemma3vl import Gemma3VLM
+from .modeling_glm import Glm4MoeForCausalLM
from .modeling_gpt_oss import GptOssForCausalLM
from .modeling_hunyuan_dense import HunYuanDenseV1ForCausalLM
from .modeling_hunyuan_moe import HunYuanMoEV1ForCausalLM
@@ -70,6 +71,7 @@
"Qwen3NextForCausalLM",
"GptOssForCausalLM",
"SeedOssForCausalLM",
+ "Glm4MoeForCausalLM",
]
if transformers.__version__ >= "4.45.1":
diff --git a/tensorrt_llm/_torch/models/modeling_glm.py b/tensorrt_llm/_torch/models/modeling_glm.py
new file mode 100644
index 00000000000..be300bcf080
--- /dev/null
+++ b/tensorrt_llm/_torch/models/modeling_glm.py
@@ -0,0 +1,913 @@
+import math
+import os
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from torch import nn
+from tqdm import tqdm
+from transformers import PretrainedConfig
+
+from tensorrt_llm._ipc_utils import can_access_peer
+from tensorrt_llm._utils import get_sm_version, is_sm_100f
+from tensorrt_llm.functional import PositionEmbeddingType
+from tensorrt_llm.models.modeling_utils import QuantConfig
+from tensorrt_llm.quantization.mode import QuantAlgo
+from tensorrt_llm.quantization.utils.fp8_utils import (
+ resmooth_to_fp8_e8m0,
+ transform_sf_into_required_layout,
+)
+
+from ..attention_backend import AttentionMetadata
+from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
+from ..distributed import (
+ AllReduce,
+ AllReduceFusionOp,
+ AllReduceParams,
+ MoEAllReduce,
+ MoEAllReduceParams,
+)
+from ..model_config import ModelConfig
+from ..modules.decoder_layer import DecoderLayer
+from ..modules.embedding import Embedding
+from ..modules.fused_moe import MoEWeightLoadingMode, create_moe
+from ..modules.gated_mlp import GatedMLP
+from ..modules.linear import Linear, TensorParallelMode
+from ..modules.multi_stream_utils import maybe_execute_in_parallel
+from ..modules.qk_norm_attention import QKNormRoPEAttention
+from ..modules.rms_norm import RMSNorm
+from ..speculative import SpecMetadata
+from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
+from .modeling_deepseekv3 import DeepseekV3Gate, DeepseekV3MTPHead, moe_reduce_add_shared_output
+from .modeling_speculative import SpecDecOneEngineForCausalLM
+from .modeling_utils import DecoderModel, EagerFusionConfig, _load_weights_impl, register_auto_model
+
+
+class Glm4Attention(QKNormRoPEAttention):
+ def __init__(
+ self,
+ model_config: ModelConfig[PretrainedConfig],
+ layer_idx: Optional[int] = None,
+ ):
+ config = model_config.pretrained_config
+ pos_embd_params = PositionalEmbeddingParams(
+ type=PositionEmbeddingType.yarn,
+ rope=RopeParams.from_config(config),
+ )
+
+ super().__init__(
+ hidden_size=config.hidden_size,
+ num_attention_heads=config.num_attention_heads,
+ num_key_value_heads=config.num_key_value_heads,
+ max_position_embeddings=config.max_position_embeddings,
+ bias=config.attention_bias,
+ pos_embd_params=pos_embd_params,
+ fuse_qk_norm_rope=False,
+ layer_idx=layer_idx,
+ dtype=config.torch_dtype,
+ dense_bias=False,
+ config=model_config,
+ )
+
+
+class Glm4MoE(nn.Module):
+ def __init__(
+ self,
+ *,
+ num_experts: int,
+ top_k: int,
+ hidden_size: int,
+ intermediate_size: int,
+ shared_expert_intermediate_size: int,
+ aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
+ dtype: Optional[torch.dtype] = None,
+ model_config: ModelConfig = ModelConfig(),
+ override_quant_config: Optional[QuantConfig] = None,
+ layer_idx: Optional[int] = None,
+ ):
+ from ..distributed import AllReduce
+
+ super().__init__()
+ config = model_config.pretrained_config
+ self.top_k = top_k
+ self.use_dp = model_config.mapping.enable_attention_dp
+ self.gate = DeepseekV3Gate(
+ hidden_size,
+ num_experts,
+ top_k=top_k,
+ n_group=config.n_group,
+ topk_group=config.topk_group,
+ routed_scaling_factor=config.routed_scaling_factor,
+ dtype=dtype,
+ fuse_routing_kernel=False,
+ apply_routing=False,
+ moe_backend=model_config.moe_backend,
+ )
+ self.experts = create_moe(
+ num_experts=num_experts,
+ routing_method=self.gate.routing_method,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ dtype=dtype,
+ reduce_results=False, # In both low‑latency and attention‑DP modes, FusedMoE skips the in‑op all‑reduce.
+ model_config=model_config,
+ override_quant_config=override_quant_config,
+ aux_stream_dict=aux_stream_dict,
+ layer_idx=layer_idx,
+ weight_loading_mode=MoEWeightLoadingMode.VANILLA,
+ )
+
+ self.mapping = model_config.mapping
+
+ # FIXME: incompatible with mixed quantization mode (including excluding modules from quantization)
+ block_size = 1
+ if (
+ model_config.quant_config
+ and model_config.quant_config.quant_algo
+ and model_config.quant_config.group_size is not None
+ ):
+ block_size = model_config.quant_config.group_size
+
+ shared_tp_size, self.shared_output_scale = self._compute_shared_expert_tp_size(
+ shared_expert_intermediate_size, block_size
+ )
+
+ self.shared_experts = GatedMLP(
+ hidden_size=hidden_size,
+ intermediate_size=shared_expert_intermediate_size,
+ bias=False,
+ dtype=dtype,
+ config=model_config,
+ overridden_tp_size=shared_tp_size,
+ reduce_output=False,
+ )
+
+ self.allreduce = AllReduce(
+ mapping=model_config.mapping, strategy=model_config.allreduce_strategy
+ )
+ self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
+ self.event_dict = {key: torch.cuda.Event() for key in [EventType.Main, EventType.MoeShared]}
+
+ def _compute_shared_expert_tp_size(
+ self, intermediate_size: int, block_size: int
+ ) -> tuple[int, float | None]:
+ """
+ In the case of GLM4, the TP size of MLP is capped by intermediate_size // block_size.
+ For example, when the intermediate_size is 2048 and block scaling size is 128,
+ TP sizes are limited to {1, 2, 4, 8, 16} because of 2048/128 = 16.
+
+ Args:
+ intermediate_size (int): MLP intermediate size.
+ block_size (int): The quantization block scale size. For NVFP4, it's 16.
+
+ Returns:
+ tuple[int, float | None]: A tuple containing (shared_tp_size, shared_output_scale).
+ - shared_tp_size: The computed TP size.
+ - shared_output_scale: The output scale factor, or None if not needed.
+ """
+
+ assert intermediate_size % block_size == 0, (
+ "intermediate_size must be divisible by block_size."
+ )
+
+ shared_output_scale = None
+ # The block scale size is 128, which requires shared_expert_intermediate_size to be divisible by 128.
+ if self.use_dp:
+ # If using attention DP, the shared experts also use DP instead of TP.
+ shared_tp_size = 1
+ else:
+ # Due to the restriction of block scale size (i.e., 128),
+ # the supported TP sizes only include 1, 2, 4, 8, and 16.
+ # The math.gcd operation ensures that shared_tp_size falls in the supported TP sizes.
+ shared_tp_size = math.gcd(
+ intermediate_size // block_size,
+ self.mapping.tp_size,
+ )
+ # If shared_tp_size has been overridden, the output of shared experts needs to be
+ # scaled down accordingly before all-reduce.
+ if shared_tp_size != self.mapping.tp_size:
+ shared_output_scale = shared_tp_size / self.mapping.tp_size
+
+ return shared_tp_size, shared_output_scale
+
+ @staticmethod
+ def _get_experts_quant_config(model_config, layer_idx: int) -> QuantConfig:
+ if getattr(model_config, "quant_config_dict", None) is None:
+ return model_config.quant_config
+ return model_config.quant_config_dict.get(
+ f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config
+ )
+
+ def compute_routed_output(
+ self, hidden_states, hidden_states_fp4, all_rank_num_tokens, do_finalize
+ ):
+ # max-throughput
+ use_dp_padding = False
+ # Add DP padding on SM120 for context comm performance
+ # TODO: Move this model-agonostic part to MoE
+ if self.use_dp and self.mapping.tp_size > 1 and get_sm_version() == 120:
+ use_dp_padding = True
+ hidden_states = torch.nn.functional.pad(
+ hidden_states, (0, 0, 0, max(all_rank_num_tokens) - hidden_states.shape[0])
+ )
+
+ router_logits = self.gate(hidden_states)
+
+ routed_output = self.experts(
+ hidden_states_fp4 if hidden_states_fp4 is not None else hidden_states,
+ router_logits,
+ do_finalize=do_finalize,
+ output_dtype=hidden_states.dtype,
+ all_rank_num_tokens=all_rank_num_tokens,
+ use_dp_padding=use_dp_padding,
+ )
+
+ return routed_output
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ hidden_states_fp4: Optional[Fp4QuantizedTensor] = None,
+ all_rank_num_tokens: Optional[list[int]] = None,
+ final_all_reduce_params: Optional[AllReduceParams] = None,
+ do_finalize: Optional[bool] = True,
+ ) -> torch.Tensor:
+ if not do_finalize:
+ assert not self.use_dp
+
+ def _compute_shared_output():
+ shared_output = self.shared_experts(
+ hidden_states_fp4 if hidden_states_fp4 is not None else hidden_states
+ )
+ if self.shared_output_scale is not None:
+ shared_output *= self.shared_output_scale
+ return shared_output
+
+ def _compute_routed_output():
+ routed_output = self.compute_routed_output(
+ hidden_states, hidden_states_fp4, all_rank_num_tokens, do_finalize
+ )
+ return routed_output
+
+ # NOTE: define compiled helpers at module scope to avoid defining decorators inside compiled frames
+
+ routed_output, shared_output = maybe_execute_in_parallel(
+ _compute_routed_output,
+ _compute_shared_output,
+ self.event_dict[EventType.Main],
+ self.event_dict[EventType.MoeShared],
+ self.aux_stream,
+ )
+
+ if not do_finalize:
+ return [shared_output, *routed_output]
+ else:
+ if routed_output.dim() == 3:
+ assert shared_output.numel() * self.top_k == routed_output.numel(), (
+ "unmatched tensor shape"
+ )
+ final_hidden_states = moe_reduce_add_shared_output(routed_output, shared_output)
+ else:
+ assert shared_output.size() == routed_output.size(), "unmatched tensor shape"
+ final_hidden_states = shared_output + routed_output
+
+ if not self.use_dp and self.mapping.tp_size > 1:
+ final_hidden_states = self.allreduce(
+ final_hidden_states, all_reduce_params=final_all_reduce_params
+ )
+
+ return final_hidden_states
+
+
+class Glm4DecoderLayer(DecoderLayer):
+ def __init__(
+ self,
+ model_config: ModelConfig[PretrainedConfig],
+ layer_idx: int,
+ aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
+ is_separate_draft_engine: bool = False,
+ ):
+ super().__init__()
+ self.model_config = model_config
+ self.config = model_config.pretrained_config
+ config = self.config
+
+ self.hidden_size = config.hidden_size
+ self.moe_intermediate_size = config.moe_intermediate_size
+ self.num_experts = config.n_routed_experts
+ self.num_shared_experts = config.n_shared_experts
+ self.top_k = config.num_experts_per_tok
+
+ self.mapping = model_config.mapping
+ mapping = self.mapping
+ layer_idx_for_attention = layer_idx
+ if is_separate_draft_engine:
+ # KVCacheManager only support 1 layer for separate draft engine
+ layer_idx_for_attention = layer_idx - model_config.pretrained_config.num_hidden_layers
+
+ self.self_attn = Glm4Attention(
+ model_config,
+ layer_idx=layer_idx_for_attention,
+ )
+ self.enable_attention_dp = mapping.enable_attention_dp
+
+ self.mlp_tp_size = mapping.tp_size
+ self.is_p2p_supported = can_access_peer(mapping)
+
+ self.fusion_config = EagerFusionConfig()
+ self.enable_fusion = os.environ.get("TRTLLM_GLM_EAGER_FUSION_DISABLED", "0") == "0"
+ self.enable_fusion &= not self.enable_attention_dp
+
+ # FIXME: incompatible with mixed quantization mode
+ quant_config = self._get_decoder_layer_quant_config(model_config, layer_idx)
+ self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4()
+ assert quant_config.quant_algo is not QuantAlgo.MIXED_PRECISION, (
+ "MIXED_PRECISION is ambiguous"
+ )
+
+ has_tp = mapping.has_tp()
+ self.allreduce = AllReduce(
+ mapping=model_config.mapping,
+ strategy=model_config.allreduce_strategy,
+ dtype=config.torch_dtype,
+ )
+ self.moe_allreduce = MoEAllReduce(self.mapping)
+
+ if config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace:
+ self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
+ self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION
+
+ self.mlp = Glm4MoE(
+ num_experts=self.num_experts,
+ top_k=self.top_k,
+ hidden_size=self.hidden_size,
+ intermediate_size=self.moe_intermediate_size,
+ shared_expert_intermediate_size=self.moe_intermediate_size
+ * self.num_shared_experts,
+ dtype=config.torch_dtype,
+ model_config=model_config,
+ override_quant_config=quant_config,
+ aux_stream_dict=aux_stream_dict,
+ layer_idx=layer_idx,
+ )
+ else:
+ block_size = 1
+ if quant_config and quant_config.quant_algo and quant_config.group_size is not None:
+ block_size = quant_config.group_size
+ self.mlp_tp_size = self._compute_mlp_tp_size(config.intermediate_size, block_size)
+
+ has_mlp_tp = self.mlp_tp_size > 1
+ self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_mlp_tp and self.is_nvfp4
+ self.fusion_config.POST_MLP_FUSION = self.enable_fusion and has_mlp_tp
+
+ self.mlp = GatedMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ bias=False,
+ dtype=config.torch_dtype,
+ config=model_config,
+ overridden_tp_size=self.mlp_tp_size,
+ reduce_output=True,
+ )
+
+ self.input_layernorm = RMSNorm(
+ hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
+ )
+
+ self.disable_attn_allreduce = (
+ self.fusion_config.PRE_MOE_FUSION
+ or self.fusion_config.PRE_MLP_FUSION
+ or self.mapping.tp_size == 1
+ or self.enable_attention_dp
+ )
+
+ self.post_attention_layernorm = RMSNorm(
+ hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
+ )
+ self.layer_idx = layer_idx
+ self.next_layer_layernorm: RMSNorm = None
+
+ def _get_decoder_layer_quant_config(
+ self, model_config: ModelConfig[PretrainedConfig], layer_idx: int
+ ):
+ """
+ The MTP layer in the nvfp4 checkpoint is unquantized. Because the TRTLLM
+ moe_backend only supports fp8/fp4 quantization, we need to override
+ the quant_config for the MTP layer.
+ """
+ quant_config = model_config.quant_config
+
+ layer_name = f"model.layers.{layer_idx}"
+ if quant_config.is_module_excluded_from_quantization(layer_name):
+ return QuantConfig(
+ quant_algo=None,
+ kv_cache_quant_algo=quant_config.kv_cache_quant_algo,
+ )
+ else:
+ return model_config.quant_config
+
+ def _compute_mlp_tp_size(self, intermediate_size: int, block_size: int) -> int:
+ """
+ For GLM4, MLP TP size is limited by intermediate_size // block_size
+ and must also be multiples of gpus_per_node to avoid expensive inter‑node allreduce.
+
+ Args:
+ intermediate_size (int): MLP intermediate size.
+ block_size (int): The quantization block scale size. For NVFP4, it's 16.
+
+ Returns:
+ int: The computed tp_size.
+ """
+
+ assert intermediate_size % block_size == 0, (
+ "intermediate_size must be divisible by block_size."
+ )
+ if self.enable_attention_dp:
+ # If using attention DP, the MLP also uses DP instead of TP.
+ mlp_tp_size = 1
+ else:
+ # The two math.gcd operations ensure that mlp_tp_size falls in the candidate TP sizes.
+ tp = math.gcd(
+ intermediate_size // block_size,
+ self.mapping.tp_size,
+ )
+
+ if tp > self.mapping.gpus_per_node:
+ mlp_tp_size = math.gcd(
+ tp,
+ self.mapping.gpus_per_node,
+ ) # Avoid costly inter-node TP
+ else:
+ mlp_tp_size = tp
+ return mlp_tp_size
+
+ def forward(
+ self,
+ position_ids: torch.IntTensor,
+ hidden_states: torch.Tensor,
+ attn_metadata: AttentionMetadata,
+ residual: torch.Tensor,
+ spec_metadata: Optional[SpecMetadata] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states = self.self_attn(
+ position_ids=position_ids,
+ hidden_states=hidden_states,
+ attn_metadata=attn_metadata,
+ all_reduce_params=AllReduceParams(enable_allreduce=not (self.disable_attn_allreduce)),
+ **kwargs,
+ )
+ if isinstance(self.mlp, Glm4MoE):
+ if spec_metadata is not None and spec_metadata.is_layer_capture(self.layer_idx):
+ self.fusion_config.POST_MOE_FUSION = False
+ return self.forward_MoE(
+ hidden_states=hidden_states,
+ attn_metadata=attn_metadata,
+ residual=residual,
+ spec_metadata=spec_metadata,
+ )
+ else:
+ if spec_metadata is not None and spec_metadata.is_layer_capture(self.layer_idx):
+ self.fusion_config.POST_MLP_FUSION = False
+ assert isinstance(self.mlp, GatedMLP)
+ return self.forward_mlp(
+ hidden_states=hidden_states,
+ residual=residual,
+ spec_metadata=spec_metadata,
+ )
+
+ def forward_MoE(
+ self,
+ hidden_states: torch.Tensor,
+ attn_metadata: AttentionMetadata,
+ residual: torch.Tensor,
+ spec_metadata: Optional[SpecMetadata] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
+ return self.mlp(
+ hidden_states,
+ hidden_states_fp4,
+ all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
+ final_all_reduce_params=AllReduceParams(
+ enable_allreduce=not (
+ self.fusion_config.POST_MOE_FUSION or self.mapping.tp_size == 1
+ )
+ ),
+ do_finalize=do_finalize,
+ )
+
+ if self.fusion_config.PRE_MOE_FUSION:
+ # moe_backend can be either CUTLASS or TRTLLM here
+ # TODO: unify the two min-latency MoE backends by enabling quant fusion
+ hidden_states, residual = self.allreduce(
+ hidden_states,
+ all_reduce_params=AllReduceParams(
+ fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
+ residual=residual,
+ norm_weight=self.post_attention_layernorm.weight,
+ eps=self.post_attention_layernorm.variance_epsilon,
+ trigger_completion_at_end=False,
+ ),
+ )
+ else:
+ # No fusion
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+
+ # Note: this fusion pattern is only supported for single-node TRTLLM-nvfp4 backend now
+ do_finalize = self.mapping.is_multi_node() or (
+ not (
+ hidden_states.shape[0] <= self.moe_allreduce.max_token
+ and self.fusion_config.POST_MOE_FUSION
+ and self.model_config.moe_backend == "TRTLLM"
+ and self.mlp.experts.has_nvfp4
+ and self.is_p2p_supported
+ )
+ )
+
+ hidden_states = _run_MoE(hidden_states, hidden_states_fp4=None, do_finalize=do_finalize)
+
+ if self.fusion_config.POST_MOE_FUSION:
+ if do_finalize:
+ hidden_states, residual = self.allreduce(
+ hidden_states,
+ all_reduce_params=AllReduceParams(
+ fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
+ residual=residual,
+ norm_weight=self.next_layer_layernorm.weight,
+ eps=self.next_layer_layernorm.variance_epsilon,
+ trigger_completion_at_end=False,
+ ),
+ )
+ else:
+ assert len(hidden_states) == 4, "hidden_states must have 4 elements"
+
+ shared_output = hidden_states[0]
+ fc2_output = hidden_states[1]
+ expert_scale_factor = hidden_states[2]
+ expanded_idx_to_permuted_idx = hidden_states[3]
+
+ moe_all_reduce_params = MoEAllReduceParams(
+ expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
+ expert_scale_factor=expert_scale_factor,
+ shared_expert_output=shared_output,
+ residual=residual,
+ norm_weight=self.next_layer_layernorm.weight,
+ eps=self.next_layer_layernorm.variance_epsilon,
+ is_cutlass_min_latency=False,
+ )
+ hidden_states, residual = self.moe_allreduce(
+ fc2_output, all_reduce_params=moe_all_reduce_params
+ )
+ else:
+ if spec_metadata is not None and spec_metadata.is_layer_capture(self.layer_idx):
+ spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states, residual)
+ if self.next_layer_layernorm is not None:
+ hidden_states, residual = self.next_layer_layernorm(hidden_states, residual)
+
+ return hidden_states, residual
+
+ def forward_mlp(
+ self,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ spec_metadata: Optional[SpecMetadata] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if self.fusion_config.PRE_MLP_FUSION:
+ act_fp4, act_sf, residual = self.allreduce(
+ hidden_states,
+ all_reduce_params=AllReduceParams(
+ fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4,
+ residual=residual,
+ norm_weight=self.post_attention_layernorm.weight,
+ scale=self.mlp.gate_up_proj.input_scale,
+ eps=self.post_attention_layernorm.variance_epsilon,
+ ),
+ )
+ hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
+ else:
+ # No fusion
+ # We need to add twoshot allreduce here to avoid modifying MLA logic
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+
+ hidden_states = self.mlp(
+ hidden_states,
+ final_all_reduce_params=AllReduceParams(
+ enable_allreduce=not (self.fusion_config.POST_MLP_FUSION or self.mlp_tp_size == 1)
+ ),
+ )
+
+ if self.fusion_config.POST_MLP_FUSION:
+ hidden_states, residual = self.allreduce(
+ hidden_states,
+ all_reduce_params=AllReduceParams(
+ fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
+ residual=residual,
+ norm_weight=self.next_layer_layernorm.weight,
+ eps=self.next_layer_layernorm.variance_epsilon,
+ ),
+ )
+ else:
+ if spec_metadata is not None and spec_metadata.is_layer_capture(self.layer_idx):
+ spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states, residual)
+ if self.next_layer_layernorm is not None:
+ hidden_states, residual = self.next_layer_layernorm(hidden_states, residual)
+
+ return hidden_states, residual
+
+
+class Glm4MTP(Glm4DecoderLayer):
+ def __init__(
+ self,
+ model_config: ModelConfig[PretrainedConfig],
+ layer_idx: int,
+ aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
+ is_separate_draft_engine: bool = False,
+ ):
+ super().__init__(model_config, layer_idx, aux_stream_dict, is_separate_draft_engine)
+ config = model_config.pretrained_config
+ self.hidden_dim = config.hidden_size
+ self.moe_intermediate_size = config.moe_intermediate_size
+ self.num_experts = config.n_routed_experts
+ self.num_shared_experts = config.n_shared_experts
+ self.top_k = config.num_experts_per_tok
+
+ self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
+ self.event_dict = {key: torch.cuda.Event() for key in [EventType.Main, EventType.MoeShared]}
+
+ self.enorm = RMSNorm(
+ hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
+ )
+
+ self.hnorm = RMSNorm(
+ hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
+ )
+ if model_config.mapping.enable_attention_dp:
+ self.eh_proj = Linear(
+ config.hidden_size * 2,
+ config.hidden_size,
+ bias=False,
+ dtype=config.torch_dtype,
+ skip_create_weights_in_init=model_config.skip_create_weights_in_init,
+ )
+ else:
+ self.eh_proj = Linear(
+ config.hidden_size * 2,
+ config.hidden_size,
+ bias=False,
+ dtype=config.torch_dtype,
+ tensor_parallel_mode=TensorParallelMode.ROW,
+ mapping=model_config.mapping,
+ reduce_output=True,
+ skip_create_weights_in_init=model_config.skip_create_weights_in_init,
+ )
+
+ self.shared_head = DeepseekV3MTPHead(model_config)
+
+ def forward(
+ self,
+ input_ids: torch.IntTensor,
+ position_ids: torch.IntTensor,
+ hidden_states: torch.Tensor,
+ embed_tokens: Embedding,
+ attn_metadata: AttentionMetadata,
+ all_rank_num_tokens: Optional[List[int]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ def norm_embeds():
+ return self.enorm(embed_tokens(input_ids)) # emdedding
+
+ def norm_hidden():
+ return self.hnorm(hidden_states)
+
+ inputs_embeds, hidden_states = maybe_execute_in_parallel(
+ norm_embeds,
+ norm_hidden,
+ self.event_dict[EventType.Main],
+ self.event_dict[EventType.MoeShared],
+ self.aux_stream,
+ )
+ hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1)
+ # Split hidden_states columnwise based on TP
+ tp_size = self.model_config.mapping.tp_size
+ tp_rank = self.model_config.mapping.tp_rank
+
+ if tp_size > 1 and not (self.model_config.mapping.enable_attention_dp):
+ hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank]
+ hidden_states = self.eh_proj(hidden_states)
+
+ # Input layer norm
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states = self.self_attn(
+ position_ids=position_ids,
+ hidden_states=hidden_states,
+ attn_metadata=attn_metadata,
+ all_reduce_params=AllReduceParams(enable_allreduce=not (self.disable_attn_allreduce)),
+ **kwargs,
+ )
+
+ # MTP Layer Must have sparse MOE
+ if self.fusion_config.PRE_MOE_FUSION:
+ hidden_states, residual = self.allreduce(
+ hidden_states,
+ all_reduce_params=AllReduceParams(
+ fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
+ residual=residual,
+ norm_weight=self.post_attention_layernorm.weight,
+ eps=self.post_attention_layernorm.variance_epsilon,
+ ),
+ )
+ else:
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+
+ # MoE
+ hidden_states = self.mlp(
+ hidden_states,
+ all_rank_num_tokens=all_rank_num_tokens,
+ final_all_reduce_params=AllReduceParams(
+ enable_allreduce=not (
+ self.fusion_config.POST_MOE_FUSION or self.mapping.tp_size == 1
+ )
+ ),
+ )
+
+ if self.fusion_config.POST_MOE_FUSION:
+ hidden_states, residual = self.allreduce(
+ hidden_states,
+ all_reduce_params=AllReduceParams(
+ fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
+ residual=residual,
+ norm_weight=self.shared_head.norm.weight,
+ eps=self.shared_head.norm.variance_epsilon,
+ ),
+ )
+ else:
+ hidden_states, _ = self.shared_head.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class Glm4Model(DecoderModel):
+ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
+ super().__init__(model_config)
+ config = model_config.pretrained_config
+ self.vocab_size = config.vocab_size
+ self.num_hidden_layers = config.num_hidden_layers
+ aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
+ self.aux_stream_dict = {
+ AuxStreamType.Attention: aux_stream_list[0],
+ AuxStreamType.MoeShared: aux_stream_list[0],
+ AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
+ AuxStreamType.MoeBalancer: aux_stream_list[2],
+ }
+
+ self.embed_tokens = Embedding(
+ config.vocab_size,
+ config.hidden_size,
+ dtype=config.torch_dtype,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ Glm4DecoderLayer(model_config, layer_idx, self.aux_stream_dict)
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = RMSNorm(
+ hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
+ )
+
+ def forward(
+ self,
+ attn_metadata: AttentionMetadata,
+ input_ids: Optional[torch.IntTensor] = None,
+ position_ids: Optional[torch.IntTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ spec_metadata: Optional[SpecMetadata] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = inputs_embeds
+ residual = None
+
+ for decoder_layer in self.layers[: self.num_hidden_layers]:
+ hidden_states, residual = decoder_layer(
+ position_ids=position_ids,
+ hidden_states=hidden_states,
+ attn_metadata=attn_metadata,
+ residual=residual,
+ spec_metadata=spec_metadata,
+ )
+
+ return hidden_states
+
+
+@register_auto_model("Glm4MoeForCausalLM")
+class Glm4MoeForCausalLM(SpecDecOneEngineForCausalLM[Glm4Model, PretrainedConfig]):
+ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
+ super().__init__(model=Glm4Model(model_config), model_config=model_config)
+
+ self.model_nextn = 0
+ if (
+ model_config.spec_config is not None
+ and model_config.spec_config.spec_dec_mode.is_mtp_one_model()
+ ):
+ model_nextn = model_config.spec_config.num_nextn_predict_layers
+ ckpt_nextn = self.config.num_nextn_predict_layers
+ self.num_hidden_layers = self.config.num_hidden_layers
+ assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
+ if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla:
+ pass
+ else:
+ # modify the QuantConfig to support duplicated mtp layers
+ if model_config.quant_config.exclude_modules is not None:
+ extend_exclude_modules = []
+ for model_mtp_idx in range(
+ self.num_hidden_layers, self.num_hidden_layers + model_nextn
+ ):
+ ckpt_mtp_idx = (
+ model_mtp_idx - self.num_hidden_layers
+ ) % ckpt_nextn + self.num_hidden_layers
+ model_prefix = f"model.layers.{model_mtp_idx}"
+ ckpt_prefix = f"model.layers.{ckpt_mtp_idx}"
+ for exclude_module in model_config.quant_config.exclude_modules:
+ if ckpt_prefix in exclude_module and model_prefix not in exclude_module:
+ extend_exclude_modules.append(
+ exclude_module.replace(ckpt_prefix, model_prefix)
+ )
+ self.model_config.quant_config.exclude_modules.extend(extend_exclude_modules)
+ self.model.layers.extend(self.draft_model.mtp_layers)
+ self.epilogue.extend(self.draft_model.mtp_layers)
+ self.epilogue.append(self.spec_worker)
+
+ def forward(
+ self,
+ attn_metadata: AttentionMetadata,
+ input_ids: torch.IntTensor = None,
+ position_ids: Optional[torch.IntTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ spec_metadata: Optional[SpecMetadata] = None,
+ return_context_logits: bool = False,
+ **kwargs,
+ ) -> torch.Tensor:
+ return super().forward(
+ attn_metadata=attn_metadata,
+ input_ids=input_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ spec_metadata=spec_metadata,
+ return_context_logits=return_context_logits,
+ **kwargs,
+ )
+
+ def load_weights(self, weights: Dict):
+ # model.layers.91.mlp.experts.75.up_proj.weight_scale_2
+ _load_weights_impl(
+ self,
+ weights,
+ params_map={
+ r"(?!.*shared_experts)(?=.*experts?)(.*?)up_proj(.*)": r"\1w3\2",
+ r"(?!.*shared_experts)(?=.*experts?)(.*?)down_proj(.*)": r"\1w2\2",
+ r"(?!.*shared_experts)(?=.*experts?)(.*?)gate_proj(.*)": r"\1w1\2",
+ },
+ )
+
+ def post_load_weights(self):
+ all_named_modules = dict(self.model.named_modules())
+ for name, module in tqdm(all_named_modules.items(), desc="Post loading weights"):
+ if len(module._parameters) <= 0 or name.startswith("draft_model"):
+ continue
+ else:
+ if (
+ self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales()
+ and is_sm_100f()
+ and hasattr(module, "weight_scale")
+ ):
+ weight, weight_scale = resmooth_to_fp8_e8m0(module.weight, module.weight_scale)
+ transfromed_scale = transform_sf_into_required_layout(
+ weight_scale,
+ mn=weight.shape[0],
+ k=weight.shape[1],
+ recipe=(1, 128, 128),
+ is_sfa=False,
+ )
+ module.weight = nn.Parameter(weight, requires_grad=False)
+ module.weight_scale = nn.Parameter(transfromed_scale, requires_grad=False)
+
+ for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
+ if idx == self.config.num_hidden_layers - 1:
+ layer.next_layer_layernorm = self.model.norm
+ else:
+ layer.next_layer_layernorm = self.model.layers[idx + 1].input_layernorm
diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py
index 31d52791f6b..1aece8ab21f 100755
--- a/tensorrt_llm/_torch/models/modeling_speculative.py
+++ b/tensorrt_llm/_torch/models/modeling_speculative.py
@@ -351,7 +351,18 @@ def __init__(
):
super().__init__()
# Import here to avoid circular import
- from .modeling_deepseekv3 import DeepseekV3MTP
+ model_type = model_config.pretrained_config.model_type
+ mtp_layer = None
+ match model_type:
+ case "glm4_moe":
+ from .modeling_glm import Glm4MTP
+ mtp_layer = Glm4MTP
+ case "deepseek_v3" | "deepseek_v32":
+ from .modeling_deepseekv3 import DeepseekV3MTP
+ mtp_layer = DeepseekV3MTP
+ case _:
+ raise ValueError(
+ f"Model type {model_type} not supported for MTP")
spec_dec_mode = model_config.spec_config.spec_dec_mode
assert spec_dec_mode.is_mtp_one_model()
@@ -362,8 +373,8 @@ def __init__(
model_config.spec_config.num_nextn_predict_layers // mtp_num_layers)
self.mtp_layers = nn.ModuleList([
- DeepseekV3MTP(model_config, layer_idx + start_layer_idx,
- model.aux_stream_dict)
+ mtp_layer(model_config, layer_idx + start_layer_idx,
+ model.aux_stream_dict)
for layer_idx in range(mtp_num_layers)
])
self.lm_head = lm_head
diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py
index d4314e48bff..b9b1310cbe9 100644
--- a/tensorrt_llm/_torch/speculative/mtp.py
+++ b/tensorrt_llm/_torch/speculative/mtp.py
@@ -837,7 +837,7 @@ def sample_and_accept_draft_tokens(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
ctx_input_ids = input_ids[:attn_metadata.num_ctx_tokens]
ctx_is_think = (ctx_input_ids ==
- self.spec_config.BEGIN_THINKING_PHASE_TOKEN).int()
+ self.spec_config.begin_thinking_phase_token).int()
ctx_is_think_cumsum = torch.cumsum(ctx_is_think, dim=0)
ctx_last_cumsum = ctx_is_think_cumsum[
last_tokens_idx[:num_contexts]]
@@ -863,8 +863,8 @@ def sample_and_accept_draft_tokens(
mtp_relaxed_delta_pool, num_accepted_tokens, accepted_tokens,
mtp_num_modules, batch_size, num_contexts,
self.spec_config.relaxed_topk, self.spec_config.relaxed_delta,
- self.spec_config.BEGIN_THINKING_PHASE_TOKEN,
- self.spec_config.END_THINKING_PHASE_TOKEN)
+ self.spec_config.begin_thinking_phase_token,
+ self.spec_config.end_thinking_phase_token)
# Strict acceptance
else:
diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py
index ddd3f5b6164..cc7ed8c6b22 100644
--- a/tensorrt_llm/llmapi/llm_args.py
+++ b/tensorrt_llm/llmapi/llm_args.py
@@ -817,12 +817,11 @@ class MTPDecodingConfig(DecodingBaseConfig):
# Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine.
num_nextn_predict_layers_from_model_config: int = 1
- # TODO: Hard code for DeepSeek R1
# When encounter , start thinking phase.
# When encounter , end thinking phase.
# [thinking phase] [real output]
- BEGIN_THINKING_PHASE_TOKEN: int = 128798
- END_THINKING_PHASE_TOKEN: int = 128799
+ begin_thinking_phase_token: int = 128798
+ end_thinking_phase_token: int = 128799
def __init__(self, **kwargs):
super().__init__(**kwargs)
diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml
index 3b4047e476a..70d5d0643cd 100644
--- a/tests/integration/defs/accuracy/references/gsm8k.yaml
+++ b/tests/integration/defs/accuracy/references/gsm8k.yaml
@@ -260,3 +260,7 @@ LGAI-EXAONE/EXAONE-4.0-32B:
- accuracy: 88.36
ByteDance-Seed/Seed-OSS-36B-Instruct:
- accuracy: 90.8
+zai-org/GLM-4.6:
+ - accuracy: 81.3
+ - quant_algo: NVFP4
+ accuracy: 91.0
diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py
index fe95c4cc093..7653293972b 100644
--- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py
+++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py
@@ -2501,6 +2501,77 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
task.evaluate(llm)
+@skip_pre_blackwell
+class TestGLM4_6(LlmapiAccuracyTestHarness):
+ MODEL_NAME = "zai-org/GLM-4.6"
+ MODEL_PATH = f"{llm_models_root()}/GLM-4.6"
+
+ @pytest.mark.timeout(14400)
+ @pytest.mark.skip_less_device_memory(80000)
+ @pytest.mark.skip_less_device(4)
+ @parametrize_with_ids("mtp_nextn", [0, 2])
+ @parametrize_with_ids("overlap_scheduler", [False, True])
+ @parametrize_with_ids("tp_size, ep_size", [(4, 4), (4, 1)])
+ @parametrize_with_ids("max_batch_size, moe_backend", [(4, "CUTLASS")])
+ def test_bfloat16_4gpus(self, tp_size, ep_size, mtp_nextn,
+ overlap_scheduler, max_batch_size, moe_backend):
+ pytorch_config = dict(
+ disable_overlap_scheduler=not overlap_scheduler,
+ moe_config=MoeConfig(backend=moe_backend),
+ )
+ kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.50)
+
+ mtp_config = None
+ if mtp_nextn > 0:
+ mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
+
+ with LLM(self.MODEL_PATH,
+ max_batch_size=max_batch_size,
+ tensor_parallel_size=tp_size,
+ moe_expert_parallel_size=ep_size,
+ kv_cache_config=kv_cache_config,
+ enable_chunked_prefill=True,
+ max_num_tokens=512,
+ **pytorch_config,
+ speculative_config=mtp_config) as llm:
+ task = GSM8K(self.MODEL_NAME)
+ task.evaluate(llm)
+
+ @pytest.mark.skip_less_device(4)
+ @pytest.mark.parametrize(
+ "tp_size,pp_size,mtp_nextn,fp8kv,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend",
+ [pytest.param(4, 1, 2, True, True, True, True, 16, "CUTLASS")],
+ ids=["throughput"])
+ def test_nvfp4_multi_gpus(self, tp_size, pp_size, mtp_nextn, fp8kv,
+ cuda_graph, overlap_scheduler, chunked_prefill,
+ max_batch_size, moe_backend):
+
+ kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70)
+ pytorch_config = dict(
+ disable_overlap_scheduler=not overlap_scheduler,
+ cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
+ moe_config=MoeConfig(backend=moe_backend))
+
+ if fp8kv:
+ kv_cache_config.dtype = "fp8"
+
+ mtp_config = None
+ if mtp_nextn > 0:
+ mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
+ with LLM(f"{llm_models_root()}/GLM-4.6/GLM-4.6-FP4",
+ max_batch_size=max_batch_size,
+ tensor_parallel_size=tp_size,
+ pipeline_parallel_size=pp_size,
+ kv_cache_config=kv_cache_config,
+ **pytorch_config,
+ speculative_config=mtp_config,
+ enable_chunked_prefill=chunked_prefill) as llm:
+
+ assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
+ task = GSM8K(self.MODEL_NAME)
+ task.evaluate(llm)
+
+
@pytest.mark.timeout(7200)
@pytest.mark.skip_less_device_memory(100000)
class TestKimiK2(LlmapiAccuracyTestHarness):
diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt
index d7a8f0d82f6..bdb8f85b891 100644
--- a/tests/integration/test_lists/qa/llm_function_core.txt
+++ b/tests/integration/test_lists/qa/llm_function_core.txt
@@ -501,6 +501,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baselin
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency]
+accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput]
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=True]