2424
2525from sagemaker ._studio import _append_project_tags
2626from sagemaker .session import Session
27- from sagemaker .workflow .callback_step import CallbackOutput
27+ from sagemaker .workflow .callback_step import CallbackOutput , CallbackStep
2828from sagemaker .workflow .entities import (
2929 Entity ,
3030 Expression ,
@@ -240,9 +240,12 @@ def definition(self) -> str:
240240 """Converts a request structure to string representation for workflow service calls."""
241241 request_dict = self .to_request ()
242242 request_dict ["PipelineExperimentConfig" ] = interpolate (
243- request_dict ["PipelineExperimentConfig" ]
243+ request_dict ["PipelineExperimentConfig" ], {}
244+ )
245+ callback_output_to_step_map = _map_callback_outputs (self .steps )
246+ request_dict ["Steps" ] = interpolate (
247+ request_dict ["Steps" ], callback_output_to_step_map = callback_output_to_step_map
244248 )
245- request_dict ["Steps" ] = interpolate (request_dict ["Steps" ])
246249
247250 return json .dumps (request_dict )
248251
@@ -263,38 +266,62 @@ def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
263266 return [{"Name" : name , "Value" : str (value )} for name , value in parameters .items ()]
264267
265268
266- def interpolate (request_obj : RequestType ) -> RequestType :
269+ def interpolate (
270+ request_obj : RequestType , callback_output_to_step_map : Dict [str , str ]
271+ ) -> RequestType :
267272 """Replaces Parameter values in a list of nested Dict[str, Any] with their workflow expression.
268273
269274 Args:
270275 request_obj (RequestType): The request dict.
276+ callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
271277
272278 Returns:
273279 RequestType: The request dict with Parameter values replaced by their expression.
274280 """
275281 request_obj_copy = deepcopy (request_obj )
276- return _interpolate (request_obj_copy )
282+ return _interpolate (request_obj_copy , callback_output_to_step_map = callback_output_to_step_map )
277283
278284
279- def _interpolate (obj : Union [RequestType , Any ]):
285+ def _interpolate (obj : Union [RequestType , Any ], callback_output_to_step_map : Dict [ str , str ] ):
280286 """Walks the nested request dict, replacing Parameter type values with workflow expressions.
281287
282288 Args:
283289 obj (Union[RequestType, Any]): The request dict.
290+ callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
284291 """
285- if isinstance (obj , (Expression , Parameter , Properties , CallbackOutput )):
292+ if isinstance (obj , (Expression , Parameter , Properties )):
286293 return obj .expr
294+ if isinstance (obj , CallbackOutput ):
295+ step_name = callback_output_to_step_map [obj .output_name ]
296+ return obj .expr (step_name )
287297 if isinstance (obj , dict ):
288298 new = obj .__class__ ()
289299 for key , value in obj .items ():
290- new [key ] = interpolate (value )
300+ new [key ] = interpolate (value , callback_output_to_step_map )
291301 elif isinstance (obj , (list , set , tuple )):
292- new = obj .__class__ (interpolate (value ) for value in obj )
302+ new = obj .__class__ (interpolate (value , callback_output_to_step_map ) for value in obj )
293303 else :
294304 return obj
295305 return new
296306
297307
308+ def _map_callback_outputs (steps : List [Step ]):
309+ """Iterate over the provided steps, building a map of callback output parameters to step names.
310+
311+ Args:
312+ step (List[Step]): The steps list.
313+ """
314+
315+ callback_output_map = {}
316+ for step in steps :
317+ if isinstance (step , CallbackStep ):
318+ if step .outputs :
319+ for output in step .outputs :
320+ callback_output_map [output .output_name ] = step .name
321+
322+ return callback_output_map
323+
324+
298325def update_args (args : Dict [str , Any ], ** kwargs ):
299326 """Updates the request arguments dict with a value, if populated.
300327
0 commit comments