Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions cpp/tensorrt_llm/thop/moeOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream);

auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
auto const quant_params
= getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales, base_activation_type);
kernels::MoeMinLatencyParams min_latency_params{};

// TODO: support lora in the future
Expand Down Expand Up @@ -613,7 +614,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream);

auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
auto const quant_params
= getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales, base_activation_type);

// TODO: support lora in the future
::tensorrt_llm::kernels::LoraParams lora_params{};
Expand Down Expand Up @@ -859,8 +861,10 @@ class FusedMoeRunner : public torch::CustomClassHolder
}

kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size,
int64_t const inter_size, torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales) const
int64_t const inter_size, torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales,
ActivationType base_activation_type) const
{
int expand_ratio = isGatedActivation(base_activation_type) ? 2 : 1;
if (isFp8Quant())
{
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for fp8 quantization");
Expand Down Expand Up @@ -925,12 +929,12 @@ class FusedMoeRunner : public torch::CustomClassHolder
&& fc1_weight_block.sizes()[1]
== TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX)
* 2
* expand_ratio
&& fc1_weight_block.sizes()[2] * FP8_PER_INT32
* TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize
== TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX),
"fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // "
"fc1 weight block size must be (num_experts_on_rank, inter_size * expand_ratio, hidden_size // 4 // "
"block_scale_vector_size)");
TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)");
TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.sizes()[0] == num_experts_on_rank,
Expand Down Expand Up @@ -978,12 +982,12 @@ class FusedMoeRunner : public torch::CustomClassHolder
&& fc1_weight_block.sizes()[1]
== TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX)
* 2
* expand_ratio
&& fc1_weight_block.sizes()[2] * FP8_PER_INT32
* TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize
== TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX),
"fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // "
"fc1 weight block size must be (num_experts_on_rank, inter_size * expand_ratio, hidden_size // 4 // "
"block_scale_vector_size)");
TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)");
TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank
Expand Down Expand Up @@ -1044,12 +1048,12 @@ class FusedMoeRunner : public torch::CustomClassHolder
&& fc1_weight_block.sizes()[1]
== TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4)
* 2
* expand_ratio
&& fc1_weight_block.sizes()[2] * FP8_PER_INT32
* TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize
== TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4),
"fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // "
"fc1 weight block size must be (num_experts_on_rank, inter_size * expand_ratio, hidden_size // 4 // "
"block_scale_vector_size)");
TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)");
TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.sizes()[0] == num_experts_on_rank,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import \
HfWeightMapper
from tensorrt_llm._torch.models.modeling_nemotron_h import split
from tensorrt_llm._torch.models.modeling_utils import register_mapper
from tensorrt_llm._torch.utils import split


@register_mapper("HF", "NemotronHForCausalLM")
Expand Down Expand Up @@ -34,7 +34,8 @@ def preprocess_weights(self, weights: dict) -> dict:
if "A_log" in key:
key = key.replace("A_log", "A")

if "_scale" in key:
if ("mixer.in_proj" in key
or "mixer.out_proj" in key) and "_scale" in key:
new_weights[key] = weights[name]
elif "A" in key:
w = split(weights[name], tp_size, tp_rank)
Expand Down Expand Up @@ -94,6 +95,39 @@ def preprocess_weights(self, weights: dict) -> dict:
elif "mixer.norm.weight" in key:
w = split(weights[name], tp_size, tp_rank)
new_weights[key] = w
# Remap MoE expert weights.
elif "mixer.experts." in key:
if self.config.moe_backend == 'VANILLA':
new_weights[key] = weights[name]
else:
if "up_proj" in key:
w1_key = key.replace("up_proj", "w1")
w3_key = key.replace("up_proj", "w3")
# Don't need to handle with input_scale and weight_scale_2 since they are scalar for fp8 and nvfp4 models.
if "input_scale" in key or "weight_scale_2" in key:
new_weights[w3_key] = weights[name]
new_weights[w1_key] = weights[name]
elif "weight_scale" in key:
# NVFP4 case.
if weights[name].shape:
new_weights[w3_key] = weights[
name][:weights[name].shape[0] // 2]
new_weights[w1_key] = weights[name][
weights[name].shape[0] // 2:]
# FP8 case.
else:
new_weights[w3_key] = weights[name]
new_weights[w1_key] = weights[name]
else:
new_weights[w3_key] = weights[name][:weights[name].
shape[0] // 2]
new_weights[w1_key] = weights[name][weights[name].
shape[0] // 2:]
elif "down_proj" in key:
key = key.replace("down_proj", "w2")
new_weights[key] = weights[name]
else:
raise ValueError(f"Unknown MoE weight: {key}")
else:
new_weights[key] = weights[name]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \
Qwen2MoeHfWeightMapper
from tensorrt_llm._torch.models.modeling_nemotron_h import split
from tensorrt_llm._torch.models.modeling_utils import register_mapper
from tensorrt_llm._torch.utils import split
from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM


Expand Down
Loading