|
14 | 14 | from __future__ import absolute_import |
15 | 15 |
|
16 | 16 | import logging |
| 17 | +from typing import Optional, Union, List, Dict |
17 | 18 |
|
18 | 19 | import sagemaker |
19 | | -from sagemaker import image_uris |
| 20 | +from sagemaker import image_uris, ModelMetrics |
20 | 21 | from sagemaker.deserializers import JSONDeserializer |
| 22 | +from sagemaker.drift_check_baselines import DriftCheckBaselines |
21 | 23 | from sagemaker.fw_utils import ( |
22 | 24 | model_code_key_prefix, |
23 | 25 | validate_version_or_image_args, |
24 | 26 | ) |
| 27 | +from sagemaker.metadata_properties import MetadataProperties |
25 | 28 | from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME |
26 | 29 | from sagemaker.predictor import Predictor |
27 | 30 | from sagemaker.serializers import JSONSerializer |
28 | 31 | from sagemaker.session import Session |
| 32 | +from sagemaker.utils import to_string |
| 33 | +from sagemaker.workflow.entities import PipelineVariable |
29 | 34 |
|
30 | 35 | logger = logging.getLogger("sagemaker") |
31 | 36 |
|
@@ -92,16 +97,16 @@ class HuggingFaceModel(FrameworkModel): |
92 | 97 |
|
93 | 98 | def __init__( |
94 | 99 | self, |
95 | | - role, |
96 | | - model_data=None, |
97 | | - entry_point=None, |
98 | | - transformers_version=None, |
99 | | - tensorflow_version=None, |
100 | | - pytorch_version=None, |
101 | | - py_version=None, |
102 | | - image_uri=None, |
103 | | - predictor_cls=HuggingFacePredictor, |
104 | | - model_server_workers=None, |
| 100 | + role: str, |
| 101 | + model_data: Optional[Union[str, PipelineVariable]] = None, |
| 102 | + entry_point: Optional[str] = None, |
| 103 | + transformers_version: Optional[str] = None, |
| 104 | + tensorflow_version: Optional[str] = None, |
| 105 | + pytorch_version: Optional[str] = None, |
| 106 | + py_version: Optional[str] = None, |
| 107 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 108 | + predictor_cls: callable = HuggingFacePredictor, |
| 109 | + model_server_workers: Optional[Union[int, PipelineVariable]] = None, |
105 | 110 | **kwargs, |
106 | 111 | ): |
107 | 112 | """Initialize a HuggingFaceModel. |
@@ -291,21 +296,21 @@ def deploy( |
291 | 296 |
|
292 | 297 | def register( |
293 | 298 | self, |
294 | | - content_types, |
295 | | - response_types, |
296 | | - inference_instances=None, |
297 | | - transform_instances=None, |
298 | | - model_package_name=None, |
299 | | - model_package_group_name=None, |
300 | | - image_uri=None, |
301 | | - model_metrics=None, |
302 | | - metadata_properties=None, |
303 | | - marketplace_cert=False, |
304 | | - approval_status=None, |
305 | | - description=None, |
306 | | - drift_check_baselines=None, |
307 | | - customer_metadata_properties=None, |
308 | | - domain=None, |
| 299 | + content_types: List[Union[str, PipelineVariable]], |
| 300 | + response_types: List[Union[str, PipelineVariable]], |
| 301 | + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 302 | + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 303 | + model_package_name: Optional[Union[str, PipelineVariable]] = None, |
| 304 | + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, |
| 305 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 306 | + model_metrics: Optional[ModelMetrics] = None, |
| 307 | + metadata_properties: Optional[MetadataProperties] = None, |
| 308 | + marketplace_cert: bool = False, |
| 309 | + approval_status: Optional[Union[str, PipelineVariable]] = None, |
| 310 | + description: Optional[str] = None, |
| 311 | + drift_check_baselines: Optional[DriftCheckBaselines] = None, |
| 312 | + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 313 | + domain: Optional[Union[str, PipelineVariable]] = None, |
309 | 314 | ): |
310 | 315 | """Creates a model package for creating SageMaker models or listing on Marketplace. |
311 | 316 |
|
@@ -409,7 +414,9 @@ def prepare_container_def( |
409 | 414 | deploy_env.update(self._script_mode_env_vars()) |
410 | 415 |
|
411 | 416 | if self.model_server_workers: |
412 | | - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) |
| 417 | + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( |
| 418 | + self.model_server_workers |
| 419 | + ) |
413 | 420 | return sagemaker.container_def( |
414 | 421 | deploy_image, self.repacked_model_data or self.model_data, deploy_env |
415 | 422 | ) |
|
0 commit comments