1919
2020from sagemaker .s3 import s3_path_join
2121from sagemaker .remote_function .core .serialization import deserialize_obj_from_s3
22+ from sagemaker .workflow .step_outputs import get_step
2223
2324
2425@dataclass
@@ -92,7 +93,7 @@ class _S3BaseUriIdentifier:
9293class _DelayedReturn :
9394 """Delayed return from a function."""
9495
95- uri : List [Union [str , _Parameter , _ExecutionVariable ]]
96+ uri : Union [ _Properties , List [Union [str , _Parameter , _ExecutionVariable ] ]]
9697 reference_path : Tuple = field (default_factory = tuple )
9798
9899
@@ -164,6 +165,7 @@ def __init__(
164165 self ,
165166 delayed_returns : List [_DelayedReturn ],
166167 hmac_key : str ,
168+ properties_resolver : _PropertiesResolver ,
167169 parameter_resolver : _ParameterResolver ,
168170 execution_variable_resolver : _ExecutionVariableResolver ,
169171 s3_base_uri : str ,
@@ -174,6 +176,7 @@ def __init__(
174176 Args:
175177 delayed_returns: list of delayed returns to resolve.
176178 hmac_key: key used to encrypt serialized and deserialized function and arguments.
179+ properties_resolver: resolver used to resolve step properties.
177180 parameter_resolver: resolver used to pipeline parameters.
178181 execution_variable_resolver: resolver used to resolve execution variables.
179182 s3_base_uri (str): the s3 base uri of the function step that
@@ -184,6 +187,7 @@ def __init__(
184187 self ._s3_base_uri = s3_base_uri
185188 self ._parameter_resolver = parameter_resolver
186189 self ._execution_variable_resolver = execution_variable_resolver
190+ self ._properties_resolver = properties_resolver
187191 # different delayed returns can have the same uri, so we need to dedupe
188192 uris = {
189193 self ._resolve_delayed_return_uri (delayed_return ) for delayed_return in delayed_returns
@@ -214,7 +218,10 @@ def resolve(self, delayed_return: _DelayedReturn) -> Any:
214218
215219 def _resolve_delayed_return_uri (self , delayed_return : _DelayedReturn ):
216220 """Resolve the s3 uri of the delayed return."""
221+ if isinstance (delayed_return .uri , _Properties ):
222+ return self ._properties_resolver .resolve (delayed_return .uri )
217223
224+ # Keep the following old resolution logics to keep backward compatible
218225 uri = []
219226 for component in delayed_return .uri :
220227 if isinstance (component , _Parameter ):
@@ -274,6 +281,7 @@ def resolve_pipeline_variables(
274281 delayed_return_resolver = _DelayedReturnResolver (
275282 delayed_returns = delayed_returns ,
276283 hmac_key = hmac_key ,
284+ properties_resolver = properties_resolver ,
277285 parameter_resolver = parameter_resolver ,
278286 execution_variable_resolver = execution_variable_resolver ,
279287 s3_base_uri = s3_base_uri ,
@@ -325,27 +333,12 @@ def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict
325333
326334 from sagemaker .workflow .entities import PipelineVariable
327335
328- from sagemaker .workflow .execution_variables import ExecutionVariables
329-
330336 from sagemaker .workflow .function_step import DelayedReturn
331337
332- # Notes:
333- # 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
334- # when defining function steps. After step-level arg serialization,
335- # it's hard to update the s3_base_uri in pipeline compile time.
336- # Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
337- # 2. For saying s3_root_uri is unknown, it's because when defining function steps,
338- # the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
339- # should be retrieved from the pipeline's sagemaker_session.
340338 def convert (arg ):
341339 if isinstance (arg , DelayedReturn ):
342340 return _DelayedReturn (
343- uri = [
344- _S3BaseUriIdentifier (),
345- ExecutionVariables .PIPELINE_EXECUTION_ID ._pickleable ,
346- arg ._step .name ,
347- "results" ,
348- ],
341+ uri = get_step (arg )._properties .OutputDataConfig .S3OutputPath ._pickleable ,
349342 reference_path = arg ._reference_path ,
350343 )
351344
0 commit comments