1+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4+ # may not use this file except in compliance with the License. A copy of
5+ # the License is located at
6+ #
7+ # http://aws.amazon.com/apache2.0/
8+ #
9+ # or in the "license" file accompanying this file. This file is
10+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+ # ANY KIND, either express or implied. See the License for the specific
12+ # language governing permissions and limitations under the License.
113"""Holds mixin logic to support deployment of Model ID"""
214from __future__ import absolute_import
315import logging
2638 _get_admissible_dtypes ,
2739)
2840from sagemaker .serve .utils .local_hardware import _get_nb_instance , _get_ram_usage_mb
29- from sagemaker .serve .model_server .djl_serving .prepare import prepare_for_djl_serving
41+ from sagemaker .serve .model_server .djl_serving .prepare import (
42+ prepare_for_djl_serving ,
43+ _create_dir_structure ,
44+ )
3045from sagemaker .serve .utils .predictors import DjlLocalModePredictor
3146from sagemaker .serve .utils .types import ModelServer , _DjlEngine
3247from sagemaker .serve .mode .function_pointers import Mode
4055
4156logger = logging .getLogger (__name__ )
4257
43- _JUMP_START_HUGGING_FACE_PREFIX = "huggingface"
4458# Match JumpStart DJL entrypoint format
4559_DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py"
4660_CODE_FOLDER = "code"
@@ -86,16 +100,9 @@ def _prepare_for_mode(self):
86100 def _get_client_translators (self ):
87101 """Placeholder docstring"""
88102
89- def _validate_model_server (self ):
103+ def _is_djl (self ):
90104 """Placeholder docstring"""
91- if self .model_server != ModelServer .DJL_SERVING :
92- messaging = (
93- "HuggingFace Model ID support on model server: "
94- f"{ self .model_server } is not currently supported. "
95- f"Defaulting to { ModelServer .DJL_SERVING } "
96- )
97- logger .warning (messaging )
98- self .model_server = ModelServer .DJL_SERVING
105+ return self .model_server == ModelServer .DJL_SERVING
99106
100107 def _validate_djl_serving_sample_data (self ):
101108 """Placeholder docstring"""
@@ -112,12 +119,6 @@ def _validate_djl_serving_sample_data(self):
112119 ):
113120 raise ValueError (_INVALID_SAMPLE_DATA_EX )
114121
115- def _is_jumpstart_model_id (self ) -> bool :
116- """Placeholder docstring"""
117- # this will potentially extend in the future so leave like this
118- # for now, only hf jumpstart model ids will be considered
119- return self .model .startswith (_JUMP_START_HUGGING_FACE_PREFIX )
120-
121122 def _create_djl_model (self ) -> Type [Model ]:
122123 """Placeholder docstring"""
123124 code_dir = str (Path (self .model_path ).joinpath (_CODE_FOLDER ))
@@ -211,9 +212,6 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
211212 ram_usage_after = _get_ram_usage_mb ()
212213
213214 self .ram_usage_model_load = max (ram_usage_after - ram_usage_before , 0 )
214- logger .info (
215- "RAM used to load the %s locally was %s MB" , self .model , self .ram_usage_model_load
216- )
217215
218216 return predictor
219217
@@ -237,7 +235,8 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
237235 self .pysdk_model .env ["TRANSFORMERS_CACHE" ] = "/tmp"
238236 self .pysdk_model .env ["HUGGINGFACE_HUB_CACHE" ] = "/tmp"
239237
240- kwargs ["endpoint_logging" ] = True
238+ if "endpoint_logging" not in kwargs :
239+ kwargs ["endpoint_logging" ] = True
241240 if self .nb_instance_type and "instance_type" not in kwargs :
242241 kwargs .update ({"instance_type" : self .nb_instance_type })
243242
@@ -253,6 +252,7 @@ def _build_for_hf_djl(self):
253252 """Placeholder docstring"""
254253 self .overwrite_props_from_file = True
255254 self .nb_instance_type = _get_nb_instance ()
255+ _create_dir_structure (self .model_path )
256256 self .engine , self .hf_model_config = _auto_detect_engine (
257257 self .model , self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" )
258258 )
@@ -463,7 +463,6 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800):
463463
464464 def _build_for_djl (self ):
465465 """Placeholder docstring"""
466- self ._validate_model_server ()
467466 self ._validate_djl_serving_sample_data ()
468467 self .secret_key = None
469468
0 commit comments