2121
2222from mock import Mock
2323
24+ from sagemaker import s3
2425from sagemaker .workflow .execution_variables import ExecutionVariables
2526from sagemaker .workflow .parameters import ParameterString
2627from sagemaker .workflow .pipeline import Pipeline
28+ from sagemaker .workflow .parallelism_config import ParallelismConfiguration
2729from sagemaker .workflow .pipeline_experiment_config import (
2830 PipelineExperimentConfig ,
2931 PipelineExperimentConfigProperties ,
@@ -62,7 +64,9 @@ def role_arn():
6264
6365@pytest .fixture
6466def sagemaker_session_mock ():
65- return Mock ()
67+ session_mock = Mock ()
68+ session_mock .default_bucket = Mock (name = "default_bucket" , return_value = "s3_bucket" )
69+ return session_mock
6670
6771
6872def test_pipeline_create (sagemaker_session_mock , role_arn ):
@@ -78,6 +82,47 @@ def test_pipeline_create(sagemaker_session_mock, role_arn):
7882 )
7983
8084
85+ def test_pipeline_create_with_parallelism_config (sagemaker_session_mock , role_arn ):
86+ pipeline = Pipeline (
87+ name = "MyPipeline" ,
88+ parameters = [],
89+ steps = [],
90+ pipeline_experiment_config = ParallelismConfiguration (max_parallel_execution_steps = 10 ),
91+ sagemaker_session = sagemaker_session_mock ,
92+ )
93+ pipeline .create (role_arn = role_arn )
94+ assert sagemaker_session_mock .sagemaker_client .create_pipeline .called_with (
95+ PipelineName = "MyPipeline" ,
96+ PipelineDefinition = pipeline .definition (),
97+ RoleArn = role_arn ,
98+ ParallelismConfiguration = {"MaxParallelExecutionSteps" : 10 },
99+ )
100+
101+
102+ def test_large_pipeline_create (sagemaker_session_mock , role_arn ):
103+ parameter = ParameterString ("MyStr" )
104+ pipeline = Pipeline (
105+ name = "MyPipeline" ,
106+ parameters = [parameter ],
107+ steps = [CustomStep (name = "MyStep" , input_data = parameter )] * 2000 ,
108+ sagemaker_session = sagemaker_session_mock ,
109+ )
110+
111+ s3 .S3Uploader .upload_string_as_file_body = Mock ()
112+
113+ pipeline .create (role_arn = role_arn )
114+
115+ assert s3 .S3Uploader .upload_string_as_file_body .called_with (
116+ body = pipeline .definition (), s3_uri = "s3://s3_bucket/MyPipeline"
117+ )
118+
119+ assert sagemaker_session_mock .sagemaker_client .create_pipeline .called_with (
120+ PipelineName = "MyPipeline" ,
121+ PipelineDefinitionS3Location = {"Bucket" : "s3_bucket" , "ObjectKey" : "MyPipeline" },
122+ RoleArn = role_arn ,
123+ )
124+
125+
81126def test_pipeline_update (sagemaker_session_mock , role_arn ):
82127 pipeline = Pipeline (
83128 name = "MyPipeline" ,
@@ -91,6 +136,47 @@ def test_pipeline_update(sagemaker_session_mock, role_arn):
91136 )
92137
93138
139+ def test_pipeline_update_with_parallelism_config (sagemaker_session_mock , role_arn ):
140+ pipeline = Pipeline (
141+ name = "MyPipeline" ,
142+ parameters = [],
143+ steps = [],
144+ pipeline_experiment_config = ParallelismConfiguration (max_parallel_execution_steps = 10 ),
145+ sagemaker_session = sagemaker_session_mock ,
146+ )
147+ pipeline .create (role_arn = role_arn )
148+ assert sagemaker_session_mock .sagemaker_client .update_pipeline .called_with (
149+ PipelineName = "MyPipeline" ,
150+ PipelineDefinition = pipeline .definition (),
151+ RoleArn = role_arn ,
152+ ParallelismConfiguration = {"MaxParallelExecutionSteps" : 10 },
153+ )
154+
155+
156+ def test_large_pipeline_update (sagemaker_session_mock , role_arn ):
157+ parameter = ParameterString ("MyStr" )
158+ pipeline = Pipeline (
159+ name = "MyPipeline" ,
160+ parameters = [parameter ],
161+ steps = [CustomStep (name = "MyStep" , input_data = parameter )] * 2000 ,
162+ sagemaker_session = sagemaker_session_mock ,
163+ )
164+
165+ s3 .S3Uploader .upload_string_as_file_body = Mock ()
166+
167+ pipeline .create (role_arn = role_arn )
168+
169+ assert s3 .S3Uploader .upload_string_as_file_body .called_with (
170+ body = pipeline .definition (), s3_uri = "s3://s3_bucket/MyPipeline"
171+ )
172+
173+ assert sagemaker_session_mock .sagemaker_client .update_pipeline .called_with (
174+ PipelineName = "MyPipeline" ,
175+ PipelineDefinitionS3Location = {"Bucket" : "s3_bucket" , "ObjectKey" : "MyPipeline" },
176+ RoleArn = role_arn ,
177+ )
178+
179+
94180def test_pipeline_upsert (sagemaker_session_mock , role_arn ):
95181 sagemaker_session_mock .side_effect = [
96182 ClientError (
0 commit comments