|
19 | 19 | import os |
20 | 20 | import re |
21 | 21 | import copy |
22 | | -from typing import List, Dict |
| 22 | +from typing import List, Dict, Optional, Union |
23 | 23 |
|
24 | 24 | import sagemaker |
25 | 25 | from sagemaker import ( |
|
30 | 30 | utils, |
31 | 31 | git_utils, |
32 | 32 | ) |
| 33 | +from sagemaker.session import Session |
| 34 | +from sagemaker.model_metrics import ModelMetrics |
33 | 35 | from sagemaker.deprecations import removed_kwargs |
| 36 | +from sagemaker.drift_check_baselines import DriftCheckBaselines |
| 37 | +from sagemaker.metadata_properties import MetadataProperties |
34 | 38 | from sagemaker.predictor import PredictorBase |
35 | 39 | from sagemaker.serverless import ServerlessInferenceConfig |
36 | 40 | from sagemaker.transformer import Transformer |
37 | 41 | from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model |
38 | | -from sagemaker.utils import unique_name_from_base |
| 42 | +from sagemaker.utils import unique_name_from_base, to_string |
39 | 43 | from sagemaker.async_inference import AsyncInferenceConfig |
40 | 44 | from sagemaker.predictor_async import AsyncPredictor |
41 | 45 | from sagemaker.workflow import is_pipeline_variable |
| 46 | +from sagemaker.workflow.entities import PipelineVariable |
42 | 47 | from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession |
43 | 48 |
|
44 | 49 | LOGGER = logging.getLogger("sagemaker") |
@@ -78,23 +83,23 @@ class Model(ModelBase): |
78 | 83 |
|
79 | 84 | def __init__( |
80 | 85 | self, |
81 | | - image_uri, |
82 | | - model_data=None, |
83 | | - role=None, |
84 | | - predictor_cls=None, |
85 | | - env=None, |
86 | | - name=None, |
87 | | - vpc_config=None, |
88 | | - sagemaker_session=None, |
89 | | - enable_network_isolation=False, |
90 | | - model_kms_key=None, |
91 | | - image_config=None, |
92 | | - source_dir=None, |
93 | | - code_location=None, |
94 | | - entry_point=None, |
95 | | - container_log_level=logging.INFO, |
96 | | - dependencies=None, |
97 | | - git_config=None, |
| 86 | + image_uri: Union[str, PipelineVariable], |
| 87 | + model_data: Optional[Union[str, PipelineVariable]] = None, |
| 88 | + role: Optional[str] = None, |
| 89 | + predictor_cls: Optional[callable] = None, |
| 90 | + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 91 | + name: Optional[str] = None, |
| 92 | + vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, |
| 93 | + sagemaker_session: Optional[Session] = None, |
| 94 | + enable_network_isolation: Union[bool, PipelineVariable] = False, |
| 95 | + model_kms_key: Optional[str] = None, |
| 96 | + image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 97 | + source_dir: Optional[str] = None, |
| 98 | + code_location: Optional[str] = None, |
| 99 | + entry_point: Optional[str] = None, |
| 100 | + container_log_level: Union[int, PipelineVariable] = logging.INFO, |
| 101 | + dependencies: Optional[List[str]] = None, |
| 102 | + git_config: Optional[Dict[str, str]] = None, |
98 | 103 | ): |
99 | 104 | """Initialize an SageMaker ``Model``. |
100 | 105 |
|
@@ -294,22 +299,22 @@ def __init__( |
294 | 299 | @runnable_by_pipeline |
295 | 300 | def register( |
296 | 301 | self, |
297 | | - content_types, |
298 | | - response_types, |
299 | | - inference_instances=None, |
300 | | - transform_instances=None, |
301 | | - model_package_name=None, |
302 | | - model_package_group_name=None, |
303 | | - image_uri=None, |
304 | | - model_metrics=None, |
305 | | - metadata_properties=None, |
306 | | - marketplace_cert=False, |
307 | | - approval_status=None, |
308 | | - description=None, |
309 | | - drift_check_baselines=None, |
310 | | - customer_metadata_properties=None, |
311 | | - validation_specification=None, |
312 | | - domain=None, |
| 302 | + content_types: List[Union[str, PipelineVariable]], |
| 303 | + response_types: List[Union[str, PipelineVariable]], |
| 304 | + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 305 | + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 306 | + model_package_name: Optional[Union[str, PipelineVariable]] = None, |
| 307 | + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, |
| 308 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 309 | + model_metrics: Optional[ModelMetrics] = None, |
| 310 | + metadata_properties: Optional[MetadataProperties] = None, |
| 311 | + marketplace_cert: bool = False, |
| 312 | + approval_status: Optional[Union[str, PipelineVariable]] = None, |
| 313 | + description: Optional[str] = None, |
| 314 | + drift_check_baselines: Optional[DriftCheckBaselines] = None, |
| 315 | + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 316 | + validation_specification: Optional[Union[str, PipelineVariable]] = None, |
| 317 | + domain: Optional[Union[str, PipelineVariable]] = None, |
313 | 318 | ): |
314 | 319 | """Creates a model package for creating SageMaker models or listing on Marketplace. |
315 | 320 |
|
@@ -385,10 +390,10 @@ def register( |
385 | 390 | @runnable_by_pipeline |
386 | 391 | def create( |
387 | 392 | self, |
388 | | - instance_type: str = None, |
389 | | - accelerator_type: str = None, |
390 | | - serverless_inference_config: ServerlessInferenceConfig = None, |
391 | | - tags: List[Dict[str, str]] = None, |
| 393 | + instance_type: Optional[str] = None, |
| 394 | + accelerator_type: Optional[str] = None, |
| 395 | + serverless_inference_config: Optional[ServerlessInferenceConfig] = None, |
| 396 | + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
392 | 397 | ): |
393 | 398 | """Create a SageMaker Model Entity |
394 | 399 |
|
@@ -570,7 +575,7 @@ def _script_mode_env_vars(self): |
570 | 575 | return { |
571 | 576 | SCRIPT_PARAM_NAME.upper(): script_name or str(), |
572 | 577 | DIR_PARAM_NAME.upper(): dir_name or str(), |
573 | | - CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level), |
| 578 | + CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level), |
574 | 579 | SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name, |
575 | 580 | } |
576 | 581 |
|
@@ -1239,19 +1244,19 @@ class FrameworkModel(Model): |
1239 | 1244 |
|
1240 | 1245 | def __init__( |
1241 | 1246 | self, |
1242 | | - model_data, |
1243 | | - image_uri, |
1244 | | - role, |
1245 | | - entry_point, |
1246 | | - source_dir=None, |
1247 | | - predictor_cls=None, |
1248 | | - env=None, |
1249 | | - name=None, |
1250 | | - container_log_level=logging.INFO, |
1251 | | - code_location=None, |
1252 | | - sagemaker_session=None, |
1253 | | - dependencies=None, |
1254 | | - git_config=None, |
| 1247 | + model_data: Union[str, PipelineVariable], |
| 1248 | + image_uri: Union[str, PipelineVariable], |
| 1249 | + role: str, |
| 1250 | + entry_point: str, |
| 1251 | + source_dir: Optional[str] = None, |
| 1252 | + predictor_cls: Optional[callable] = None, |
| 1253 | + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 1254 | + name: Optional[str] = None, |
| 1255 | + container_log_level: Union[int, PipelineVariable] = logging.INFO, |
| 1256 | + code_location: Optional[str] = None, |
| 1257 | + sagemaker_session: Optional[Session] = None, |
| 1258 | + dependencies: Optional[List[str]] = None, |
| 1259 | + git_config: Optional[Dict[str, str]] = None, |
1255 | 1260 | **kwargs, |
1256 | 1261 | ): |
1257 | 1262 | """Initialize a ``FrameworkModel``. |
|
0 commit comments