diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 3f22e22d5d..1986febaaf 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -14,19 +14,25 @@ from __future__ import absolute_import import logging +from typing import Optional, Union, List, Dict import sagemaker -from sagemaker import image_uris +from sagemaker import image_uris, ModelMetrics +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( model_code_key_prefix, python_deprecation_warning, validate_version_or_image_args, ) +from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.chainer import defaults from sagemaker.deserializers import NumpyDeserializer from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer +from sagemaker.utils import to_string +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -75,14 +81,14 @@ class ChainerModel(FrameworkModel): def __init__( self, - model_data, - role, - entry_point, - image_uri=None, - framework_version=None, - py_version=None, - predictor_cls=ChainerPredictor, - model_server_workers=None, + model_data: Union[str, PipelineVariable], + role: str, + entry_point: str, + image_uri: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[str] = None, + py_version: Optional[str] = None, + predictor_cls: callable = ChainerPredictor, + model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): """Initialize an ChainerModel. @@ -142,27 +148,27 @@ def __init__( def register( self, - content_types, - response_types, - inference_instances, - transform_instances, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -218,6 +224,8 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = (framework or self._framework_name).upper() return super(ChainerModel, self).register( content_types, response_types, @@ -236,7 +244,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=(framework or self._framework_name).upper(), + framework=framework, framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, @@ -282,7 +290,9 @@ def prepare_container_def( deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( + self.model_server_workers + ) return sagemaker.container_def(deploy_image, self.model_data, deploy_env) def serving_image_uri( diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 1ab122b2e0..dee102999b 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -76,6 +76,7 @@ build_dict, get_config_value, name_from_base, + to_string, ) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -1947,10 +1948,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config): current_hyperparameters = estimator.hyperparameters() if current_hyperparameters is not None: - hyperparameters = { - str(k): (v.to_string() if is_pipeline_variable(v) else str(v)) - for (k, v) in current_hyperparameters.items() - } + hyperparameters = {str(k): to_string(v) for (k, v) in current_hyperparameters.items()} train_args = config.copy() train_args["input_mode"] = estimator.input_mode diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 04af57b566..6f810dc5e2 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -14,18 +14,24 @@ from __future__ import absolute_import import logging +from typing import Optional, Union, List, Dict import sagemaker -from sagemaker import image_uris +from sagemaker import image_uris, ModelMetrics from sagemaker.deserializers import JSONDeserializer +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( model_code_key_prefix, validate_version_or_image_args, ) +from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.session import Session +from sagemaker.utils import to_string +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -100,16 +106,16 @@ class HuggingFaceModel(FrameworkModel): def __init__( self, - role, - model_data=None, - entry_point=None, - transformers_version=None, - tensorflow_version=None, - pytorch_version=None, - py_version=None, - image_uri=None, - predictor_cls=HuggingFacePredictor, - model_server_workers=None, + role: str, + model_data: Optional[Union[str, PipelineVariable]] = None, + entry_point: Optional[str] = None, + transformers_version: Optional[str] = None, + tensorflow_version: Optional[str] = None, + pytorch_version: Optional[str] = None, + py_version: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + predictor_cls: callable = HuggingFacePredictor, + model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): """Initialize a HuggingFaceModel. @@ -299,27 +305,27 @@ def deploy( def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -377,6 +383,13 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = ( + framework + or fetch_framework_and_framework_version( + self.tensorflow_version, self.pytorch_version + )[0] + ).upper() return super(HuggingFaceModel, self).register( content_types, response_types, @@ -395,12 +408,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=( - framework - or fetch_framework_and_framework_version( - self.tensorflow_version, self.pytorch_version - )[0] - ).upper(), + framework=framework, framework_version=framework_version or fetch_framework_and_framework_version(self.tensorflow_version, self.pytorch_version)[ 1 @@ -449,7 +457,9 @@ def prepare_container_def( deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( + self.model_server_workers + ) return sagemaker.container_def( deploy_image, self.repacked_model_data or self.model_data, deploy_env ) diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index a3cd17cd8c..d90a5ca76f 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import os +from typing import Union, Optional from six.moves.urllib.parse import urlparse @@ -22,6 +23,8 @@ from sagemaker.deprecations import removed_kwargs from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg +from sagemaker.workflow.entities import PipelineVariable MULTI_MODEL_CONTAINER_MODE = "MultiModel" @@ -34,12 +37,12 @@ class MultiDataModel(Model): def __init__( self, - name, - model_data_prefix, - model=None, - image_uri=None, - role=None, - sagemaker_session=None, + name: str, + model_data_prefix: str, + model: Optional[Model] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + role: Optional[str] = None, + sagemaker_session: Optional[Session] = None, **kwargs, ): """Initialize a ``MultiDataModel``. @@ -106,6 +109,7 @@ def __init__( # Set the ``Model`` parameters if the model parameter is not specified if not self.model: + pop_out_unused_kwarg("model_data", kwargs, self.model_data_prefix) super(MultiDataModel, self).__init__( image_uri, self.model_data_prefix, @@ -115,7 +119,9 @@ def __init__( **kwargs, ) - def prepare_container_def(self, instance_type=None, accelerator_type=None): + def prepare_container_def( + self, instance_type=None, accelerator_type=None, serverless_inference_config=None + ): """Return a container definition set. Definition set includes MultiModel mode, model data and other parameters diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 4aaf6a8acc..f2e18c009e 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -14,21 +14,27 @@ from __future__ import absolute_import import logging +from typing import Union, Optional, List, Dict import packaging.version import sagemaker -from sagemaker import image_uris +from sagemaker import image_uris, ModelMetrics from sagemaker.deserializers import JSONDeserializer +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( model_code_key_prefix, python_deprecation_warning, validate_version_or_image_args, ) +from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.mxnet import defaults from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer +from sagemaker.utils import to_string +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -77,14 +83,14 @@ class MXNetModel(FrameworkModel): def __init__( self, - model_data, - role, - entry_point, - framework_version=None, - py_version=None, - image_uri=None, - predictor_cls=MXNetPredictor, - model_server_workers=None, + model_data: Union[str, PipelineVariable], + role: str, + entry_point: str, + framework_version: str = _LOWEST_MMS_VERSION, + py_version: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + predictor_cls: callable = MXNetPredictor, + model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): """Initialize an MXNetModel. @@ -102,7 +108,7 @@ def __init__( hosting. If ``source_dir`` is specified, then ``entry_point`` must point to a file located at the root of ``source_dir``. framework_version (str): MXNet version you want to use for executing - your model training code. Defaults to ``None``. Required unless + your model training code. Defaults to ``1.4.0``. Required unless ``image_uri`` is provided. py_version (str): Python version you want to use for executing your model training code. Defaults to ``None``. Required unless @@ -144,27 +150,27 @@ def __init__( def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -220,6 +226,8 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = (framework or self._framework_name).upper() return super(MXNetModel, self).register( content_types, response_types, @@ -238,7 +246,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=(framework or self._framework_name).upper(), + framework=framework, framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, @@ -286,7 +294,9 @@ def prepare_container_def( deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( + self.model_server_workers + ) return sagemaker.container_def( deploy_image, self.repacked_model_data or self.model_data, deploy_env ) diff --git a/src/sagemaker/parameter.py b/src/sagemaker/parameter.py index 79bbc62da2..b44e6f9ef2 100644 --- a/src/sagemaker/parameter.py +++ b/src/sagemaker/parameter.py @@ -16,8 +16,8 @@ import json from typing import Union -from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable +from sagemaker.utils import to_string class ParameterRange(object): @@ -78,12 +78,8 @@ def as_tuning_range(self, name): """ return { "Name": name, - "MinValue": str(self.min_value) - if not is_pipeline_variable(self.min_value) - else self.min_value.to_string(), - "MaxValue": str(self.max_value) - if not is_pipeline_variable(self.max_value) - else self.max_value.to_string(), + "MinValue": to_string(self.min_value), + "MaxValue": to_string(self.max_value), "ScalingType": self.scaling_type, } @@ -117,7 +113,7 @@ def __init__(self, values): # pylint: disable=super-init-not-called This input will be converted into a list of strings. """ values = values if isinstance(values, list) else [values] - self.values = [str(v) if not is_pipeline_variable(v) else v.to_string() for v in values] + self.values = [to_string(v) for v in values] def as_tuning_range(self, name): """Represent the parameter range as a dictionary. diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 5047e6351a..f7c1bded9a 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -13,10 +13,10 @@ """Placeholder docstring""" from __future__ import absolute_import -from typing import Optional, Dict +from typing import Optional, Dict, List, Union import sagemaker -from sagemaker import ModelMetrics +from sagemaker import ModelMetrics, Model from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties from sagemaker.session import Session @@ -25,6 +25,7 @@ update_container_with_inference_params, ) from sagemaker.transformer import Transformer +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -36,13 +37,13 @@ class PipelineModel(object): def __init__( self, - models, - role, - predictor_cls=None, - name=None, - vpc_config=None, - sagemaker_session=None, - enable_network_isolation=False, + models: List[Model], + role: str, + predictor_cls: Optional[callable] = None, + name: Optional[str] = None, + vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, + sagemaker_session: Optional[Session] = None, + enable_network_isolation: Union[bool, PipelineVariable] = False, ): """Initialize a SageMaker `Model` instance. @@ -267,27 +268,27 @@ def _create_sagemaker_pipeline_model(self, instance_type): @runnable_by_pipeline def register( self, - content_types: list, - response_types: list, - inference_instances: Optional[list] = None, - transform_instances: Optional[list] = None, - model_package_name: Optional[str] = None, - model_package_group_name: Optional[str] = None, - image_uri: Optional[str] = None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, model_metrics: Optional[ModelMetrics] = None, metadata_properties: Optional[MetadataProperties] = None, marketplace_cert: bool = False, - approval_status: Optional[str] = None, + approval_status: Optional[Union[str, PipelineVariable]] = None, description: Optional[str] = None, drift_check_baselines: Optional[DriftCheckBaselines] = None, - customer_metadata_properties: Optional[Dict[str, str]] = None, - domain: Optional[str] = None, - sample_payload_url: Optional[str] = None, - task: Optional[str] = None, - framework: Optional[str] = None, - framework_version: Optional[str] = None, - nearest_model_name: Optional[str] = None, - data_input_configuration: Optional[str] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -345,7 +346,7 @@ def register( framework_version=framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, - container_def=container_def, + container_list=container_def, ) else: container_def = [ diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index fcbfd1da84..a16fc4d5e2 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -14,20 +14,27 @@ from __future__ import absolute_import import logging +from typing import Optional, Union, List, Dict + import packaging.version import sagemaker -from sagemaker import image_uris +from sagemaker import image_uris, ModelMetrics from sagemaker.deserializers import NumpyDeserializer +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( model_code_key_prefix, python_deprecation_warning, validate_version_or_image_args, ) +from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.pytorch import defaults from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer +from sagemaker.utils import to_string +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -77,14 +84,14 @@ class PyTorchModel(FrameworkModel): def __init__( self, - model_data, - role, - entry_point, - framework_version=None, - py_version=None, - image_uri=None, - predictor_cls=PyTorchPredictor, - model_server_workers=None, + model_data: Union[str, PipelineVariable], + role: str, + entry_point: str, + framework_version: str = "1.3", + py_version: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + predictor_cls: callable = PyTorchPredictor, + model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): """Initialize a PyTorchModel. @@ -102,7 +109,7 @@ def __init__( hosting. If ``source_dir`` is specified, then ``entry_point`` must point to a file located at the root of ``source_dir``. framework_version (str): PyTorch version you want to use for - executing your model training code. Defaults to None. Required + executing your model training code. Defaults to 1.3. Required unless ``image_uri`` is provided. py_version (str): Python version you want to use for executing your model training code. Defaults to ``None``. Required unless @@ -145,27 +152,27 @@ def __init__( def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -221,6 +228,8 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = (framework or self._framework_name).upper() return super(PyTorchModel, self).register( content_types, response_types, @@ -239,7 +248,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=(framework or self._framework_name).upper(), + framework=framework, framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, @@ -285,7 +294,9 @@ def prepare_container_def( deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( + self.model_server_workers + ) return sagemaker.container_def( deploy_image, self.repacked_model_data or self.model_data, deploy_env ) diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 70ea22908e..5bb469991a 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -14,15 +14,21 @@ from __future__ import absolute_import import logging +from typing import Union, Optional, List, Dict import sagemaker -from sagemaker import image_uris +from sagemaker import image_uris, ModelMetrics from sagemaker.deserializers import NumpyDeserializer +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args +from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer from sagemaker.sklearn import defaults +from sagemaker.utils import to_string +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -71,14 +77,14 @@ class SKLearnModel(FrameworkModel): def __init__( self, - model_data, - role, - entry_point, - framework_version=None, - py_version="py3", - image_uri=None, - predictor_cls=SKLearnPredictor, - model_server_workers=None, + model_data: Union[str, PipelineVariable], + role: str, + entry_point: str, + framework_version: Optional[str] = None, + py_version: str = "py3", + image_uri: Optional[Union[str, PipelineVariable]] = None, + predictor_cls: callable = SKLearnPredictor, + model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): """Initialize an SKLearnModel. @@ -139,27 +145,27 @@ def __init__( def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -215,6 +221,8 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = (framework or self._framework_name).upper() return super(SKLearnModel, self).register( content_types, response_types, @@ -233,7 +241,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=(framework or self._framework_name).upper(), + framework=framework, framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, @@ -274,7 +282,9 @@ def prepare_container_def( deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( + self.model_server_workers + ) model_data_uri = ( self.repacked_model_data if self.enable_network_isolation() else self.model_data ) diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 401ae04b23..82885995b7 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -14,14 +14,18 @@ from __future__ import absolute_import import logging +from typing import Union, Optional, List, Dict import sagemaker -from sagemaker import image_uris, s3 +from sagemaker import image_uris, s3, ModelMetrics from sagemaker.deserializers import JSONDeserializer from sagemaker.deprecations import removed_kwargs +from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.metadata_properties import MetadataProperties from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import PipelineSession logger = logging.getLogger(__name__) @@ -126,13 +130,13 @@ class TensorFlowModel(sagemaker.model.FrameworkModel): def __init__( self, - model_data, - role, - entry_point=None, - image_uri=None, - framework_version=None, - container_log_level=None, - predictor_cls=TensorFlowPredictor, + model_data: Union[str, PipelineVariable], + role: str, + entry_point: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[str] = None, + container_log_level: Optional[int] = None, + predictor_cls: callable = TensorFlowPredictor, **kwargs, ): """Initialize a Model. @@ -193,27 +197,27 @@ def __init__( def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -269,6 +273,8 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = (framework or self._framework_name).upper() return super(TensorFlowModel, self).register( content_types, response_types, @@ -287,7 +293,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=(framework or self._framework_name).upper(), + framework=framework, framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 58c875f8d9..0440cee3b8 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -44,8 +44,7 @@ from sagemaker.workflow.pipeline_context import runnable_by_pipeline from sagemaker.session import Session -from sagemaker.utils import base_from_name, base_name_from_image, name_from_base -from sagemaker.workflow import is_pipeline_variable +from sagemaker.utils import base_from_name, base_name_from_image, name_from_base, to_string AMAZON_ESTIMATOR_MODULE = "sagemaker" AMAZON_ESTIMATOR_CLS_NAMES = { @@ -414,8 +413,7 @@ def _prepare_static_hyperparameters( """Prepare static hyperparameters for one estimator before tuning.""" # Remove any hyperparameter that will be tuned static_hyperparameters = { - str(k): str(v) if not is_pipeline_variable(v) else v.to_string() - for (k, v) in estimator.hyperparameters().items() + str(k): to_string(v) for (k, v) in estimator.hyperparameters().items() } for hyperparameter_name in hyperparameter_ranges.keys(): static_hyperparameters.pop(hyperparameter_name, None) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 4365d22f2d..d71b8e1433 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -845,3 +845,14 @@ def pop_out_unused_kwarg(arg_name: str, kwargs: dict, override_val: Optional[str warn_msg += " and further overridden with {}.".format(override_val) logging.warning(warn_msg) kwargs.pop(arg_name) + + +def to_string(obj: object): + """Convert an object to string + + This helper function handles converting PipelineVariable object to string as well + + Args: + obj (object): The object to be converted + """ + return obj.to_string() if is_pipeline_variable(obj) else str(obj) diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 6e56230234..5279c07c50 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -14,14 +14,20 @@ from __future__ import absolute_import import logging +from typing import Optional, Union, List, Dict import sagemaker -from sagemaker import image_uris +from sagemaker import image_uris, ModelMetrics from sagemaker.deserializers import CSVDeserializer +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import model_code_key_prefix +from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.predictor import Predictor from sagemaker.serializers import LibSVMSerializer +from sagemaker.utils import to_string +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable from sagemaker.xgboost.defaults import XGBOOST_NAME from sagemaker.xgboost.utils import validate_py_version, validate_framework_version @@ -70,14 +76,14 @@ class XGBoostModel(FrameworkModel): def __init__( self, - model_data, - role, - entry_point, - framework_version, - image_uri=None, - py_version="py3", - predictor_cls=XGBoostPredictor, - model_server_workers=None, + model_data: Union[str, PipelineVariable], + role: str, + entry_point: str, + framework_version: str, + image_uri: Optional[Union[str, PipelineVariable]] = None, + py_version: str = "py3", + predictor_cls: callable = XGBoostPredictor, + model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): """Initialize an XGBoostModel. @@ -126,27 +132,27 @@ def __init__( def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -202,6 +208,8 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = (framework or self._framework_name).upper() return super(XGBoostModel, self).register( content_types, response_types, @@ -220,7 +228,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=(framework or self._framework_name).upper(), + framework=framework, framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, @@ -259,7 +267,9 @@ def prepare_container_def( deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( + self.model_server_workers + ) model_data = ( self.repacked_model_data if self.enable_network_isolation() else self.model_data ) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 5302e21fb8..b0b5045b94 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -30,7 +30,8 @@ import sagemaker from sagemaker.session_settings import SessionSettings from tests.unit.sagemaker.workflow.helpers import CustomStep -from sagemaker.workflow.parameters import ParameterString +from sagemaker.workflow.parameters import ParameterString, ParameterInteger + BUCKET_WITHOUT_WRITING_PERMISSION = "s3://bucket-without-writing-permission" @@ -773,3 +774,16 @@ def test_pop_out_unused_kwarg(): kwargs = dict(arg1=1, arg2=2) sagemaker.utils.pop_out_unused_kwarg("arg3", kwargs) assert len(kwargs) == 2 + + +def test_to_string(): + var = 1 + assert sagemaker.utils.to_string(var) == "1" + + var = ParameterInteger(name="MyInt") + assert sagemaker.utils.to_string(var).expr == { + "Std:Join": { + "On": "", + "Values": [{"Get": "Parameters.MyInt"}], + }, + }