1515
1616from concurrent .futures import ThreadPoolExecutor
1717from dataclasses import dataclass , field
18- from typing import Any , Union , Dict , List , Tuple
18+ from typing import Any , Dict , List , Tuple
1919
20- from sagemaker .s3 import s3_path_join
2120from sagemaker .remote_function .core .serialization import deserialize_obj_from_s3
21+ from sagemaker .workflow .step_outputs import get_step
2222
2323
2424@dataclass
@@ -77,22 +77,11 @@ class _ExecutionVariable:
7777 name : str
7878
7979
80- @dataclass
81- class _S3BaseUriIdentifier :
82- """Identifies that the class refers to function step s3 base uri.
83-
84- The s3_base_uri = s3_root_uri + pipeline_name.
85- This identifier is resolved in function step runtime by SDK.
86- """
87-
88- NAME = "S3_BASE_URI"
89-
90-
9180@dataclass
9281class _DelayedReturn :
9382 """Delayed return from a function."""
9483
95- uri : List [ Union [ str , _Parameter , _ExecutionVariable ]]
84+ uri : _Properties
9685 reference_path : Tuple = field (default_factory = tuple )
9786
9887
@@ -164,26 +153,18 @@ def __init__(
164153 self ,
165154 delayed_returns : List [_DelayedReturn ],
166155 hmac_key : str ,
167- parameter_resolver : _ParameterResolver ,
168- execution_variable_resolver : _ExecutionVariableResolver ,
169- s3_base_uri : str ,
156+ properties_resolver : _PropertiesResolver ,
170157 ** settings ,
171158 ):
172159 """Resolve delayed return.
173160
174161 Args:
175162 delayed_returns: list of delayed returns to resolve.
176163 hmac_key: key used to encrypt serialized and deserialized function and arguments.
177- parameter_resolver: resolver used to pipeline parameters.
178- execution_variable_resolver: resolver used to resolve execution variables.
179- s3_base_uri (str): the s3 base uri of the function step that
180- the serialized artifacts will be uploaded to.
181- The s3_base_uri = s3_root_uri + pipeline_name.
164+ properties_resolver: resolver used to resolve step properties.
182165 **settings: settings to pass to the deserialization function.
183166 """
184- self ._s3_base_uri = s3_base_uri
185- self ._parameter_resolver = parameter_resolver
186- self ._execution_variable_resolver = execution_variable_resolver
167+ self ._properties_resolver = properties_resolver
187168 # different delayed returns can have the same uri, so we need to dedupe
188169 uris = {
189170 self ._resolve_delayed_return_uri (delayed_return ) for delayed_return in delayed_returns
@@ -214,18 +195,7 @@ def resolve(self, delayed_return: _DelayedReturn) -> Any:
214195
215196 def _resolve_delayed_return_uri (self , delayed_return : _DelayedReturn ):
216197 """Resolve the s3 uri of the delayed return."""
217-
218- uri = []
219- for component in delayed_return .uri :
220- if isinstance (component , _Parameter ):
221- uri .append (self ._parameter_resolver .resolve (component ))
222- elif isinstance (component , _ExecutionVariable ):
223- uri .append (self ._execution_variable_resolver .resolve (component ))
224- elif isinstance (component , _S3BaseUriIdentifier ):
225- uri .append (self ._s3_base_uri )
226- else :
227- uri .append (component )
228- return s3_path_join (* uri )
198+ return self ._properties_resolver .resolve (delayed_return .uri )
229199
230200
231201def _retrieve_child_item (delayed_return : _DelayedReturn , deserialized_obj : Any ):
@@ -241,7 +211,6 @@ def resolve_pipeline_variables(
241211 func_args : Tuple ,
242212 func_kwargs : Dict ,
243213 hmac_key : str ,
244- s3_base_uri : str ,
245214 ** settings ,
246215):
247216 """Resolve pipeline variables.
@@ -251,8 +220,6 @@ def resolve_pipeline_variables(
251220 func_args: function args.
252221 func_kwargs: function kwargs.
253222 hmac_key: key used to encrypt serialized and deserialized function and arguments.
254- s3_base_uri: the s3 base uri of the function step that the serialized artifacts
255- will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
256223 **settings: settings to pass to the deserialization function.
257224 """
258225
@@ -274,9 +241,7 @@ def resolve_pipeline_variables(
274241 delayed_return_resolver = _DelayedReturnResolver (
275242 delayed_returns = delayed_returns ,
276243 hmac_key = hmac_key ,
277- parameter_resolver = parameter_resolver ,
278- execution_variable_resolver = execution_variable_resolver ,
279- s3_base_uri = s3_base_uri ,
244+ properties_resolver = properties_resolver ,
280245 ** settings ,
281246 )
282247
@@ -322,39 +287,27 @@ def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict
322287 func_args: function args.
323288 func_kwargs: function kwargs.
324289 """
290+ converted_func_args = tuple (_convert_pipeline_variable_to_pickleable (arg ) for arg in func_args )
291+ converted_func_kwargs = {
292+ key : _convert_pipeline_variable_to_pickleable (arg ) for key , arg in func_kwargs .items ()
293+ }
325294
326- from sagemaker .workflow .entities import PipelineVariable
327-
328- from sagemaker .workflow .execution_variables import ExecutionVariables
295+ return converted_func_args , converted_func_kwargs
329296
330- from sagemaker .workflow .function_step import DelayedReturn
331297
332- # Notes:
333- # 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
334- # when defining function steps. After step-level arg serialization,
335- # it's hard to update the s3_base_uri in pipeline compile time.
336- # Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
337- # 2. For saying s3_root_uri is unknown, it's because when defining function steps,
338- # the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
339- # should be retrieved from the pipeline's sagemaker_session.
340- def convert (arg ):
341- if isinstance (arg , DelayedReturn ):
342- return _DelayedReturn (
343- uri = [
344- _S3BaseUriIdentifier (),
345- ExecutionVariables .PIPELINE_EXECUTION_ID ._pickleable ,
346- arg ._step .name ,
347- "results" ,
348- ],
349- reference_path = arg ._reference_path ,
350- )
298+ def _convert_pipeline_variable_to_pickleable (arg ):
299+ """Convert a pipeline variable to pickleable."""
300+ from sagemaker .workflow .entities import PipelineVariable
351301
352- if isinstance (arg , PipelineVariable ):
353- return arg ._pickleable
302+ from sagemaker .workflow .function_step import DelayedReturn
354303
355- return arg
304+ if isinstance (arg , DelayedReturn ):
305+ return _DelayedReturn (
306+ uri = get_step (arg )._properties .OutputDataConfig .S3OutputPath ._pickleable ,
307+ reference_path = arg ._reference_path ,
308+ )
356309
357- converted_func_args = tuple ( convert ( arg ) for arg in func_args )
358- converted_func_kwargs = { key : convert ( arg ) for key , arg in func_kwargs . items ()}
310+ if isinstance ( arg , PipelineVariable ):
311+ return arg . _pickleable
359312
360- return converted_func_args , converted_func_kwargs
313+ return arg
0 commit comments