1616import json
1717import logging
1818import os .path
19+ import urllib .request
20+ from json import JSONDecodeError
21+ from urllib .error import HTTPError , URLError
1922from enum import Enum
2023from typing import Optional , Union , Dict , Any
2124
@@ -134,10 +137,10 @@ def _read_existing_serving_properties(directory: str):
134137
135138def _get_model_config_properties_from_s3 (model_s3_uri : str ):
136139 """Placeholder docstring"""
140+
137141 s3_files = s3 .S3Downloader .list (model_s3_uri )
138- valid_config_files = ["config.json" , "model_index.json" ]
139142 model_config = None
140- for config in valid_config_files :
143+ for config in defaults . VALID_MODEL_CONFIG_FILES :
141144 config_file = os .path .join (model_s3_uri , config )
142145 if config_file in s3_files :
143146 model_config = json .loads (s3 .S3Downloader .read_file (config_file ))
@@ -151,26 +154,53 @@ def _get_model_config_properties_from_s3(model_s3_uri: str):
151154 return model_config
152155
153156
157+ def _get_model_config_properties_from_hf (model_id : str ):
158+ """Placeholder docstring"""
159+
160+ config_url_prefix = f"https://huggingface.co/{ model_id } /raw/main/"
161+ model_config = None
162+ for config in defaults .VALID_MODEL_CONFIG_FILES :
163+ config_file_url = config_url_prefix + config
164+ try :
165+ with urllib .request .urlopen (config_file_url ) as response :
166+ model_config = json .load (response )
167+ break
168+ except (HTTPError , URLError , TimeoutError , JSONDecodeError ) as e :
169+ logger .warning (
170+ "Exception encountered while trying to read config file %s. " "Details: %s" ,
171+ config_file_url ,
172+ e ,
173+ )
174+ if not model_config :
175+ raise ValueError (
176+ f"Did not find a config.json or model_index.json file in huggingface hub for "
177+ f"{ model_id } . Please make sure a config.json exists (or model_index.json for Stable "
178+ f"Diffusion Models) for this model in the huggingface hub"
179+ )
180+ return model_config
181+
182+
154183class DJLModel (FrameworkModel ):
155184 """A DJL SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
156185
157186 def __new__ (
158187 cls ,
159- model_s3_uri : str ,
188+ model_id : str ,
160189 * args ,
161190 ** kwargs ,
162191 ): # pylint: disable=W0613
163192 """Create a specific subclass of DJLModel for a given engine"""
164193
165- if not model_s3_uri .startswith ("s3://" ):
166- raise ValueError ("DJLModel only supports loading model artifacts from s3" )
167- if model_s3_uri .endswith ("tar.gz" ):
194+ if model_id .endswith ("tar.gz" ):
168195 raise ValueError (
169196 "DJLModel does not support model artifacts in tar.gz format."
170197 "Please store the model in uncompressed format and provide the s3 uri of the "
171198 "containing folder"
172199 )
173- model_config = _get_model_config_properties_from_s3 (model_s3_uri )
200+ if model_id .startswith ("s3://" ):
201+ model_config = _get_model_config_properties_from_s3 (model_id )
202+ else :
203+ model_config = _get_model_config_properties_from_hf (model_id )
174204 if model_config .get ("_class_name" ) == "StableDiffusionPipeline" :
175205 model_type = defaults .STABLE_DIFFUSION_MODEL_TYPE
176206 num_heads = 0
@@ -196,7 +226,7 @@ def __new__(
196226
197227 def __init__ (
198228 self ,
199- model_s3_uri : str ,
229+ model_id : str ,
200230 role : str ,
201231 djl_version : Optional [str ] = None ,
202232 task : Optional [str ] = None ,
@@ -216,8 +246,9 @@ def __init__(
216246 """Initialize a DJLModel.
217247
218248 Args:
219- model_s3_uri (str): The Amazon S3 location containing the uncompressed model
220- artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
249+ model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
250+ containing the uncompressed model artifacts (i.e. not a tar.gz file).
251+ The model artifacts are expected to be in HuggingFace pre-trained model
221252 format (i.e. model should be loadable from the huggingface transformers
222253 from_pretrained api, and should also include tokenizer configs if applicable).
223254 role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
@@ -285,13 +316,13 @@ def __init__(
285316 if kwargs .get ("model_data" ):
286317 logger .warning (
287318 "DJLModels do not use model_data parameter. model_data parameter will be ignored."
288- "You only need to set model_S3_uri and ensure it points to uncompressed model "
289- "artifacts."
319+ "You only need to set model_id and ensure it points to uncompressed model "
320+ "artifacts in s3, or a valid HuggingFace Hub model_id ."
290321 )
291322 super (DJLModel , self ).__init__ (
292323 None , image_uri , role , entry_point , predictor_cls = predictor_cls , ** kwargs
293324 )
294- self .model_s3_uri = model_s3_uri
325+ self .model_id = model_id
295326 self .djl_version = djl_version
296327 self .task = task
297328 self .data_type = data_type
@@ -529,7 +560,10 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]
529560 serving_properties = {}
530561 serving_properties ["engine" ] = self .engine .value [0 ] # pylint: disable=E1101
531562 serving_properties ["option.entryPoint" ] = self .engine .value [1 ] # pylint: disable=E1101
532- serving_properties ["option.s3url" ] = self .model_s3_uri
563+ if self .model_id .startswith ("s3://" ):
564+ serving_properties ["option.s3url" ] = self .model_id
565+ else :
566+ serving_properties ["option.model_id" ] = self .model_id
533567 if self .number_of_partitions :
534568 serving_properties ["option.tensor_parallel_degree" ] = self .number_of_partitions
535569 if self .entry_point :
@@ -593,7 +627,7 @@ class DeepSpeedModel(DJLModel):
593627
594628 def __init__ (
595629 self ,
596- model_s3_uri : str ,
630+ model_id : str ,
597631 role : str ,
598632 tensor_parallel_degree : Optional [int ] = None ,
599633 max_tokens : Optional [int ] = None ,
@@ -606,11 +640,11 @@ def __init__(
606640 """Initialize a DeepSpeedModel
607641
608642 Args:
609- model_s3_uri (str): The Amazon S3 location containing the uncompressed model
610- artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
643+ model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
644+ containing the uncompressed model artifacts (i.e. not a tar.gz file).
645+ The model artifacts are expected to be in HuggingFace pre-trained model
611646 format (i.e. model should be loadable from the huggingface transformers
612- from_pretrained
613- api, and should also include tokenizer configs if applicable).
647+ from_pretrained api, and should also include tokenizer configs if applicable).
614648 role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
615649 SageMaker training jobs and APIs that create Amazon SageMaker
616650 endpoints use this role to access model artifacts. After the endpoint is created,
@@ -647,7 +681,7 @@ def __init__(
647681 """
648682
649683 super (DeepSpeedModel , self ).__init__ (
650- model_s3_uri ,
684+ model_id ,
651685 role ,
652686 ** kwargs ,
653687 )
@@ -710,7 +744,7 @@ class HuggingFaceAccelerateModel(DJLModel):
710744
711745 def __init__ (
712746 self ,
713- model_s3_uri : str ,
747+ model_id : str ,
714748 role : str ,
715749 number_of_partitions : Optional [int ] = None ,
716750 device_id : Optional [int ] = None ,
@@ -722,11 +756,11 @@ def __init__(
722756 """Initialize a HuggingFaceAccelerateModel.
723757
724758 Args:
725- model_s3_uri (str): The Amazon S3 location containing the uncompressed model
726- artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
759+ model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
760+ containing the uncompressed model artifacts (i.e. not a tar.gz file).
761+ The model artifacts are expected to be in HuggingFace pre-trained model
727762 format (i.e. model should be loadable from the huggingface transformers
728- from_pretrained
729- method).
763+ from_pretrained api, and should also include tokenizer configs if applicable).
730764 role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
731765 SageMaker training jobs and APIs that create Amazon SageMaker
732766 endpoints use this role to access model artifacts. After the endpoint is created,
@@ -760,7 +794,7 @@ def __init__(
760794 """
761795
762796 super (HuggingFaceAccelerateModel , self ).__init__ (
763- model_s3_uri ,
797+ model_id ,
764798 role ,
765799 number_of_partitions = number_of_partitions ,
766800 ** kwargs ,
0 commit comments