Skip to content

Commit 4d5fa97

Browse files
committed
[None][feat] Reference pytorch implementation for nemotron H
Signed-off-by: William Zhang <[email protected]>
1 parent 9b2abb8 commit 4d5fa97

File tree

6 files changed

+717
-49
lines changed

6 files changed

+717
-49
lines changed

tensorrt_llm/_torch/auto_deploy/export/export.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,31 @@ def _deduplicate_params_and_buffers(gm: fx.GraphModule) -> None:
9393

9494
def _add_missing_load_hooks(gm: fx.GraphModule, model: nn.Module) -> None:
9595
"""Adds back the state dict load hooks stripped away during export."""
96-
hooks = {
96+
pre_hooks = {
9797
k: mod._load_state_dict_pre_hooks
9898
for k, mod in model.named_modules()
9999
if mod._load_state_dict_pre_hooks
100100
}
101101

102102
for mod_name, mod in gm.named_modules():
103-
if mod_name in hooks:
104-
for hook in hooks.pop(mod_name).values():
103+
if mod_name in pre_hooks:
104+
for hook in pre_hooks.pop(mod_name).values():
105105
mod._register_load_state_dict_pre_hook(hook.hook, with_module=hook.with_module)
106-
assert not (bool(hooks)), f"""Mismatch in names of exported and source modules with hooks.
107-
The following module names were not found in exported module {list(hooks.keys())}"""
106+
assert not (bool(pre_hooks)), f"""Mismatch in names of exported and source modules with hooks.
107+
The following module names were not found in exported module {list(pre_hooks.keys())}"""
108+
109+
post_hooks = {
110+
k: mod._load_state_dict_post_hooks
111+
for k, mod in model.named_modules()
112+
if mod._load_state_dict_post_hooks
113+
}
114+
115+
for mod_name, mod in gm.named_modules():
116+
if mod_name in post_hooks:
117+
for hook in post_hooks.pop(mod_name).values():
118+
mod.register_load_state_dict_post_hook(hook)
119+
assert not (bool(post_hooks)), f"""Mismatch in names of exported and source modules with hooks.
120+
The following module names were not found in exported module {list(post_hooks.keys())}"""
108121

109122

110123
def _add_load_hook_for_aliased_params(gm: fx.GraphModule, model: nn.Module) -> None:
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1+
# TODO: When getting rid of the nemotron H patches, import `modeling_nemotron_h` here to ensure the
2+
# custom model implementation is registered.
13
from . import hf, patches
24
from .factory import *

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
110110
"use_cache": False,
111111
}
112112

113+
# The below maps from an entry in a model's config dict's `model_type` to the alternative
114+
# `AutoModelForCausalLM` we would like to use.
115+
_custom_model_mapping: Dict[str, Type[AutoModelForCausalLM]] = {}
116+
113117
def __init__(self, *args, **kwargs):
114118
super().__init__(*args, **kwargs)
115119
self._quant_config_reader: QuantConfigReader | None = None
@@ -212,14 +216,25 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
212216
"""Build the model on the desired device."""
213217
model_config, unused_kwargs = self._get_model_config()
214218

219+
model_type = getattr(model_config, "model_type", "")
220+
custom_model_cls = self._custom_model_mapping.get(model_type, None)
215221
with (init_empty_weights if device == "meta" else nullcontext)():
216-
model = self.automodel_cls.from_config(
217-
model_config,
218-
**{
219-
"trust_remote_code": True,
220-
**unused_kwargs,
221-
},
222-
)
222+
if custom_model_cls is not None:
223+
# `_from_config` has some behavior we would like to use where possible. It is
224+
# defined in the `PreTrainedModel` mixin.
225+
if hasattr(custom_model_cls, "_from_config"):
226+
model = custom_model_cls._from_config(model_config, **unused_kwargs)
227+
else:
228+
model = custom_model_cls(model_config, **unused_kwargs)
229+
else:
230+
model = self.automodel_cls.from_config(
231+
model_config,
232+
**{
233+
"trust_remote_code": True,
234+
**unused_kwargs,
235+
},
236+
)
237+
223238
if device == "meta":
224239
# post-init --> this must be called explicitly for HF models the way we initialize them
225240
# since this "gets lost" with the init_empty_weights context manager.
@@ -482,6 +497,23 @@ def _remap_param_names_load_hook(self, model, state_dict, *args, **kwargs) -> No
482497
def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]:
483498
return [FullModelExportInfo()]
484499

500+
@classmethod
501+
def register_custom_model_cls(
502+
cls, model_type: str, custom_model_cls: Type[AutoModelForCausalLM]
503+
) -> None:
504+
"""Register a custom model implementation.
505+
506+
This is useful when the default `AutoModelForCausalLM` is not the one we want to use. For
507+
example, when the model's code is in a HuggingFace repo that is out of date, or has
508+
dependencies that TensorRT-LLM does not have, etc.
509+
510+
Args:
511+
model_type: This should be the value for the `model_type` field in the model's config.
512+
custom_model_cls: The `AutoModelForCausalLM` implementation that should be used for
513+
`model_type`.
514+
"""
515+
cls._custom_model_mapping[model_type] = custom_model_cls
516+
485517

486518
class _StateDictParamNameConverter:
487519
"""Helper class for applying param name conversions to a state dict.

0 commit comments

Comments
 (0)