diff --git a/.gitignore b/.gitignore index 84e184aa92..1b6b4ca1cf 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,5 @@ venv/ *.swp .docker/ env/ -.vscode/ \ No newline at end of file +.vscode/ +.python-version \ No newline at end of file diff --git a/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst b/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst index 9071d05145..908621ea1c 100644 --- a/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst +++ b/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst @@ -82,6 +82,12 @@ Pipeline .. autoclass:: sagemaker.workflow.pipeline._PipelineExecution :members: +Parallelism Configuration +------------------------- + +.. autoclass:: sagemaker.workflow.parallelism_config.ParallelismConfiguration + :members: + Pipeline Experiment Config -------------------------- diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 0829e25f4b..006cc4846c 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -290,11 +290,15 @@ def __init__( probability_threshold (float): An optional value for binary prediction tasks in which the model returns a probability, to indicate the threshold to convert the prediction to a boolean value. Default is 0.5. - label_headers (list): List of label values - one for each score of the ``probability``. + label_headers (list[str]): List of headers, each for a predicted score in model output. + For bias analysis, it is used to extract the label value with the highest score as + predicted label. For explainability job, It is used to beautify the analysis report + by replacing placeholders like "label0". """ self.label = label self.probability = probability self.probability_threshold = probability_threshold + self.label_headers = label_headers if probability_threshold is not None: try: float(probability_threshold) @@ -1060,10 +1064,10 @@ def run_explainability( explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list): Config of the specific explainability method or a list of ExplainabilityConfig objects. Currently, SHAP and PDP are the two methods supported. - model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the - model output for the predicted scores to be explained. This is not required if the - model output is a single score. Alternatively, an instance of - ModelPredictedLabelConfig can be provided. + model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`): + Index or JSONPath to locate the predicted scores in the model output. This is not + required if the model output is a single score. Alternatively, it can be an instance + of ModelPredictedLabelConfig to provide more parameters like label_headers. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when ``wait`` is True (default: True). diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index d603188f74..cf039fa010 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2343,6 +2343,7 @@ def _stage_user_code_in_s3(self): dependencies=self.dependencies, kms_key=kms_key, s3_resource=self.sagemaker_session.s3_resource, + settings=self.sagemaker_session.settings, ) def _model_source_dir(self): diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 87b94711ae..79b9e803d7 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -19,8 +19,10 @@ import shutil import tempfile from collections import namedtuple +from typing import Optional import sagemaker.image_uris +from sagemaker.session_settings import SessionSettings import sagemaker.utils from sagemaker.deprecations import renamed_warning @@ -73,7 +75,20 @@ "2.6.0", "2.6.2", ], - "pytorch": ["1.6", "1.6.0", "1.7", "1.7.1", "1.8", "1.8.0", "1.8.1", "1.9", "1.9.0", "1.9.1"], + "pytorch": [ + "1.6", + "1.6.0", + "1.7", + "1.7.1", + "1.8", + "1.8.0", + "1.8.1", + "1.9", + "1.9.0", + "1.9.1", + "1.10", + "1.10.0", + ], } SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] @@ -203,6 +218,7 @@ def tar_and_upload_dir( dependencies=None, kms_key=None, s3_resource=None, + settings: Optional[SessionSettings] = None, ): """Package source files and upload a compress tar file to S3. @@ -230,6 +246,9 @@ def tar_and_upload_dir( s3_resource (boto3.resource("s3")): Optional. Pre-instantiated Boto3 Resource for S3 connections, can be used to customize the configuration, e.g. set the endpoint URL (default: None). + settings (sagemaker.session_settings.SessionSettings): Optional. The settings + of the SageMaker ``Session``, can be used to override the default encryption + behavior (default: None). Returns: sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name. @@ -241,6 +260,7 @@ def tar_and_upload_dir( dependencies = dependencies or [] key = "%s/sourcedir.tar.gz" % s3_key_prefix tmp = tempfile.mkdtemp() + encrypt_artifact = True if settings is None else settings.encrypt_repacked_artifacts try: source_files = _list_files_to_compress(script, directory) + dependencies @@ -250,6 +270,10 @@ def tar_and_upload_dir( if kms_key: extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key} + elif encrypt_artifact: + # encrypt the tarball at rest in S3 with the default AWS managed KMS key for S3 + # see https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html#API_PutObject_RequestSyntax + extra_args = {"ServerSideEncryption": "aws:kms"} else: extra_args = None diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index a64a710692..9c96858efe 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -63,7 +63,8 @@ "1.6": "1.6.0", "1.7": "1.7.1", "1.8": "1.8.1", - "1.9": "1.9.1" + "1.9": "1.9.1", + "1.10": "1.10.0" }, "versions": { "0.4.0": { @@ -500,6 +501,39 @@ "us-west-2": "763104351884" }, "repository": "pytorch-inference" + }, + "1.10.0": { + "py_versions": [ + "py38" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-inference" } } }, @@ -519,7 +553,8 @@ "1.6": "1.6.0", "1.7": "1.7.1", "1.8": "1.8.1", - "1.9": "1.9.1" + "1.9": "1.9.1", + "1.10": "1.10.0" }, "versions": { "0.4.0": { @@ -957,6 +992,39 @@ "us-west-2": "763104351884" }, "repository": "pytorch-training" + }, + "1.10.0": { + "py_versions": [ + "py38" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" } } } diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index f2d1bf8c14..033e838137 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -83,10 +83,11 @@ def __init__( self._session = sagemaker_session def to_lineage_object(self): - """Convert the ``Vertex`` object to its corresponding ``Artifact`` or ``Context`` object.""" + """Convert the ``Vertex`` object to its corresponding Artifact, Action, Context object.""" from sagemaker.lineage.artifact import Artifact, ModelArtifact from sagemaker.lineage.context import Context, EndpointContext from sagemaker.lineage.artifact import DatasetArtifact + from sagemaker.lineage.action import Action if self.lineage_entity == LineageEntityEnum.CONTEXT.value: resource_name = get_resource_name_from_arn(self.arn) @@ -103,6 +104,9 @@ def to_lineage_object(self): return DatasetArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session) + if self.lineage_entity == LineageEntityEnum.ACTION.value: + return Action.load(action_name=self.arn.split("/")[1], sagemaker_session=self._session) + raise ValueError("Vertex cannot be converted to a lineage object.") @@ -208,6 +212,44 @@ def _convert_api_response(self, response) -> LineageQueryResult: return converted + def _collapse_cross_account_artifacts(self, query_response): + """Collapse the duplicate vertices and edges for cross-account.""" + for edge in query_response.edges: + if ( + "artifact" in edge.source_arn + and "artifact" in edge.destination_arn + and edge.source_arn.split("/")[1] == edge.destination_arn.split("/")[1] + and edge.source_arn != edge.destination_arn + ): + edge_source_arn = edge.source_arn + edge_destination_arn = edge.destination_arn + self._update_cross_account_edge( + edges=query_response.edges, + arn=edge_source_arn, + duplicate_arn=edge_destination_arn, + ) + self._update_cross_account_vertex( + query_response=query_response, duplicate_arn=edge_destination_arn + ) + + # remove the duplicate edges from cross account + new_edge = [e for e in query_response.edges if not e.source_arn == e.destination_arn] + query_response.edges = new_edge + + return query_response + + def _update_cross_account_edge(self, edges, arn, duplicate_arn): + """Replace the duplicate arn with arn in edges list.""" + for idx, e in enumerate(edges): + if e.destination_arn == duplicate_arn: + edges[idx].destination_arn = arn + elif e.source_arn == duplicate_arn: + edges[idx].source_arn = arn + + def _update_cross_account_vertex(self, query_response, duplicate_arn): + """Remove the vertex with duplicate arn in the vertices list.""" + query_response.vertices = [v for v in query_response.vertices if not v.arn == duplicate_arn] + def query( self, start_arns: List[str], @@ -235,5 +277,7 @@ def query( Filters=query_filter._to_request_dict() if query_filter else {}, MaxDepth=max_depth, ) + query_response = self._convert_api_response(query_response) + query_response = self._collapse_cross_account_artifacts(query_response) - return self._convert_api_response(query_response) + return query_response diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 5af5539a96..830bb50dab 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1131,6 +1131,7 @@ def _upload_code(self, key_prefix, repack=False): script=self.entry_point, directory=self.source_dir, dependencies=self.dependencies, + settings=self.sagemaker_session.settings, ) if repack and self.model_data is not None and self.entry_point is not None: diff --git a/src/sagemaker/model_monitor/clarify_model_monitoring.py b/src/sagemaker/model_monitor/clarify_model_monitoring.py index 10da0bf6c9..09de7b5c05 100644 --- a/src/sagemaker/model_monitor/clarify_model_monitoring.py +++ b/src/sagemaker/model_monitor/clarify_model_monitoring.py @@ -26,7 +26,7 @@ from sagemaker import image_uris, s3 from sagemaker.session import Session from sagemaker.utils import name_from_base -from sagemaker.clarify import SageMakerClarifyProcessor +from sagemaker.clarify import SageMakerClarifyProcessor, ModelPredictedLabelConfig _LOGGER = logging.getLogger(__name__) @@ -833,9 +833,10 @@ def suggest_baseline( specific explainability method. Currently, only SHAP is supported. model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its endpoint to be created. - model_scores (int or str): Index or JSONPath location in the model output for the - predicted scores to be explained. This is not required if the model output is - a single score. + model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`): + Index or JSONPath to locate the predicted scores in the model output. This is not + required if the model output is a single score. Alternatively, it can be an instance + of ModelPredictedLabelConfig to provide more parameters like label_headers. wait (bool): Whether the call should wait until the job completes (default: False). logs (bool): Whether to show the logs produced by the job. Only meaningful when wait is True (default: False). @@ -865,14 +866,24 @@ def suggest_baseline( headers = copy.deepcopy(data_config.headers) if headers and data_config.label in headers: headers.remove(data_config.label) + if model_scores is None: + inference_attribute = None + label_headers = None + elif isinstance(model_scores, ModelPredictedLabelConfig): + inference_attribute = str(model_scores.label) + label_headers = model_scores.label_headers + else: + inference_attribute = str(model_scores) + label_headers = None self.latest_baselining_job_config = ClarifyBaseliningConfig( analysis_config=ExplainabilityAnalysisConfig( explainability_config=explainability_config, model_config=model_config, headers=headers, + label_headers=label_headers, ), features_attribute=data_config.features, - inference_attribute=model_scores if model_scores is None else str(model_scores), + inference_attribute=inference_attribute, ) self.latest_baselining_job_name = baselining_job_name self.latest_baselining_job = ClarifyBaseliningJob( @@ -1166,7 +1177,7 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None): class ExplainabilityAnalysisConfig: """Analysis configuration for ModelExplainabilityMonitor.""" - def __init__(self, explainability_config, model_config, headers=None): + def __init__(self, explainability_config, model_config, headers=None, label_headers=None): """Creates an analysis config dictionary. Args: @@ -1175,13 +1186,19 @@ def __init__(self, explainability_config, model_config, headers=None): model_config (sagemaker.clarify.ModelConfig): Config object related to bias configurations. headers (list[str]): A list of feature names (without label) of model/endpint input. + label_headers (list[str]): List of headers, each for a predicted score in model output. + It is used to beautify the analysis report by replacing placeholders like "label0". + """ + predictor_config = model_config.get_predictor_config() self.analysis_config = { "methods": explainability_config.get_explainability_config(), - "predictor": model_config.get_predictor_config(), + "predictor": predictor_config, } if headers is not None: self.analysis_config["headers"] = headers + if label_headers is not None: + predictor_config["label_headers"] = label_headers def _to_dict(self): """Generates a request dictionary using the parameters provided to the class.""" diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 189c9cb308..56f008be84 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -42,6 +42,7 @@ sts_regional_endpoint, ) from sagemaker import exceptions +from sagemaker.session_settings import SessionSettings LOGGER = logging.getLogger("sagemaker") @@ -85,6 +86,7 @@ def __init__( sagemaker_runtime_client=None, sagemaker_featurestore_runtime_client=None, default_bucket=None, + settings=SessionSettings(), ): """Initialize a SageMaker ``Session``. @@ -110,6 +112,8 @@ def __init__( If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". Example: "sagemaker-my-custom-bucket". + settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional + parameters to apply to the session. """ self._default_bucket = None self._default_bucket_name_override = default_bucket @@ -117,6 +121,7 @@ def __init__( self.s3_client = None self.config = None self.lambda_client = None + self.settings = settings self._initialize( boto_session=boto_session, diff --git a/src/sagemaker/session_settings.py b/src/sagemaker/session_settings.py new file mode 100644 index 0000000000..53ff9a9f0d --- /dev/null +++ b/src/sagemaker/session_settings.py @@ -0,0 +1,34 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Defines classes to parametrize a SageMaker ``Session``.""" + +from __future__ import absolute_import + + +class SessionSettings(object): + """Optional container class for settings to apply to a SageMaker session.""" + + def __init__(self, encrypt_repacked_artifacts=True) -> None: + """Initialize the ``SessionSettings`` of a SageMaker ``Session``. + + Args: + encrypt_repacked_artifacts (bool): Flag to indicate whether to encrypt the artifacts + at rest in S3 using the default AWS managed KMS key for S3 when a custom KMS key + is not provided (Default: True). + """ + self._encrypt_repacked_artifacts = encrypt_repacked_artifacts + + @property + def encrypt_repacked_artifacts(self) -> bool: + """Return True if repacked artifacts at rest in S3 should be encrypted by default.""" + return self._encrypt_repacked_artifacts diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 4409c0b954..5c617b0155 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -29,6 +29,7 @@ from six.moves.urllib import parse from sagemaker import deprecations +from sagemaker.session_settings import SessionSettings ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$" @@ -429,8 +430,15 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key): bucket, key = url.netloc, url.path.lstrip("/") new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri)) + settings = ( + sagemaker_session.settings if sagemaker_session is not None else SessionSettings() + ) + encrypt_artifact = settings.encrypt_repacked_artifacts + if kms_key: extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key} + elif encrypt_artifact: + extra_args = {"ServerSideEncryption": "aws:kms"} else: extra_args = None sagemaker_session.boto_session.resource( diff --git a/src/sagemaker/workflow/parallelism_config.py b/src/sagemaker/workflow/parallelism_config.py new file mode 100644 index 0000000000..72c180517a --- /dev/null +++ b/src/sagemaker/workflow/parallelism_config.py @@ -0,0 +1,34 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Pipeline Parallelism Configuration""" +from __future__ import absolute_import +from sagemaker.workflow.entities import RequestType + + +class ParallelismConfiguration: + """Parallelism config for SageMaker pipeline.""" + + def __init__(self, max_parallel_execution_steps: int): + """Create a ParallelismConfiguration + + Args: + max_parallel_execution_steps, int: + max number of steps which could be parallelized + """ + self.max_parallel_execution_steps = max_parallel_execution_steps + + def to_request(self) -> RequestType: + """Returns: the request structure.""" + return { + "MaxParallelExecutionSteps": self.max_parallel_execution_steps, + } diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 4982c6f5fd..606ba38bc2 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -22,6 +22,7 @@ import botocore from botocore.exceptions import ClientError +from sagemaker import s3 from sagemaker._studio import _append_project_tags from sagemaker.session import Session from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep @@ -34,6 +35,7 @@ from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.parameters import Parameter from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig +from sagemaker.workflow.parallelism_config import ParallelismConfiguration from sagemaker.workflow.properties import Properties from sagemaker.workflow.steps import Step from sagemaker.workflow.step_collections import StepCollection @@ -94,6 +96,7 @@ def create( role_arn: str, description: str = None, tags: List[Dict[str, str]] = None, + parallelism_config: ParallelismConfiguration = None, ) -> Dict[str, Any]: """Creates a Pipeline in the Pipelines service. @@ -102,37 +105,62 @@ def create( description (str): A description of the pipeline. tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as tags. + parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration + that is applied to each of the executions of the pipeline. It takes precedence + over the parallelism configuration of the parent pipeline. Returns: A response dict from the service. """ tags = _append_project_tags(tags) - - kwargs = self._create_args(role_arn, description) + kwargs = self._create_args(role_arn, description, parallelism_config) update_args( kwargs, Tags=tags, ) return self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs) - def _create_args(self, role_arn: str, description: str): + def _create_args( + self, role_arn: str, description: str, parallelism_config: ParallelismConfiguration + ): """Constructs the keyword argument dict for a create_pipeline call. Args: role_arn (str): The role arn that is assumed by pipelines to create step artifacts. description (str): A description of the pipeline. + parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration + that is applied to each of the executions of the pipeline. It takes precedence + over the parallelism configuration of the parent pipeline. Returns: A keyword argument dict for calling create_pipeline. """ + pipeline_definition = self.definition() kwargs = dict( PipelineName=self.name, - PipelineDefinition=self.definition(), RoleArn=role_arn, ) + + # If pipeline definition is large, upload to S3 bucket and + # provide PipelineDefinitionS3Location to request instead. + if len(pipeline_definition.encode("utf-8")) < 1024 * 100: + kwargs["PipelineDefinition"] = pipeline_definition + else: + desired_s3_uri = s3.s3_path_join( + "s3://", self.sagemaker_session.default_bucket(), self.name + ) + s3.S3Uploader.upload_string_as_file_body( + body=pipeline_definition, + desired_s3_uri=desired_s3_uri, + sagemaker_session=self.sagemaker_session, + ) + kwargs["PipelineDefinitionS3Location"] = { + "Bucket": self.sagemaker_session.default_bucket(), + "ObjectKey": self.name, + } + update_args( - kwargs, - PipelineDescription=description, + kwargs, PipelineDescription=description, ParallelismConfiguration=parallelism_config ) return kwargs @@ -146,17 +174,25 @@ def describe(self) -> Dict[str, Any]: """ return self.sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=self.name) - def update(self, role_arn: str, description: str = None) -> Dict[str, Any]: + def update( + self, + role_arn: str, + description: str = None, + parallelism_config: ParallelismConfiguration = None, + ) -> Dict[str, Any]: """Updates a Pipeline in the Workflow service. Args: role_arn (str): The role arn that is assumed by pipelines to create step artifacts. description (str): A description of the pipeline. + parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration + that is applied to each of the executions of the pipeline. It takes precedence + over the parallelism configuration of the parent pipeline. Returns: A response dict from the service. """ - kwargs = self._create_args(role_arn, description) + kwargs = self._create_args(role_arn, description, parallelism_config) return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs) def upsert( @@ -164,6 +200,7 @@ def upsert( role_arn: str, description: str = None, tags: List[Dict[str, str]] = None, + parallelism_config: ParallelismConfiguration = None, ) -> Dict[str, Any]: """Creates a pipeline or updates it, if it already exists. @@ -172,12 +209,14 @@ def upsert( description (str): A description of the pipeline. tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as tags. + parallelism_config (Optional[Config for parallel steps, Parallelism configuration that + is applied to each of. the executions Returns: response dict from service """ try: - response = self.create(role_arn, description, tags) + response = self.create(role_arn, description, tags, parallelism_config) except ClientError as e: error = e.response["Error"] if ( @@ -215,6 +254,7 @@ def start( parameters: Dict[str, Union[str, bool, int, float]] = None, execution_display_name: str = None, execution_description: str = None, + parallelism_config: ParallelismConfiguration = None, ): """Starts a Pipeline execution in the Workflow service. @@ -223,6 +263,9 @@ def start( pipeline parameters. execution_display_name (str): The display name of the pipeline execution. execution_description (str): A description of the execution. + parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration + that is applied to each of the executions of the pipeline. It takes precedence + over the parallelism configuration of the parent pipeline. Returns: A `_PipelineExecution` instance, if successful. @@ -245,6 +288,7 @@ def start( PipelineParameters=format_start_parameters(parameters), PipelineExecutionDescription=execution_description, PipelineExecutionDisplayName=execution_display_name, + ParallelismConfiguration=parallelism_config, ) response = self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs) return _PipelineExecution( diff --git a/tests/integ/sagemaker/lineage/test_dataset_artifact.py b/tests/integ/sagemaker/lineage/test_dataset_artifact.py index 4b1d35aa16..be03a85e86 100644 --- a/tests/integ/sagemaker/lineage/test_dataset_artifact.py +++ b/tests/integ/sagemaker/lineage/test_dataset_artifact.py @@ -12,11 +12,9 @@ # language governing permissions and limitations under the License. """This module contains code to test SageMaker ``DatasetArtifact``""" from __future__ import absolute_import -from tests.integ.sagemaker.lineage.helpers import traverse_graph_forward def test_trained_models( - sagemaker_session, dataset_artifact_associated_models, trial_component_obj, model_artifact_obj1, @@ -31,20 +29,9 @@ def test_trained_models( def test_endpoint_contexts( static_dataset_artifact, - sagemaker_session, ): contexts_from_query = static_dataset_artifact.endpoint_contexts() - associations_from_api = traverse_graph_forward( - static_dataset_artifact.artifact_arn, sagemaker_session=sagemaker_session - ) - assert len(contexts_from_query) > 0 for context in contexts_from_query: - # assert that the contexts from the query - # appear in the association list from the lineage API - assert any( - x - for x in associations_from_api - if x["DestinationArn"] == context.context_arn and x["DestinationType"] == "Endpoint" - ) + assert context.context_type == "Endpoint" diff --git a/tests/integ/sagemaker/lineage/test_endpoint_context.py b/tests/integ/sagemaker/lineage/test_endpoint_context.py index d3b0c225bd..07cc48142d 100644 --- a/tests/integ/sagemaker/lineage/test_endpoint_context.py +++ b/tests/integ/sagemaker/lineage/test_endpoint_context.py @@ -12,15 +12,9 @@ # language governing permissions and limitations under the License. """This module contains code to test SageMaker ``Contexts``""" from __future__ import absolute_import -from tests.integ.sagemaker.lineage.helpers import traverse_graph_back -def test_model( - endpoint_context_associate_with_model, - model_obj, - endpoint_action_obj, - sagemaker_session, -): +def test_model(endpoint_context_associate_with_model, model_obj, endpoint_action_obj): model_list = endpoint_context_associate_with_model.models() for model in model_list: assert model.source_arn == endpoint_action_obj.action_arn @@ -29,25 +23,12 @@ def test_model( assert model.destination_type == "Model" -def test_dataset_artifacts( - static_endpoint_context, - sagemaker_session, -): +def test_dataset_artifacts(static_endpoint_context): artifacts_from_query = static_endpoint_context.dataset_artifacts() - associations_from_api = traverse_graph_back( - static_endpoint_context.context_arn, sagemaker_session=sagemaker_session - ) - assert len(artifacts_from_query) > 0 for artifact in artifacts_from_query: - # assert that the artifacts from the query - # appear in the association list from the lineage API - assert any( - x - for x in associations_from_api - if x["SourceArn"] == artifact.artifact_arn and x["SourceType"] == "DataSet" - ) + assert artifact.artifact_type == "DataSet" def test_training_job_arns( diff --git a/tests/integ/sagemaker/lineage/test_model_artifact.py b/tests/integ/sagemaker/lineage/test_model_artifact.py index ca4dc2d94c..8d9048726d 100644 --- a/tests/integ/sagemaker/lineage/test_model_artifact.py +++ b/tests/integ/sagemaker/lineage/test_model_artifact.py @@ -12,11 +12,9 @@ # language governing permissions and limitations under the License. """This module contains code to test SageMaker ``DatasetArtifact``""" from __future__ import absolute_import -from tests.integ.sagemaker.lineage.helpers import traverse_graph_forward, traverse_graph_back def test_endpoints( - sagemaker_session, model_artifact_associated_endpoints, endpoint_deployment_action_obj, endpoint_context_obj, @@ -32,44 +30,22 @@ def test_endpoints( def test_endpoint_contexts( static_model_artifact, - sagemaker_session, ): contexts_from_query = static_model_artifact.endpoint_contexts() - associations_from_api = traverse_graph_forward( - static_model_artifact.artifact_arn, sagemaker_session=sagemaker_session - ) - assert len(contexts_from_query) > 0 for context in contexts_from_query: - # assert that the contexts from the query - # appear in the association list from the lineage API - assert any( - x - for x in associations_from_api - if x["DestinationArn"] == context.context_arn and x["DestinationType"] == "Endpoint" - ) + assert context.context_type == "Endpoint" def test_dataset_artifacts( static_model_artifact, - sagemaker_session, ): artifacts_from_query = static_model_artifact.dataset_artifacts() - associations_from_api = traverse_graph_back( - static_model_artifact.artifact_arn, sagemaker_session=sagemaker_session - ) - assert len(artifacts_from_query) > 0 for artifact in artifacts_from_query: - # assert that the artifacts from the query - # appear in the association list from the lineage API - assert any( - x - for x in associations_from_api - if x["SourceArn"] == artifact.artifact_arn and x["SourceType"] == "DataSet" - ) + assert artifact.artifact_type == "DataSet" def test_training_job_arns( diff --git a/tests/integ/test_clarify_model_monitor.py b/tests/integ/test_clarify_model_monitor.py index 6891082285..3f48fa1032 100644 --- a/tests/integ/test_clarify_model_monitor.py +++ b/tests/integ/test_clarify_model_monitor.py @@ -53,6 +53,7 @@ HEADER_OF_LABEL = "Label" HEADERS_OF_FEATURES = ["F1", "F2", "F3", "F4", "F5", "F6", "F7"] ALL_HEADERS = [*HEADERS_OF_FEATURES, HEADER_OF_LABEL] +HEADER_OF_PREDICTION = "Decision" DATASET_TYPE = "text/csv" CONTENT_TYPE = DATASET_TYPE ACCEPT_TYPE = DATASET_TYPE @@ -325,7 +326,7 @@ def scheduled_explainability_monitor( ): monitor_schedule_name = utils.unique_name_from_base("explainability-monitor") analysis_config = ExplainabilityAnalysisConfig( - shap_config, model_config, headers=HEADERS_OF_FEATURES + shap_config, model_config, headers=HEADERS_OF_FEATURES, label_headers=[HEADER_OF_PREDICTION] ) s3_uri_monitoring_output = os.path.join( "s3://", diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index 2fe674a203..58b681fd0e 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -2757,3 +2757,99 @@ def cleanup_feature_group(feature_group: FeatureGroup): except Exception as e: print(f"Delete FeatureGroup failed with error: {e}.") pass + + +def test_large_pipeline(sagemaker_session, role, pipeline_name, region_name): + instance_count = ParameterInteger(name="InstanceCount", default_value=2) + + outputParam = CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String) + + callback_steps = [ + CallbackStep( + name=f"callback-step{count}", + sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue", + inputs={"arg1": "foo"}, + outputs=[outputParam], + ) + for count in range(2000) + ] + pipeline = Pipeline( + name=pipeline_name, + parameters=[instance_count], + steps=callback_steps, + sagemaker_session=sagemaker_session, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", + create_arn, + ) + response = pipeline.describe() + assert len(json.loads(pipeline.describe()["PipelineDefinition"])["Steps"]) == 2000 + + pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)] + response = pipeline.update(role) + update_arn = response["PipelineArn"] + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", + update_arn, + ) + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_create_and_update_with_parallelism_config( + sagemaker_session, role, pipeline_name, region_name +): + instance_count = ParameterInteger(name="InstanceCount", default_value=2) + + outputParam = CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String) + + callback_steps = [ + CallbackStep( + name=f"callback-step{count}", + sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue", + inputs={"arg1": "foo"}, + outputs=[outputParam], + ) + for count in range(500) + ] + pipeline = Pipeline( + name=pipeline_name, + parameters=[instance_count], + steps=callback_steps, + sagemaker_session=sagemaker_session, + ) + + try: + response = pipeline.create(role, parallelism_config={"MaxParallelExecutionSteps": 50}) + create_arn = response["PipelineArn"] + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", + create_arn, + ) + response = pipeline.describe() + assert response["ParallelismConfiguration"]["MaxParallelExecutionSteps"] == 50 + + pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)] + response = pipeline.update(role, parallelism_config={"MaxParallelExecutionSteps": 55}) + update_arn = response["PipelineArn"] + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", + update_arn, + ) + + response = pipeline.describe() + assert response["ParallelismConfiguration"]["MaxParallelExecutionSteps"] == 55 + + finally: + try: + pipeline.delete() + except Exception: + pass diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index 17d3eabe92..595e7e1d0f 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -13,6 +13,7 @@ from __future__ import absolute_import from sagemaker.lineage.artifact import DatasetArtifact, ModelArtifact, Artifact from sagemaker.lineage.context import EndpointContext, Context +from sagemaker.lineage.action import Action from sagemaker.lineage.query import LineageEntityEnum, LineageSourceEnum, Vertex, LineageQuery import pytest @@ -44,6 +45,143 @@ def test_lineage_query(sagemaker_session): assert response.vertices[1].lineage_entity == "Context" +def test_lineage_query_cross_account_same_artifact(sagemaker_session): + lineage_query = LineageQuery(sagemaker_session) + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + { + "Arn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", + "Type": "Endpoint", + "LineageType": "Artifact", + }, + { + "Arn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "Type": "Endpoint", + "LineageType": "Artifact", + }, + ], + "Edges": [ + { + "SourceArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", + "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "AssociationType": "SAME_AS", + }, + { + "SourceArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", + "AssociationType": "SAME_AS", + }, + ], + } + + response = lineage_query.query( + start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] + ) + assert len(response.edges) == 0 + assert len(response.vertices) == 1 + assert ( + response.vertices[0].arn + == "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0" + ) + assert response.vertices[0].lineage_source == "Endpoint" + assert response.vertices[0].lineage_entity == "Artifact" + + +def test_lineage_query_cross_account(sagemaker_session): + lineage_query = LineageQuery(sagemaker_session) + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + { + "Arn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", + "Type": "Endpoint", + "LineageType": "Artifact", + }, + { + "Arn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "Type": "Endpoint", + "LineageType": "Artifact", + }, + { + "Arn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd", + "Type": "Endpoint", + "LineageType": "Artifact", + }, + { + "Arn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9efgh", + "Type": "Endpoint", + "LineageType": "Artifact", + }, + ], + "Edges": [ + { + "SourceArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", + "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "AssociationType": "SAME_AS", + }, + { + "SourceArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", + "AssociationType": "SAME_AS", + }, + { + "SourceArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd", + "AssociationType": "ABC", + }, + { + "SourceArn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd", + "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9efgh", + "AssociationType": "DEF", + }, + ], + } + + response = lineage_query.query( + start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] + ) + + assert len(response.edges) == 2 + assert ( + response.edges[0].source_arn + == "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0" + ) + assert ( + response.edges[0].destination_arn + == "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd" + ) + assert response.edges[0].association_type == "ABC" + + assert ( + response.edges[1].source_arn + == "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd" + ) + assert ( + response.edges[1].destination_arn + == "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9efgh" + ) + assert response.edges[1].association_type == "DEF" + + assert len(response.vertices) == 3 + assert ( + response.vertices[0].arn + == "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0" + ) + assert response.vertices[0].lineage_source == "Endpoint" + assert response.vertices[0].lineage_entity == "Artifact" + assert ( + response.vertices[1].arn + == "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd" + ) + assert response.vertices[1].lineage_source == "Endpoint" + assert response.vertices[1].lineage_entity == "Artifact" + assert ( + response.vertices[2].arn + == "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9efgh" + ) + assert response.vertices[2].lineage_source == "Endpoint" + assert response.vertices[2].lineage_entity == "Artifact" + + def test_vertex_to_object_endpoint_context(sagemaker_session): vertex = Vertex( arn="arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext", @@ -240,10 +378,38 @@ def test_vertex_to_object_artifact(sagemaker_session): assert isinstance(artifact, Artifact) +def test_vertex_to_object_action(sagemaker_session): + vertex = Vertex( + arn="arn:aws:sagemaker:us-west-2:0123456789012:action/cp-m5-20210424t041405868z-1619237657-1-aws-endpoint", + lineage_entity=LineageEntityEnum.ACTION.value, + lineage_source="A", + sagemaker_session=sagemaker_session, + ) + + sagemaker_session.sagemaker_client.describe_action.return_value = { + "ActionName": "cp-m5-20210424t041405868z-1619237657-1-aws-endpoint", + "Source": { + "SourceUri": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3", + "SourceTypes": [], + }, + "ActionType": "A", + "Properties": {}, + "CreationTime": 1608224704.149, + "CreatedBy": {}, + "LastModifiedTime": 1608224704.149, + "LastModifiedBy": {}, + } + + action = vertex.to_lineage_object() + + assert action.action_name == "cp-m5-20210424t041405868z-1619237657-1-aws-endpoint" + assert isinstance(action, Action) + + def test_vertex_to_object_unconvertable(sagemaker_session): vertex = Vertex( arn="arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", - lineage_entity=LineageEntityEnum.ACTION.value, + lineage_entity=LineageEntityEnum.TRIAL_COMPONENT.value, lineage_source=LineageSourceEnum.TENSORBOARD.value, sagemaker_session=sagemaker_session, ) diff --git a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py index e13755f208..7c1d497d64 100644 --- a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py +++ b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py @@ -279,6 +279,7 @@ # for bias ANALYSIS_CONFIG_LABEL = "Label" ANALYSIS_CONFIG_HEADERS_OF_FEATURES = ["F1", "F2", "F3"] +ANALYSIS_CONFIG_LABEL_HEADERS = ["Decision"] ANALYSIS_CONFIG_ALL_HEADERS = [*ANALYSIS_CONFIG_HEADERS_OF_FEATURES, ANALYSIS_CONFIG_LABEL] ANALYSIS_CONFIG_LABEL_VALUES = [1] ANALYSIS_CONFIG_FACET_NAME = "F1" @@ -330,6 +331,11 @@ "content_type": CONTENT_TYPE, }, } +EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS = copy.deepcopy(EXPLAINABILITY_ANALYSIS_CONFIG) +# noinspection PyTypeChecker +EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS["predictor"][ + "label_headers" +] = ANALYSIS_CONFIG_LABEL_HEADERS @pytest.fixture() @@ -1048,12 +1054,31 @@ def test_explainability_analysis_config(shap_config, model_config): explainability_config=shap_config, model_config=model_config, headers=ANALYSIS_CONFIG_HEADERS_OF_FEATURES, + label_headers=ANALYSIS_CONFIG_LABEL_HEADERS, ) - assert EXPLAINABILITY_ANALYSIS_CONFIG == config._to_dict() + assert EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS == config._to_dict() +@pytest.mark.parametrize( + "model_scores,explainability_analysis_config", + [ + (INFERENCE_ATTRIBUTE, EXPLAINABILITY_ANALYSIS_CONFIG), + ( + ModelPredictedLabelConfig( + label=INFERENCE_ATTRIBUTE, label_headers=ANALYSIS_CONFIG_LABEL_HEADERS + ), + EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS, + ), + ], +) def test_model_explainability_monitor_suggest_baseline( - model_explainability_monitor, sagemaker_session, data_config, shap_config, model_config + model_explainability_monitor, + sagemaker_session, + data_config, + shap_config, + model_config, + model_scores, + explainability_analysis_config, ): clarify_model_monitor = model_explainability_monitor # suggest baseline @@ -1061,12 +1086,12 @@ def test_model_explainability_monitor_suggest_baseline( data_config=data_config, explainability_config=shap_config, model_config=model_config, - model_scores=INFERENCE_ATTRIBUTE, + model_scores=model_scores, job_name=BASELINING_JOB_NAME, ) assert isinstance(clarify_model_monitor.latest_baselining_job, ClarifyBaseliningJob) assert ( - EXPLAINABILITY_ANALYSIS_CONFIG + explainability_analysis_config == clarify_model_monitor.latest_baselining_job_config.analysis_config._to_dict() ) clarify_baselining_job = clarify_model_monitor.latest_baselining_job @@ -1081,6 +1106,7 @@ def test_model_explainability_monitor_suggest_baseline( analysis_config=None, # will pick up config from baselining job baseline_job_name=BASELINING_JOB_NAME, endpoint_input=ENDPOINT_NAME, + explainability_analysis_config=explainability_analysis_config, # will pick up attributes from baselining job ) @@ -1133,6 +1159,7 @@ def test_model_explainability_monitor_created_with_config( sagemaker_session=sagemaker_session, analysis_config=analysis_config, constraints=CONSTRAINTS, + explainability_analysis_config=EXPLAINABILITY_ANALYSIS_CONFIG, ) # update schedule @@ -1263,6 +1290,7 @@ def _test_model_explainability_monitor_create_schedule( features_attribute=FEATURES_ATTRIBUTE, inference_attribute=str(INFERENCE_ATTRIBUTE), ), + explainability_analysis_config=None, ): # create schedule with patch( @@ -1278,7 +1306,7 @@ def _test_model_explainability_monitor_create_schedule( ) if not isinstance(analysis_config, str): upload.assert_called_once() - assert json.loads(upload.call_args[0][0]) == EXPLAINABILITY_ANALYSIS_CONFIG + assert json.loads(upload.call_args[0][0]) == explainability_analysis_config # validation expected_arguments = { diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index 4b68abceeb..be90a8a876 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -21,9 +21,11 @@ from mock import Mock +from sagemaker import s3 from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.parallelism_config import ParallelismConfiguration from sagemaker.workflow.pipeline_experiment_config import ( PipelineExperimentConfig, PipelineExperimentConfigProperties, @@ -62,7 +64,9 @@ def role_arn(): @pytest.fixture def sagemaker_session_mock(): - return Mock() + session_mock = Mock() + session_mock.default_bucket = Mock(name="default_bucket", return_value="s3_bucket") + return session_mock def test_pipeline_create(sagemaker_session_mock, role_arn): @@ -78,6 +82,47 @@ def test_pipeline_create(sagemaker_session_mock, role_arn): ) +def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_arn): + pipeline = Pipeline( + name="MyPipeline", + parameters=[], + steps=[], + pipeline_experiment_config=ParallelismConfiguration(max_parallel_execution_steps=10), + sagemaker_session=sagemaker_session_mock, + ) + pipeline.create(role_arn=role_arn) + assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with( + PipelineName="MyPipeline", + PipelineDefinition=pipeline.definition(), + RoleArn=role_arn, + ParallelismConfiguration={"MaxParallelExecutionSteps": 10}, + ) + + +def test_large_pipeline_create(sagemaker_session_mock, role_arn): + parameter = ParameterString("MyStr") + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000, + sagemaker_session=sagemaker_session_mock, + ) + + s3.S3Uploader.upload_string_as_file_body = Mock() + + pipeline.create(role_arn=role_arn) + + assert s3.S3Uploader.upload_string_as_file_body.called_with( + body=pipeline.definition(), s3_uri="s3://s3_bucket/MyPipeline" + ) + + assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with( + PipelineName="MyPipeline", + PipelineDefinitionS3Location={"Bucket": "s3_bucket", "ObjectKey": "MyPipeline"}, + RoleArn=role_arn, + ) + + def test_pipeline_update(sagemaker_session_mock, role_arn): pipeline = Pipeline( name="MyPipeline", @@ -91,6 +136,47 @@ def test_pipeline_update(sagemaker_session_mock, role_arn): ) +def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_arn): + pipeline = Pipeline( + name="MyPipeline", + parameters=[], + steps=[], + pipeline_experiment_config=ParallelismConfiguration(max_parallel_execution_steps=10), + sagemaker_session=sagemaker_session_mock, + ) + pipeline.create(role_arn=role_arn) + assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + PipelineName="MyPipeline", + PipelineDefinition=pipeline.definition(), + RoleArn=role_arn, + ParallelismConfiguration={"MaxParallelExecutionSteps": 10}, + ) + + +def test_large_pipeline_update(sagemaker_session_mock, role_arn): + parameter = ParameterString("MyStr") + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000, + sagemaker_session=sagemaker_session_mock, + ) + + s3.S3Uploader.upload_string_as_file_body = Mock() + + pipeline.create(role_arn=role_arn) + + assert s3.S3Uploader.upload_string_as_file_body.called_with( + body=pipeline.definition(), s3_uri="s3://s3_bucket/MyPipeline" + ) + + assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + PipelineName="MyPipeline", + PipelineDefinitionS3Location={"Bucket": "s3_bucket", "ObjectKey": "MyPipeline"}, + RoleArn=role_arn, + ) + + def test_pipeline_upsert(sagemaker_session_mock, role_arn): sagemaker_session_mock.side_effect = [ ClientError( diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 69e030b567..248eda1aa5 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -2323,8 +2323,8 @@ def test_different_code_location_kms_key(utils, sagemaker_session): obj = sagemaker_session.boto_session.resource("s3").Object obj.assert_called_with("another-location", "%s/source/sourcedir.tar.gz" % fw._current_job_name) - - obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=None) + extra_args = {"ServerSideEncryption": "aws:kms"} + obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args) @patch("sagemaker.utils") diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index be70182be8..c2470a5ba6 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -24,6 +24,7 @@ from sagemaker import fw_utils from sagemaker.utils import name_from_image +from sagemaker.session_settings import SessionSettings TIMESTAMP = "2017-10-10-14-14-15" @@ -93,6 +94,40 @@ def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session): obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args) +@patch("sagemaker.utils") +def test_tar_and_upload_dir_s3_kms_enabled_by_default(utils, sagemaker_session): + bucket = "mybucket" + s3_key_prefix = "something/source" + script = "inference.py" + result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script) + + assert result == fw_utils.UploadedCode( + "s3://{}/{}/sourcedir.tar.gz".format(bucket, s3_key_prefix), script + ) + + extra_args = {"ServerSideEncryption": "aws:kms"} + obj = sagemaker_session.resource("s3").Object("", "") + obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args) + + +@patch("sagemaker.utils") +def test_tar_and_upload_dir_s3_without_kms_with_overridden_settings(utils, sagemaker_session): + bucket = "mybucket" + s3_key_prefix = "something/source" + script = "inference.py" + settings = SessionSettings(encrypt_repacked_artifacts=False) + result = fw_utils.tar_and_upload_dir( + sagemaker_session, bucket, s3_key_prefix, script, settings=settings + ) + + assert result == fw_utils.UploadedCode( + "s3://{}/{}/sourcedir.tar.gz".format(bucket, s3_key_prefix), script + ) + + obj = sagemaker_session.resource("s3").Object("", "") + obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=None) + + def test_mp_config_partition_exists(): mp_parameters = {} with pytest.raises(ValueError): @@ -658,6 +693,7 @@ def test_validate_smdataparallel_args_not_raises(): ("ml.p3.16xlarge", "pytorch", "1.8", "py3", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.9", "py38", smdataparallel_enabled), + ("ml.p3.16xlarge", "pytorch", "1.10", "py38", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.5.1", "py37", smdataparallel_enabled_custom_mpi), diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 5c0b217299..4b8ce1de20 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -27,6 +27,7 @@ from mock import call, patch, Mock, MagicMock import sagemaker +from sagemaker.session_settings import SessionSettings BUCKET_WITHOUT_WRITING_PERMISSION = "s3://bucket-without-writing-permission" @@ -390,6 +391,13 @@ def test_repack_model_without_source_dir(tmp, fake_s3): "/code/inference.py", } + extra_args = {"ServerSideEncryption": "aws:kms"} + object_mock = fake_s3.object_mock + _, _, kwargs = object_mock.mock_calls[0] + + assert "ExtraArgs" in kwargs + assert kwargs["ExtraArgs"] == extra_args + def test_repack_model_with_entry_point_without_path_without_source_dir(tmp, fake_s3): @@ -415,12 +423,20 @@ def test_repack_model_with_entry_point_without_path_without_source_dir(tmp, fake "s3://fake/location", "s3://destination-bucket/model.tar.gz", fake_s3.sagemaker_session, + kms_key="kms_key", ) finally: os.chdir(cwd) assert list_tar_files(fake_s3.fake_upload_path, tmp) == {"/code/inference.py", "/model"} + extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": "kms_key"} + object_mock = fake_s3.object_mock + _, _, kwargs = object_mock.mock_calls[0] + + assert "ExtraArgs" in kwargs + assert kwargs["ExtraArgs"] == extra_args + def test_repack_model_from_s3_to_s3(tmp, fake_s3): @@ -434,6 +450,7 @@ def test_repack_model_from_s3_to_s3(tmp, fake_s3): ) fake_s3.tar_and_upload("model-dir", "s3://fake/location") + fake_s3.sagemaker_session.settings = SessionSettings(encrypt_repacked_artifacts=False) sagemaker.utils.repack_model( "inference.py", @@ -450,6 +467,11 @@ def test_repack_model_from_s3_to_s3(tmp, fake_s3): "/model", } + object_mock = fake_s3.object_mock + _, _, kwargs = object_mock.mock_calls[0] + assert "ExtraArgs" in kwargs + assert kwargs["ExtraArgs"] is None + def test_repack_model_from_file_to_file(tmp): create_file_tree(tmp, ["model", "dependencies/a", "source-dir/inference.py"]) @@ -581,6 +603,7 @@ def __init__(self, tmp): self.sagemaker_session = MagicMock() self.location_map = {} self.current_bucket = None + self.object_mock = MagicMock() self.sagemaker_session.boto_session.resource().Bucket().download_file.side_effect = ( self.download_file @@ -606,6 +629,7 @@ def tar_and_upload(self, path, fake_location): def mock_s3_upload(self): dst = os.path.join(self.tmp, "dst") + object_mock = self.object_mock class MockS3Object(object): def __init__(self, bucket, key): @@ -616,6 +640,7 @@ def upload_file(self, target, **kwargs): if self.bucket in BUCKET_WITHOUT_WRITING_PERMISSION: raise exceptions.S3UploadFailedError() shutil.copy2(target, dst) + object_mock.upload_file(target, **kwargs) self.sagemaker_session.boto_session.resource().Object = MockS3Object return dst diff --git a/tox.ini b/tox.ini index b8dc0292f9..d9e3b41b41 100644 --- a/tox.ini +++ b/tox.ini @@ -19,6 +19,7 @@ exclude = .tox tests/data/ venv/ + env/ max-complexity = 10