@@ -30,6 +30,7 @@ class Context:
3030 property_references : Dict [str , str ] = field (default_factory = dict )
3131 serialize_output_to_json : bool = False
3232 func_step_s3_dir : str = None
33+ s3_base_uri : str = None
3334
3435
3536@dataclass
@@ -77,6 +78,17 @@ class _ExecutionVariable:
7778 name : str
7879
7980
81+ @dataclass
82+ class _S3BaseUriIdentifier :
83+ """Identifies that the class refers to function step s3 base uri.
84+
85+ The s3_base_uri = s3_root_uri + pipeline_name.
86+ This identifier is resolved in function step runtime by SDK.
87+ """
88+
89+ NAME = "S3_BASE_URI"
90+
91+
8092@dataclass
8193class _DelayedReturn :
8294 """Delayed return from a function."""
@@ -155,6 +167,7 @@ def __init__(
155167 hmac_key : str ,
156168 parameter_resolver : _ParameterResolver ,
157169 execution_variable_resolver : _ExecutionVariableResolver ,
170+ s3_base_uri : str ,
158171 ** settings ,
159172 ):
160173 """Resolve delayed return.
@@ -164,8 +177,11 @@ def __init__(
164177 hmac_key: key used to encrypt serialized and deserialized function and arguments.
165178 parameter_resolver: resolver used to pipeline parameters.
166179 execution_variable_resolver: resolver used to resolve execution variables.
180+ s3_base_uri (str): the s3 base uri of the function step that
181+ the DelayedReturn object associates with.
167182 **settings: settings to pass to the deserialization function.
168183 """
184+ self ._s3_base_uri = s3_base_uri
169185 self ._parameter_resolver = parameter_resolver
170186 self ._execution_variable_resolver = execution_variable_resolver
171187 # different delayed returns can have the same uri, so we need to dedupe
@@ -205,6 +221,8 @@ def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn):
205221 uri .append (self ._parameter_resolver .resolve (component ))
206222 elif isinstance (component , _ExecutionVariable ):
207223 uri .append (self ._execution_variable_resolver .resolve (component ))
224+ elif isinstance (component , _S3BaseUriIdentifier ):
225+ uri .append (self ._s3_base_uri )
208226 else :
209227 uri .append (component )
210228 return s3_path_join (* uri )
@@ -251,6 +269,7 @@ def resolve_pipeline_variables(
251269 hmac_key = hmac_key ,
252270 parameter_resolver = parameter_resolver ,
253271 execution_variable_resolver = execution_variable_resolver ,
272+ s3_base_uri = context .s3_base_uri ,
254273 ** settings ,
255274 )
256275
@@ -289,11 +308,10 @@ def resolve_pipeline_variables(
289308 return resolved_func_args , resolved_func_kwargs
290309
291310
292- def convert_pipeline_variables_to_pickleable (s3_base_uri : str , func_args : Tuple , func_kwargs : Dict ):
311+ def convert_pipeline_variables_to_pickleable (func_args : Tuple , func_kwargs : Dict ):
293312 """Convert pipeline variables to pickleable.
294313
295314 Args:
296- s3_base_uri: s3 base uri where artifacts are stored.
297315 func_args: function args.
298316 func_kwargs: function kwargs.
299317 """
@@ -304,11 +322,19 @@ def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple,
304322
305323 from sagemaker .workflow .function_step import DelayedReturn
306324
325+ # Notes:
326+ # 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
327+ # when defining function steps and after function. After step-level arg serialization,
328+ # it's hard to update the s3_base_uri in pipeline compile time.
329+ # Thus set a placeholder _S3BaseUriIdentifier here and let the runtime job to resolve it.
330+ # 2. For saying s3_root_uri is unknown, it's because when defining function steps,
331+ # the sagemaker_session is not passed in the pipeline but the default s3_root_uri
332+ # should be retrieved from the pipeline's sagemaker_session.
307333 def convert (arg ):
308334 if isinstance (arg , DelayedReturn ):
309335 return _DelayedReturn (
310336 uri = [
311- s3_base_uri ,
337+ _S3BaseUriIdentifier () ,
312338 ExecutionVariables .PIPELINE_EXECUTION_ID ._pickleable ,
313339 arg ._step .name ,
314340 "results" ,
0 commit comments