1919
2020from sagemaker .s3 import s3_path_join
2121from sagemaker .remote_function .core .serialization import deserialize_obj_from_s3
22+ from sagemaker .workflow .step_outputs import get_step
2223
2324
2425@dataclass
@@ -166,6 +167,7 @@ def __init__(
166167 hmac_key : str ,
167168 parameter_resolver : _ParameterResolver ,
168169 execution_variable_resolver : _ExecutionVariableResolver ,
170+ properties_resolver : _PropertiesResolver ,
169171 s3_base_uri : str ,
170172 ** settings ,
171173 ):
@@ -184,6 +186,7 @@ def __init__(
184186 self ._s3_base_uri = s3_base_uri
185187 self ._parameter_resolver = parameter_resolver
186188 self ._execution_variable_resolver = execution_variable_resolver
189+ self ._properties_resolver = properties_resolver
187190 # different delayed returns can have the same uri, so we need to dedupe
188191 uris = {
189192 self ._resolve_delayed_return_uri (delayed_return ) for delayed_return in delayed_returns
@@ -223,6 +226,8 @@ def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn):
223226 uri .append (self ._execution_variable_resolver .resolve (component ))
224227 elif isinstance (component , _S3BaseUriIdentifier ):
225228 uri .append (self ._s3_base_uri )
229+ elif isinstance (component , _Properties ):
230+ uri .append (self ._properties_resolver .resolve (component ))
226231 else :
227232 uri .append (component )
228233 return s3_path_join (* uri )
@@ -276,6 +281,7 @@ def resolve_pipeline_variables(
276281 hmac_key = hmac_key ,
277282 parameter_resolver = parameter_resolver ,
278283 execution_variable_resolver = execution_variable_resolver ,
284+ properties_resolver = properties_resolver ,
279285 s3_base_uri = s3_base_uri ,
280286 ** settings ,
281287 )
@@ -322,39 +328,43 @@ def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict
322328 func_args: function args.
323329 func_kwargs: function kwargs.
324330 """
331+ converted_func_args = tuple (_convert_pipeline_variable_to_pickleable (arg ) for arg in func_args )
332+ converted_func_kwargs = {
333+ key : _convert_pipeline_variable_to_pickleable (arg ) for key , arg in func_kwargs .items ()
334+ }
325335
326- from sagemaker .workflow .entities import PipelineVariable
327-
328- from sagemaker .workflow .execution_variables import ExecutionVariables
329-
330- from sagemaker .workflow .function_step import DelayedReturn
331-
332- # Notes:
333- # 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
334- # when defining function steps. After step-level arg serialization,
335- # it's hard to update the s3_base_uri in pipeline compile time.
336- # Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
337- # 2. For saying s3_root_uri is unknown, it's because when defining function steps,
338- # the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
339- # should be retrieved from the pipeline's sagemaker_session.
340- def convert (arg ):
341- if isinstance (arg , DelayedReturn ):
342- return _DelayedReturn (
343- uri = [
344- _S3BaseUriIdentifier (),
345- ExecutionVariables .PIPELINE_EXECUTION_ID ._pickleable ,
346- arg ._step .name ,
347- "results" ,
348- ],
349- reference_path = arg ._reference_path ,
350- )
336+ return converted_func_args , converted_func_kwargs
351337
352- if isinstance (arg , PipelineVariable ):
353- return arg ._pickleable
354338
355- return arg
339+ def _convert_pipeline_variable_to_pickleable (arg ):
340+ """Convert a pipeline variable to pickleable."""
341+ from sagemaker .workflow .entities import PipelineVariable
356342
357- converted_func_args = tuple (convert (arg ) for arg in func_args )
358- converted_func_kwargs = {key : convert (arg ) for key , arg in func_kwargs .items ()}
343+ from sagemaker .workflow .function_step import DelayedReturn
359344
360- return converted_func_args , converted_func_kwargs
345+ if isinstance (arg , DelayedReturn ):
346+ # Notes:
347+ # 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
348+ # when defining function steps. After step-level arg serialization,
349+ # it's hard to update the s3_base_uri in pipeline compile time.
350+ # Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
351+ # 2. For saying s3_root_uri is unknown, it's because when defining function steps,
352+ # the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
353+ # should be retrieved from the pipeline's sagemaker_session.
354+
355+ container_args = get_step (arg )._properties .AlgorithmSpecification .ContainerArguments
356+ execution_id = container_args [11 ]._pickleable
357+ return _DelayedReturn (
358+ uri = [
359+ _S3BaseUriIdentifier (),
360+ execution_id ,
361+ arg ._step .name ,
362+ "results" ,
363+ ],
364+ reference_path = arg ._reference_path ,
365+ )
366+
367+ if isinstance (arg , PipelineVariable ):
368+ return arg ._pickleable
369+
370+ return arg
0 commit comments