2121
2222from mock import Mock
2323
24+ from sagemaker import s3
2425from sagemaker .workflow .execution_variables import ExecutionVariables
2526from sagemaker .workflow .parameters import ParameterString
2627from sagemaker .workflow .pipeline import Pipeline
28+ from sagemaker .workflow .parallelism_config import ParallelismConfiguration
2729from sagemaker .workflow .pipeline_experiment_config import (
2830 PipelineExperimentConfig ,
2931 PipelineExperimentConfigProperties ,
@@ -62,7 +64,9 @@ def role_arn():
6264
6365@pytest .fixture
6466def sagemaker_session_mock ():
65- return Mock ()
67+ session_mock = Mock ()
68+ session_mock .default_bucket = Mock (name = "default_bucket" , return_value = "s3_bucket" )
69+ return session_mock
6670
6771
6872def test_pipeline_create (sagemaker_session_mock , role_arn ):
@@ -78,6 +82,50 @@ def test_pipeline_create(sagemaker_session_mock, role_arn):
7882 )
7983
8084
85+ def test_pipeline_create_with_parallelism_config (sagemaker_session_mock , role_arn ):
86+ pipeline = Pipeline (
87+ name = "MyPipeline" ,
88+ parameters = [],
89+ steps = [],
90+ pipeline_experiment_config = ParallelismConfiguration (max_parallel_execution_steps = 10 ),
91+ sagemaker_session = sagemaker_session_mock ,
92+ )
93+ pipeline .create (role_arn = role_arn )
94+ assert sagemaker_session_mock .sagemaker_client .create_pipeline .called_with (
95+ PipelineName = "MyPipeline" , PipelineDefinition = pipeline .definition (), RoleArn = role_arn ,
96+ ParallelismConfiguration = {
97+ "MaxParallelExecutionSteps" : 10
98+ }
99+ )
100+
101+
102+ def test_large_pipeline_create (sagemaker_session_mock , role_arn ):
103+ parameter = ParameterString ("MyStr" )
104+ pipeline = Pipeline (
105+ name = "MyPipeline" ,
106+ parameters = [parameter ],
107+ steps = [CustomStep (name = "MyStep" , input_data = parameter )] * 2000 ,
108+ sagemaker_session = sagemaker_session_mock ,
109+ )
110+
111+ s3 .S3Uploader .upload_string_as_file_body = Mock ()
112+
113+ pipeline .create (role_arn = role_arn )
114+
115+ assert s3 .S3Uploader .upload_string_as_file_body .called_with (
116+ body = pipeline .definition (),
117+ s3_uri = "s3://s3_bucket/MyPipeline" )
118+
119+ assert sagemaker_session_mock .sagemaker_client .create_pipeline .called_with (
120+ PipelineName = "MyPipeline" ,
121+ PipelineDefinitionS3Location = {
122+ "Bucket" : "s3_bucket" ,
123+ "ObjectKey" : "MyPipeline"
124+ },
125+ RoleArn = role_arn
126+ )
127+
128+
81129def test_pipeline_update (sagemaker_session_mock , role_arn ):
82130 pipeline = Pipeline (
83131 name = "MyPipeline" ,
@@ -91,6 +139,50 @@ def test_pipeline_update(sagemaker_session_mock, role_arn):
91139 )
92140
93141
142+ def test_pipeline_update_with_parallelism_config (sagemaker_session_mock , role_arn ):
143+ pipeline = Pipeline (
144+ name = "MyPipeline" ,
145+ parameters = [],
146+ steps = [],
147+ pipeline_experiment_config = ParallelismConfiguration (max_parallel_execution_steps = 10 ),
148+ sagemaker_session = sagemaker_session_mock ,
149+ )
150+ pipeline .create (role_arn = role_arn )
151+ assert sagemaker_session_mock .sagemaker_client .update_pipeline .called_with (
152+ PipelineName = "MyPipeline" , PipelineDefinition = pipeline .definition (), RoleArn = role_arn ,
153+ ParallelismConfiguration = {
154+ "MaxParallelExecutionSteps" : 10
155+ }
156+ )
157+
158+
159+ def test_large_pipeline_update (sagemaker_session_mock , role_arn ):
160+ parameter = ParameterString ("MyStr" )
161+ pipeline = Pipeline (
162+ name = "MyPipeline" ,
163+ parameters = [parameter ],
164+ steps = [CustomStep (name = "MyStep" , input_data = parameter )] * 2000 ,
165+ sagemaker_session = sagemaker_session_mock ,
166+ )
167+
168+ s3 .S3Uploader .upload_string_as_file_body = Mock ()
169+
170+ pipeline .create (role_arn = role_arn )
171+
172+ assert s3 .S3Uploader .upload_string_as_file_body .called_with (
173+ body = pipeline .definition (),
174+ s3_uri = "s3://s3_bucket/MyPipeline" )
175+
176+ assert sagemaker_session_mock .sagemaker_client .update_pipeline .called_with (
177+ PipelineName = "MyPipeline" ,
178+ PipelineDefinitionS3Location = {
179+ "Bucket" : "s3_bucket" ,
180+ "ObjectKey" : "MyPipeline"
181+ },
182+ RoleArn = role_arn
183+ )
184+
185+
94186def test_pipeline_upsert (sagemaker_session_mock , role_arn ):
95187 sagemaker_session_mock .side_effect = [
96188 ClientError (
0 commit comments