@@ -110,6 +110,10 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
110110 "use_cache" : False ,
111111 }
112112
113+ # The below maps from an entry in a model's config dict's `model_type` to the alternative
114+ # `AutoModelForCausalLM` we would like to use.
115+ _custom_model_mapping : Dict [str , Type [AutoModelForCausalLM ]] = {}
116+
113117 def __init__ (self , * args , ** kwargs ):
114118 super ().__init__ (* args , ** kwargs )
115119 self ._quant_config_reader : QuantConfigReader | None = None
@@ -212,14 +216,25 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
212216 """Build the model on the desired device."""
213217 model_config , unused_kwargs = self ._get_model_config ()
214218
219+ model_type = getattr (model_config , "model_type" , "" )
220+ custom_model_cls = self ._custom_model_mapping .get (model_type , None )
215221 with (init_empty_weights if device == "meta" else nullcontext )():
216- model = self .automodel_cls .from_config (
217- model_config ,
218- ** {
219- "trust_remote_code" : True ,
220- ** unused_kwargs ,
221- },
222- )
222+ if custom_model_cls is not None :
223+ # `_from_config` has some behavior we would like to use where possible. It is
224+ # defined in the `PreTrainedModel` mixin.
225+ if hasattr (custom_model_cls , "_from_config" ):
226+ model = custom_model_cls ._from_config (model_config , ** unused_kwargs )
227+ else :
228+ model = custom_model_cls (model_config , ** unused_kwargs )
229+ else :
230+ model = self .automodel_cls .from_config (
231+ model_config ,
232+ ** {
233+ "trust_remote_code" : True ,
234+ ** unused_kwargs ,
235+ },
236+ )
237+
223238 if device == "meta" :
224239 # post-init --> this must be called explicitly for HF models the way we initialize them
225240 # since this "gets lost" with the init_empty_weights context manager.
@@ -482,6 +497,23 @@ def _remap_param_names_load_hook(self, model, state_dict, *args, **kwargs) -> No
482497 def get_export_infos (self , model : nn .Module ) -> List [SubModuleExportInfo ]:
483498 return [FullModelExportInfo ()]
484499
500+ @classmethod
501+ def register_custom_model_cls (
502+ cls , model_type : str , custom_model_cls : Type [AutoModelForCausalLM ]
503+ ) -> None :
504+ """Register a custom model implementation.
505+
506+ This is useful when the default `AutoModelForCausalLM` is not the one we want to use. For
507+ example, when the model's code is in a HuggingFace repo that is out of date, or has
508+ dependencies that TensorRT-LLM does not have, etc.
509+
510+ Args:
511+ model_type: This should be the value for the `model_type` field in the model's config.
512+ custom_model_cls: The `AutoModelForCausalLM` implementation that should be used for
513+ `model_type`.
514+ """
515+ cls ._custom_model_mapping [model_type ] = custom_model_cls
516+
485517
486518class _StateDictParamNameConverter :
487519 """Helper class for applying param name conversions to a state dict.
0 commit comments