From d09d2625e1fc9d2885f6967c847447dc99012dc3 Mon Sep 17 00:00:00 2001 From: chenxy Date: Tue, 21 Sep 2021 02:46:44 -0400 Subject: [PATCH 1/5] feature: Add EMRStep support in Sagemaker pipeline --- src/sagemaker/workflow/emr_step.py | 119 ++++++++++++ src/sagemaker/workflow/properties.py | 59 ++++-- src/sagemaker/workflow/steps.py | 1 + tests/data/workflow/emr-script.sh | 2 + tests/integ/test_workflow.py | 71 +++++++ .../unit/sagemaker/workflow/test_emr_step.py | 175 ++++++++++++++++++ .../sagemaker/workflow/test_properties.py | 13 ++ 7 files changed, 420 insertions(+), 20 deletions(-) create mode 100644 src/sagemaker/workflow/emr_step.py create mode 100644 tests/data/workflow/emr-script.sh create mode 100644 tests/unit/sagemaker/workflow/test_emr_step.py diff --git a/src/sagemaker/workflow/emr_step.py b/src/sagemaker/workflow/emr_step.py new file mode 100644 index 0000000000..edae377d5e --- /dev/null +++ b/src/sagemaker/workflow/emr_step.py @@ -0,0 +1,119 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The step definitions for workflow.""" +from __future__ import absolute_import + +from typing import List + +from sagemaker.workflow.entities import ( + RequestType, +) +from sagemaker.workflow.properties import ( + Properties, +) +from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig + + +class EMRStepConfig: + """Config for a Hadoop Jar step.""" + + def __init__( + self, jar, args: List[str] = None, main_class: str = None, properties: List[dict] = None + ): + """Create a definition for input data used by an EMR cluster(job flow) step. + + See AWS documentation on the ``StepConfig`` API for more details on the parameters. + + Args: + args(List[str]): + A list of command line arguments passed to + the JAR file's main function when executed. + jar(str): A path to a JAR file run during the step. + main_class(str): The name of the main class in the specified Java file. + properties(List(dict)): A list of key-value pairs that are set when the step runs. + """ + self.jar = jar + self.args = args + self.main_class = main_class + self.properties = properties + + def to_request(self) -> RequestType: + """Convert EMRStepConfig object to request dict.""" + config = {"HadoopJarStep": {"Jar": self.jar}} + if self.args is not None: + config["HadoopJarStep"]["Args"] = self.args + if self.main_class is not None: + config["HadoopJarStep"]["MainClass"] = self.main_class + if self.properties is not None: + config["HadoopJarStep"]["Properties"] = self.properties + + return config + + +class EMRStep(Step): + """EMR step for workflow.""" + + def __init__( + self, + name: str, + display_name: str, + description: str, + cluster_id: str, + step_config: EMRStepConfig, + depends_on: List[str] = None, + cache_config: CacheConfig = None, + ): + """Constructs a LambdaStep. + + Args: + name(str): The name of the EMR step. + display_name(str): The display name of the EMR step. + description(str): The description of the EMR step. + cluster_id(str): A string that uniquely identifies the cluster. + step_config(EMRStepConfig): One StepConfig to be executed by the job flow. + depends_on(List[str]): + A list of step names this `sagemaker.workflow.steps.EMRStep` depends on + cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. + + """ + super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on) + + emr_step_args = {"ClusterId": cluster_id, "StepConfig": step_config.to_request()} + self.args = emr_step_args + self.cache_config = cache_config + + root_property = Properties(path=f"Steps.{name}", shape_name="Step", service_name="emr") + root_property.__dict__["ClusterId"] = cluster_id + self._properties = root_property + + @property + def arguments(self) -> RequestType: + """The arguments dict that is used to call `AddJobFlowSteps`. + + NOTE: The AddFlowJobSteps request is not quite the args list that workflow needs. + The Name attribute in AddJobFlowSteps cannot be passed; it will be set during runtime. + In addition to that, we will also need to include emr job inputs and output config. + """ + return self.args + + @property + def properties(self) -> RequestType: + """A Properties object representing the EMR DescribeStepResponse model""" + return self._properties + + def to_request(self) -> RequestType: + """Updates the dictionary with cache configuration.""" + request_dict = super().to_request() + if self.cache_config: + request_dict.update(self.cache_config.config) + return request_dict diff --git a/src/sagemaker/workflow/properties.py b/src/sagemaker/workflow/properties.py index 96147e8e8b..6e9aba4408 100644 --- a/src/sagemaker/workflow/properties.py +++ b/src/sagemaker/workflow/properties.py @@ -23,17 +23,24 @@ class PropertiesMeta(type): - """Load an internal shapes attribute from the botocore sagemaker service model.""" + """Load an internal shapes attribute from the botocore service model - _shapes = None + for sagemaker and emr service. + """ + + _shapes_map = dict() _primitive_types = {"string", "boolean", "integer", "float"} def __new__(mcs, *args, **kwargs): - """Loads up the shapes from the botocore sagemaker service model.""" - if mcs._shapes is None: + """Loads up the shapes from the botocore service model.""" + if len(mcs._shapes_map.keys()) == 0: loader = botocore.loaders.Loader() - model = loader.load_service_model("sagemaker", "service-2") - mcs._shapes = model["shapes"] + + sagemaker_model = loader.load_service_model("sagemaker", "service-2") + emr_model = loader.load_service_model("emr", "service-2") + mcs._shapes_map["sagemaker"] = sagemaker_model["shapes"] + mcs._shapes_map["emr"] = emr_model["shapes"] + return super().__new__(mcs, *args, **kwargs) @@ -45,32 +52,41 @@ def __init__( path: str, shape_name: str = None, shape_names: List[str] = None, + service_name: str = "sagemaker", ): """Create a Properties instance representing the given shape. Args: path (str): The parent path of the Properties instance. - shape_name (str): The botocore sagemaker service model shape name. - shape_names (str): A List of the botocore sagemaker service model shape name. + shape_name (str): The botocore service model shape name. + shape_names (str): A List of the botocore service model shape name. """ self._path = path shape_names = [] if shape_names is None else shape_names self._shape_names = shape_names if shape_name is None else [shape_name] + shape_names + shapes = Properties._shapes_map.get(service_name, {}) + for name in self._shape_names: - shape = Properties._shapes.get(name, {}) + shape = shapes.get(name, {}) shape_type = shape.get("type") if shape_type in Properties._primitive_types: self.__str__ = name elif shape_type == "structure": members = shape["members"] for key, info in members.items(): - if Properties._shapes.get(info["shape"], {}).get("type") == "list": - self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"]) - elif Properties._shapes.get(info["shape"], {}).get("type") == "map": - self.__dict__[key] = PropertiesMap(f"{path}.{key}", info["shape"]) + if shapes.get(info["shape"], {}).get("type") == "list": + self.__dict__[key] = PropertiesList( + f"{path}.{key}", info["shape"], service_name + ) + elif shapes.get(info["shape"], {}).get("type") == "map": + self.__dict__[key] = PropertiesMap( + f"{path}.{key}", info["shape"], service_name + ) else: - self.__dict__[key] = Properties(f"{path}.{key}", info["shape"]) + self.__dict__[key] = Properties( + f"{path}.{key}", info["shape"], service_name=service_name + ) @property def expr(self): @@ -81,16 +97,17 @@ def expr(self): class PropertiesList(Properties): """PropertiesList for use in workflow expressions.""" - def __init__(self, path: str, shape_name: str = None): + def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"): """Create a PropertiesList instance representing the given shape. Args: path (str): The parent path of the PropertiesList instance. - shape_name (str): The botocore sagemaker service model shape name. - root_shape_name (str): The botocore sagemaker service model shape name. + shape_name (str): The botocore service model shape name. + service_name (str): The botocore service name. """ super(PropertiesList, self).__init__(path, shape_name) self.shape_name = shape_name + self.service_name = service_name self._items: Dict[Union[int, str], Properties] = dict() def __getitem__(self, item: Union[int, str]): @@ -100,7 +117,7 @@ def __getitem__(self, item: Union[int, str]): item (Union[int, str]): The index of the item in sequence. """ if item not in self._items.keys(): - shape = Properties._shapes.get(self.shape_name) + shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name) member = shape["member"]["shape"] if isinstance(item, str): property_item = Properties(f"{self._path}['{item}']", member) @@ -114,15 +131,17 @@ def __getitem__(self, item: Union[int, str]): class PropertiesMap(Properties): """PropertiesMap for use in workflow expressions.""" - def __init__(self, path: str, shape_name: str = None): + def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"): """Create a PropertiesMap instance representing the given shape. Args: path (str): The parent path of the PropertiesMap instance. shape_name (str): The botocore sagemaker service model shape name. + service_name (str): The botocore service name. """ super(PropertiesMap, self).__init__(path, shape_name) self.shape_name = shape_name + self.service_name = service_name self._items: Dict[Union[int, str], Properties] = dict() def __getitem__(self, item: Union[int, str]): @@ -132,7 +151,7 @@ def __getitem__(self, item: Union[int, str]): item (Union[int, str]): The index of the item in sequence. """ if item not in self._items.keys(): - shape = Properties._shapes.get(self.shape_name) + shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name) member = shape["value"]["shape"] if isinstance(item, str): property_item = Properties(f"{self._path}['{item}']", member) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 30eca68f66..329bd1d950 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -60,6 +60,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta): LAMBDA = "Lambda" QUALITY_CHECK = "QualityCheck" CLARIFY_CHECK = "ClarifyCheck" + EMR = "EMR" @attr.s diff --git a/tests/data/workflow/emr-script.sh b/tests/data/workflow/emr-script.sh new file mode 100644 index 0000000000..aeee24ec95 --- /dev/null +++ b/tests/data/workflow/emr-script.sh @@ -0,0 +1,2 @@ +echo "This is emr test script..." +sleep 15 diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index de03608b27..f4699f3229 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -69,6 +69,7 @@ from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.callback_step import CallbackStep, CallbackOutput, CallbackOutputTypeEnum from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum +from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig from sagemaker.wrangler.processing import DataWranglerProcessor from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition from sagemaker.workflow.execution_variables import ExecutionVariables @@ -95,6 +96,7 @@ from tests.integ import DATA_DIR from tests.integ.kms_utils import get_or_create_kms_key from tests.integ.retry import retries +from tests.integ.vpc_test_utils import get_or_create_vpc_resources def ordered(obj): @@ -1148,6 +1150,75 @@ def test_two_step_lambda_pipeline_with_output_reference( pass +def test_two_steps_emr_pipeline( + sagemaker_session, role, pipeline_name, region_name, emr_cluster_id, emr_script_path +): + instance_count = ParameterInteger(name="InstanceCount", default_value=2) + + emr_step_config = EMRStepConfig( + jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar", + args=[emr_script_path], + ) + + step_emr_1 = EMRStep( + name="emr-step-1", + cluster_id=emr_cluster_id, + display_name="emr_step_1", + description="MyEMRStepDescription", + step_config=emr_step_config, + ) + + step_emr_2 = EMRStep( + name="emr-step-2", + cluster_id=step_emr_1.properties.ClusterId, + display_name="emr_step_2", + description="MyEMRStepDescription", + step_config=emr_step_config, + ) + + pipeline = Pipeline( + name=pipeline_name, + parameters=[instance_count], + steps=[step_emr_1, step_emr_2], + sagemaker_session=sagemaker_session, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn + ) + + execution = pipeline.start() + try: + execution.wait(delay=60, max_attempts=5) + except WaiterError: + pass + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + assert execution_steps[0]["StepName"] == "emr-step-1" + assert execution_steps[0].get("FailureReason", "") == "" + assert execution_steps[0]["StepStatus"] == "Succeeded" + assert execution_steps[1]["StepName"] == "emr-step-2" + assert execution_steps[1].get("FailureReason", "") == "" + assert execution_steps[1]["StepStatus"] == "Succeeded" + + pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)] + response = pipeline.update(role) + update_arn = response["PipelineArn"] + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", + update_arn, + ) + finally: + try: + pipeline.delete() + except Exception: + pass + + def test_conditional_pytorch_training_model_registration( sagemaker_session, role, diff --git a/tests/unit/sagemaker/workflow/test_emr_step.py b/tests/unit/sagemaker/workflow/test_emr_step.py new file mode 100644 index 0000000000..e0dd81ebb5 --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_emr_step.py @@ -0,0 +1,175 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json + +import pytest + +from mock import Mock + +from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig +from sagemaker.workflow.steps import CacheConfig +from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.parameters import ParameterString + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name="boto_session", region_name="us-west-2") + session_mock = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name="us-west-2", + config=None, + local_mode=False, + ) + return session_mock + + +def test_emr_step_with_one_step_config(sagemaker_session): + emr_step_config = EMRStepConfig( + jar="s3:/script-runner/script-runner.jar", + args=["--arg_0", "arg_0_value"], + main_class="com.my.main", + properties=[{"Key": "Foo", "Value": "Foo_value"}, {"Key": "Bar", "Value": "Bar_value"}], + ) + + emr_step = EMRStep( + name="MyEMRStep", + display_name="MyEMRStep", + description="MyEMRStepDescription", + cluster_id="MyClusterID", + step_config=emr_step_config, + depends_on=["TestStep"], + cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"), + ) + emr_step.add_depends_on(["SecondTestStep"]) + assert emr_step.to_request() == { + "Name": "MyEMRStep", + "Type": "EMR", + "Arguments": { + "ClusterId": "MyClusterID", + "StepConfig": { + "HadoopJarStep": { + "Args": ["--arg_0", "arg_0_value"], + "Jar": "s3:/script-runner/script-runner.jar", + "MainClass": "com.my.main", + "Properties": [ + {"Key": "Foo", "Value": "Foo_value"}, + {"Key": "Bar", "Value": "Bar_value"}, + ], + } + }, + }, + "DependsOn": ["TestStep", "SecondTestStep"], + "DisplayName": "MyEMRStep", + "Description": "MyEMRStepDescription", + "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, + } + + assert emr_step.properties.ClusterId == "MyClusterID" + assert emr_step.properties.ActionOnFailure.expr == {"Get": "Steps.MyEMRStep.ActionOnFailure"} + assert emr_step.properties.Config.Args.expr == {"Get": "Steps.MyEMRStep.Config.Args"} + assert emr_step.properties.Config.Jar.expr == {"Get": "Steps.MyEMRStep.Config.Jar"} + assert emr_step.properties.Config.MainClass.expr == {"Get": "Steps.MyEMRStep.Config.MainClass"} + assert emr_step.properties.Id.expr == {"Get": "Steps.MyEMRStep.Id"} + assert emr_step.properties.Name.expr == {"Get": "Steps.MyEMRStep.Name"} + assert emr_step.properties.Status.State.expr == {"Get": "Steps.MyEMRStep.Status.State"} + assert emr_step.properties.Status.FailureDetails.Reason.expr == { + "Get": "Steps.MyEMRStep.Status.FailureDetails.Reason" + } + + +def test_pipeline_interpolates_emr_outputs(sagemaker_session): + parameter = ParameterString("MyStr") + + emr_step_config_1 = EMRStepConfig( + jar="s3:/script-runner/script-runner_1.jar", + args=["--arg_0", "arg_0_value"], + main_class="com.my.main", + properties=[{"Key": "Foo", "Value": "Foo_value"}, {"Key": "Bar", "Value": "Bar_value"}], + ) + + step_emr_1 = EMRStep( + name="emr_step_1", + cluster_id="MyClusterID", + display_name="emr_step_1", + description="MyEMRStepDescription", + depends_on=["TestStep"], + step_config=emr_step_config_1, + ) + + emr_step_config_2 = EMRStepConfig(jar="s3:/script-runner/script-runner_2.jar") + + step_emr_2 = EMRStep( + name="emr_step_2", + cluster_id="MyClusterID", + display_name="emr_step_2", + description="MyEMRStepDescription", + depends_on=["TestStep"], + step_config=emr_step_config_2, + ) + + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[step_emr_1, step_emr_2], + sagemaker_session=sagemaker_session, + ) + + assert json.loads(pipeline.definition()) == { + "Version": "2020-12-01", + "Metadata": {}, + "Parameters": [{"Name": "MyStr", "Type": "String"}], + "PipelineExperimentConfig": { + "ExperimentName": {"Get": "Execution.PipelineName"}, + "TrialName": {"Get": "Execution.PipelineExecutionId"}, + }, + "Steps": [ + { + "Name": "emr_step_1", + "Type": "EMR", + "Arguments": { + "ClusterId": "MyClusterID", + "StepConfig": { + "HadoopJarStep": { + "Args": ["--arg_0", "arg_0_value"], + "Jar": "s3:/script-runner/script-runner_1.jar", + "MainClass": "com.my.main", + "Properties": [ + {"Key": "Foo", "Value": "Foo_value"}, + {"Key": "Bar", "Value": "Bar_value"}, + ], + } + }, + }, + "DependsOn": ["TestStep"], + "Description": "MyEMRStepDescription", + "DisplayName": "emr_step_1", + }, + { + "Name": "emr_step_2", + "Type": "EMR", + "Arguments": { + "ClusterId": "MyClusterID", + "StepConfig": { + "HadoopJarStep": {"Jar": "s3:/script-runner/script-runner_2.jar"} + }, + }, + "Description": "MyEMRStepDescription", + "DisplayName": "emr_step_2", + "DependsOn": ["TestStep"], + }, + ], + } diff --git a/tests/unit/sagemaker/workflow/test_properties.py b/tests/unit/sagemaker/workflow/test_properties.py index accaf46533..405de5c0b2 100644 --- a/tests/unit/sagemaker/workflow/test_properties.py +++ b/tests/unit/sagemaker/workflow/test_properties.py @@ -70,6 +70,19 @@ def test_properties_tuning_job(): } +def test_properties_emr_step(): + prop = Properties("Steps.MyStep", "Step", service_name="emr") + some_prop_names = ["Id", "Name", "Config", "ActionOnFailure", "Status"] + for name in some_prop_names: + assert name in prop.__dict__.keys() + + assert prop.Id.expr == {"Get": "Steps.MyStep.Id"} + assert prop.Name.expr == {"Get": "Steps.MyStep.Name"} + assert prop.ActionOnFailure.expr == {"Get": "Steps.MyStep.ActionOnFailure"} + assert prop.Config.Jar.expr == {"Get": "Steps.MyStep.Config.Jar"} + assert prop.Status.State.expr == {"Get": "Steps.MyStep.Status.State"} + + def test_properties_describe_model_package_output(): prop = Properties("Steps.MyStep", "DescribeModelPackageOutput") some_prop_names = ["ModelPackageName", "ModelPackageGroupName", "ModelPackageArn"] From 59e4be8690a4121687870b769ebbea4a8e01f630 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=F0=9F=8F=85Ethan=20Cheng=F0=9F=98=8E?= Date: Wed, 12 Jan 2022 12:14:19 -0800 Subject: [PATCH 2/5] pr review feedback changes --- src/sagemaker/workflow/emr_step.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/workflow/emr_step.py b/src/sagemaker/workflow/emr_step.py index edae377d5e..8b244c78f2 100644 --- a/src/sagemaker/workflow/emr_step.py +++ b/src/sagemaker/workflow/emr_step.py @@ -73,13 +73,13 @@ def __init__( depends_on: List[str] = None, cache_config: CacheConfig = None, ): - """Constructs a LambdaStep. + """Constructs a EMRStep. Args: name(str): The name of the EMR step. display_name(str): The display name of the EMR step. description(str): The description of the EMR step. - cluster_id(str): A string that uniquely identifies the cluster. + cluster_id(str): The ID of the running EMR cluster. step_config(EMRStepConfig): One StepConfig to be executed by the job flow. depends_on(List[str]): A list of step names this `sagemaker.workflow.steps.EMRStep` depends on From 25f9121a6fb0a941167291ebd06018a1c3dbf9d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=F0=9F=8F=85Ethan=20Cheng=F0=9F=98=8E?= Date: Thu, 13 Jan 2022 17:14:55 -0800 Subject: [PATCH 3/5] remove actual emr step execution since we already testing in canary --- tests/integ/test_workflow.py | 29 +++-------------------------- 1 file changed, 3 insertions(+), 26 deletions(-) diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index f4699f3229..3510af7649 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -1151,18 +1151,18 @@ def test_two_step_lambda_pipeline_with_output_reference( def test_two_steps_emr_pipeline( - sagemaker_session, role, pipeline_name, region_name, emr_cluster_id, emr_script_path + sagemaker_session, role, pipeline_name, region_name ): instance_count = ParameterInteger(name="InstanceCount", default_value=2) emr_step_config = EMRStepConfig( jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar", - args=[emr_script_path], + args=["dummy_emr_script_path"], ) step_emr_1 = EMRStep( name="emr-step-1", - cluster_id=emr_cluster_id, + cluster_id="j-1YONHTCP3YZKC", display_name="emr_step_1", description="MyEMRStepDescription", step_config=emr_step_config, @@ -1189,29 +1189,6 @@ def test_two_steps_emr_pipeline( assert re.match( fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn ) - - execution = pipeline.start() - try: - execution.wait(delay=60, max_attempts=5) - except WaiterError: - pass - - execution_steps = execution.list_steps() - assert len(execution_steps) == 2 - assert execution_steps[0]["StepName"] == "emr-step-1" - assert execution_steps[0].get("FailureReason", "") == "" - assert execution_steps[0]["StepStatus"] == "Succeeded" - assert execution_steps[1]["StepName"] == "emr-step-2" - assert execution_steps[1].get("FailureReason", "") == "" - assert execution_steps[1]["StepStatus"] == "Succeeded" - - pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)] - response = pipeline.update(role) - update_arn = response["PipelineArn"] - assert re.match( - fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", - update_arn, - ) finally: try: pipeline.delete() From 0921686dbbaab97594e696d4cbf14f11f25c549b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=F0=9F=8F=85Ethan=20Cheng=F0=9F=98=8E?= Date: Thu, 13 Jan 2022 17:37:39 -0800 Subject: [PATCH 4/5] remove unused import --- tests/integ/test_workflow.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index 3510af7649..eee7c82f30 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -96,7 +96,6 @@ from tests.integ import DATA_DIR from tests.integ.kms_utils import get_or_create_kms_key from tests.integ.retry import retries -from tests.integ.vpc_test_utils import get_or_create_vpc_resources def ordered(obj): From 440d99ea943cac6aaf73a875a59133e3df9ee3b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=F0=9F=8F=85Ethan=20Cheng=F0=9F=98=8E?= Date: Fri, 14 Jan 2022 11:44:17 -0800 Subject: [PATCH 5/5] black-check failure fix --- tests/integ/test_workflow.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index eee7c82f30..4a3354470a 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -1149,9 +1149,7 @@ def test_two_step_lambda_pipeline_with_output_reference( pass -def test_two_steps_emr_pipeline( - sagemaker_session, role, pipeline_name, region_name -): +def test_two_steps_emr_pipeline(sagemaker_session, role, pipeline_name, region_name): instance_count = ParameterInteger(name="InstanceCount", default_value=2) emr_step_config = EMRStepConfig(