Skip to content

Commit 77ccb30

Browse files
committed
[None][feat] Reference pytorch implementation for nemotron H
Signed-off-by: William Zhang <[email protected]>
1 parent b51258a commit 77ccb30

File tree

5 files changed

+686
-21
lines changed

5 files changed

+686
-21
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1+
# TODO: When getting rid of the nemotron H patches, import `modeling_nemotron_h` here to ensure the
2+
# custom model implementation is registered.
13
from . import hf, patches
24
from .factory import *

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
110110
"use_cache": False,
111111
}
112112

113+
# The below maps from an entry in a model's config dict's `model_type` to the alternative
114+
# `AutoModelForCausalLM` we would like to use.
115+
_custom_model_mapping: Dict[str, Type[AutoModelForCausalLM]] = {}
116+
113117
def __init__(self, *args, **kwargs):
114118
super().__init__(*args, **kwargs)
115119
self._quant_config_reader: QuantConfigReader | None = None
@@ -205,14 +209,25 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
205209
"""Build the model on the desired device."""
206210
model_config, unused_kwargs = self._get_model_config()
207211

212+
model_type = getattr(model_config, "model_type", "")
213+
custom_model_cls = self._custom_model_mapping.get(model_type, None)
208214
with (init_empty_weights if device == "meta" else nullcontext)():
209-
model = self.automodel_cls.from_config(
210-
model_config,
211-
**{
212-
"trust_remote_code": True,
213-
**unused_kwargs,
214-
},
215-
)
215+
if custom_model_cls is not None:
216+
# `_from_config` has some behavior we would like to use where possible. It is
217+
# defined in the `PreTrainedModel` mixin.
218+
if hasattr(custom_model_cls, "_from_config"):
219+
model = custom_model_cls._from_config(model_config, **unused_kwargs)
220+
else:
221+
model = custom_model_cls(model_config, **unused_kwargs)
222+
else:
223+
model = self.automodel_cls.from_config(
224+
model_config,
225+
**{
226+
"trust_remote_code": True,
227+
**unused_kwargs,
228+
},
229+
)
230+
216231
if device == "meta":
217232
# post-init --> this must be called explicitly for HF models the way we initialize them
218233
# since this "gets lost" with the init_empty_weights context manager.
@@ -475,6 +490,23 @@ def _remap_param_names_load_hook(self, model, state_dict, *args, **kwargs) -> No
475490
def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]:
476491
return [FullModelExportInfo()]
477492

493+
@classmethod
494+
def register_custom_model_cls(
495+
cls, model_type: str, custom_model_cls: Type[AutoModelForCausalLM]
496+
) -> None:
497+
"""Register a custom model implementation.
498+
499+
This is useful when the default `AutoModelForCausalLM` is not the one we want to use. For
500+
example, when the model's code is in a HuggingFace repo that is out of date, or has
501+
dependencies that TensorRT-LLM does not have, etc.
502+
503+
Args:
504+
model_type: This should be the value for the `model_type` field in the model's config.
505+
custom_model_cls: The `AutoModelForCausalLM` implementation that should be used for
506+
`model_type`.
507+
"""
508+
cls._custom_model_mapping[model_type] = custom_model_cls
509+
478510

479511
class _StateDictParamNameConverter:
480512
"""Helper class for applying param name conversions to a state dict.

0 commit comments

Comments
 (0)