2222import botocore
2323from botocore .exceptions import ClientError
2424
25+ from sagemaker import s3
2526from sagemaker ._studio import _append_project_tags
2627from sagemaker .session import Session
2728from sagemaker .workflow .callback_step import CallbackOutput , CallbackStep
3435from sagemaker .workflow .execution_variables import ExecutionVariables
3536from sagemaker .workflow .parameters import Parameter
3637from sagemaker .workflow .pipeline_experiment_config import PipelineExperimentConfig
38+ from sagemaker .workflow .parallelism_config import ParallelismConfiguration
3739from sagemaker .workflow .properties import Properties
3840from sagemaker .workflow .steps import Step
3941from sagemaker .workflow .step_collections import StepCollection
@@ -94,6 +96,7 @@ def create(
9496 role_arn : str ,
9597 description : str = None ,
9698 tags : List [Dict [str , str ]] = None ,
99+ parallelism_config : ParallelismConfiguration = None ,
97100 ) -> Dict [str , Any ]:
98101 """Creates a Pipeline in the Pipelines service.
99102
@@ -102,37 +105,62 @@ def create(
102105 description (str): A description of the pipeline.
103106 tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
104107 tags.
108+ parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
109+ that is applied to each of the executions of the pipeline. It takes precedence
110+ over the parallelism configuration of the parent pipeline.
105111
106112 Returns:
107113 A response dict from the service.
108114 """
109115 tags = _append_project_tags (tags )
110-
111- kwargs = self ._create_args (role_arn , description )
116+ kwargs = self ._create_args (role_arn , description , parallelism_config )
112117 update_args (
113118 kwargs ,
114119 Tags = tags ,
115120 )
116121 return self .sagemaker_session .sagemaker_client .create_pipeline (** kwargs )
117122
118- def _create_args (self , role_arn : str , description : str ):
123+ def _create_args (
124+ self , role_arn : str , description : str , parallelism_config : ParallelismConfiguration
125+ ):
119126 """Constructs the keyword argument dict for a create_pipeline call.
120127
121128 Args:
122129 role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
123130 description (str): A description of the pipeline.
131+ parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
132+ that is applied to each of the executions of the pipeline. It takes precedence
133+ over the parallelism configuration of the parent pipeline.
124134
125135 Returns:
126136 A keyword argument dict for calling create_pipeline.
127137 """
138+ pipeline_definition = self .definition ()
128139 kwargs = dict (
129140 PipelineName = self .name ,
130- PipelineDefinition = self .definition (),
131141 RoleArn = role_arn ,
132142 )
143+
144+ # If pipeline definition is large, upload to S3 bucket and
145+ # provide PipelineDefinitionS3Location to request instead.
146+ if len (pipeline_definition .encode ("utf-8" )) < 1024 * 100 :
147+ kwargs ["PipelineDefinition" ] = pipeline_definition
148+ else :
149+ desired_s3_uri = s3 .s3_path_join (
150+ "s3://" , self .sagemaker_session .default_bucket (), self .name
151+ )
152+ s3 .S3Uploader .upload_string_as_file_body (
153+ body = pipeline_definition ,
154+ desired_s3_uri = desired_s3_uri ,
155+ sagemaker_session = self .sagemaker_session ,
156+ )
157+ kwargs ["PipelineDefinitionS3Location" ] = {
158+ "Bucket" : self .sagemaker_session .default_bucket (),
159+ "ObjectKey" : self .name ,
160+ }
161+
133162 update_args (
134- kwargs ,
135- PipelineDescription = description ,
163+ kwargs , PipelineDescription = description , ParallelismConfiguration = parallelism_config
136164 )
137165 return kwargs
138166
@@ -146,24 +174,33 @@ def describe(self) -> Dict[str, Any]:
146174 """
147175 return self .sagemaker_session .sagemaker_client .describe_pipeline (PipelineName = self .name )
148176
149- def update (self , role_arn : str , description : str = None ) -> Dict [str , Any ]:
177+ def update (
178+ self ,
179+ role_arn : str ,
180+ description : str = None ,
181+ parallelism_config : ParallelismConfiguration = None ,
182+ ) -> Dict [str , Any ]:
150183 """Updates a Pipeline in the Workflow service.
151184
152185 Args:
153186 role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
154187 description (str): A description of the pipeline.
188+ parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
189+ that is applied to each of the executions of the pipeline. It takes precedence
190+ over the parallelism configuration of the parent pipeline.
155191
156192 Returns:
157193 A response dict from the service.
158194 """
159- kwargs = self ._create_args (role_arn , description )
195+ kwargs = self ._create_args (role_arn , description , parallelism_config )
160196 return self .sagemaker_session .sagemaker_client .update_pipeline (** kwargs )
161197
162198 def upsert (
163199 self ,
164200 role_arn : str ,
165201 description : str = None ,
166202 tags : List [Dict [str , str ]] = None ,
203+ parallelism_config : ParallelismConfiguration = None ,
167204 ) -> Dict [str , Any ]:
168205 """Creates a pipeline or updates it, if it already exists.
169206
@@ -172,12 +209,14 @@ def upsert(
172209 description (str): A description of the pipeline.
173210 tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
174211 tags.
212+ parallelism_config (Optional[Config for parallel steps, Parallelism configuration that
213+ is applied to each of. the executions
175214
176215 Returns:
177216 response dict from service
178217 """
179218 try :
180- response = self .create (role_arn , description , tags )
219+ response = self .create (role_arn , description , tags , parallelism_config )
181220 except ClientError as e :
182221 error = e .response ["Error" ]
183222 if (
@@ -215,6 +254,7 @@ def start(
215254 parameters : Dict [str , Union [str , bool , int , float ]] = None ,
216255 execution_display_name : str = None ,
217256 execution_description : str = None ,
257+ parallelism_config : ParallelismConfiguration = None ,
218258 ):
219259 """Starts a Pipeline execution in the Workflow service.
220260
@@ -223,6 +263,9 @@ def start(
223263 pipeline parameters.
224264 execution_display_name (str): The display name of the pipeline execution.
225265 execution_description (str): A description of the execution.
266+ parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
267+ that is applied to each of the executions of the pipeline. It takes precedence
268+ over the parallelism configuration of the parent pipeline.
226269
227270 Returns:
228271 A `_PipelineExecution` instance, if successful.
@@ -245,6 +288,7 @@ def start(
245288 PipelineParameters = format_start_parameters (parameters ),
246289 PipelineExecutionDescription = execution_description ,
247290 PipelineExecutionDisplayName = execution_display_name ,
291+ ParallelismConfiguration = parallelism_config ,
248292 )
249293 response = self .sagemaker_session .sagemaker_client .start_pipeline_execution (** kwargs )
250294 return _PipelineExecution (
0 commit comments