1515from abc import ABC , abstractmethod
1616
1717import json
18- from copy import deepcopy
1918from datetime import datetime
2019from typing import Dict , List , Union
2120from botocore .exceptions import ClientError
2221
2322from sagemaker .workflow .conditions import ConditionTypeEnum
23+ from sagemaker .workflow .function_step import DelayedReturn
2424from sagemaker .workflow .steps import StepTypeEnum , Step
2525from sagemaker .workflow .step_collections import StepCollection
2626from sagemaker .workflow .entities import PipelineVariable
@@ -87,7 +87,7 @@ def evaluate_step_arguments(self, step):
8787 def _parse_arguments (self , obj , step_name ):
8888 """Parse and evaluate arguments field"""
8989 if isinstance (obj , dict ):
90- obj_copy = deepcopy ( obj )
90+ obj_copy = {}
9191 for k , v in obj .items ():
9292 obj_copy [k ] = self ._parse_arguments (v , step_name )
9393 return obj_copy
@@ -108,16 +108,17 @@ def evaluate_pipeline_variable(self, pipeline_variable, step_name):
108108 elif isinstance (pipeline_variable , Parameter ):
109109 value = self .execution .pipeline_parameters .get (pipeline_variable .name )
110110 elif isinstance (pipeline_variable , Join ):
111- evaluated = [
112- str (self .evaluate_pipeline_variable (v , step_name )) for v in pipeline_variable .values
113- ]
114- value = pipeline_variable .on .join (evaluated )
111+ value = self ._evaluate_join_function (pipeline_variable , step_name )
115112 elif isinstance (pipeline_variable , Properties ):
116113 value = self ._evaluate_property_reference (pipeline_variable , step_name )
117114 elif isinstance (pipeline_variable , ExecutionVariable ):
118115 value = self ._evaluate_execution_variable (pipeline_variable )
119116 elif isinstance (pipeline_variable , JsonGet ):
120117 value = self ._evaluate_json_get_function (pipeline_variable , step_name )
118+ elif isinstance (pipeline_variable , DelayedReturn ):
119+ # DelayedReturn showing up in arguments, meaning that it's data referenced
120+ # We should convert it to JsonGet and evaluate the JsonGet object
121+ value = self ._evaluate_json_get_function (pipeline_variable ._to_json_get (), step_name )
121122 else :
122123 self .execution .update_step_failure (
123124 step_name , f"Unrecognized pipeline variable { pipeline_variable .expr } ."
@@ -127,6 +128,13 @@ def evaluate_pipeline_variable(self, pipeline_variable, step_name):
127128 self .execution .update_step_failure (step_name , f"{ pipeline_variable .expr } is undefined." )
128129 return value
129130
131+ def _evaluate_join_function (self , pipeline_variable , step_name ):
132+ """Evaluate join function runtime value"""
133+ evaluated = [
134+ str (self .evaluate_pipeline_variable (v , step_name )) for v in pipeline_variable .values
135+ ]
136+ return pipeline_variable .on .join (evaluated )
137+
130138 def _evaluate_property_reference (self , pipeline_variable , step_name ):
131139 """Evaluate property reference runtime value."""
132140 try :
@@ -156,6 +164,43 @@ def _evaluate_execution_variable(self, pipeline_variable):
156164
157165 def _evaluate_json_get_function (self , pipeline_variable , step_name ):
158166 """Evaluate join function runtime value."""
167+ s3_bucket = None
168+ s3_key = None
169+ try :
170+ if pipeline_variable .property_file :
171+ s3_bucket , s3_key = self ._evaluate_json_get_property_file_reference (
172+ pipeline_variable = pipeline_variable , step_name = step_name
173+ )
174+ else :
175+ # JsonGet's s3_uri can only be a Join function
176+ # This has been validated in _validate_json_get_function
177+ s3_uri = self ._evaluate_join_function (pipeline_variable .s3_uri , step_name )
178+ s3_bucket , s3_key = parse_s3_url (s3_uri )
179+
180+ file_content = self .sagemaker_session .read_s3_file (s3_bucket , s3_key )
181+ file_json = json .loads (file_content )
182+ return get_using_dot_notation (file_json , pipeline_variable .json_path )
183+ except ClientError as e :
184+ self .execution .update_step_failure (
185+ step_name ,
186+ f"Received an error while reading file { s3_path_join ('s3://' , s3_bucket , s3_key )} "
187+ f"from S3: { e .response .get ('Code' )} : { e .response .get ('Message' )} " ,
188+ )
189+ except json .JSONDecodeError :
190+ self .execution .update_step_failure (
191+ step_name ,
192+ f"Contents of file { s3_path_join ('s3://' , s3_bucket , s3_key )} are not "
193+ f"in valid JSON format." ,
194+ )
195+ except ValueError :
196+ self .execution .update_step_failure (
197+ step_name , f"Invalid json path '{ pipeline_variable .json_path } '"
198+ )
199+
200+ def _evaluate_json_get_property_file_reference (
201+ self , pipeline_variable : JsonGet , step_name : str
202+ ):
203+ """Evaluate JsonGet's property file reference to get s3 bucket and key"""
159204 property_file_reference = pipeline_variable .property_file
160205 property_file = None
161206 if isinstance (property_file_reference , str ):
@@ -180,28 +225,9 @@ def _evaluate_json_get_function(self, pipeline_variable, step_name):
180225 processing_output_s3_bucket = processing_step_response ["ProcessingOutputConfig" ]["Outputs" ][
181226 property_file .output_name
182227 ]["S3Output" ]["S3Uri" ]
183- try :
184- s3_bucket , s3_key_prefix = parse_s3_url (processing_output_s3_bucket )
185- file_content = self .sagemaker_session .read_s3_file (
186- s3_bucket , s3_path_join (s3_key_prefix , property_file .path )
187- )
188- file_json = json .loads (file_content )
189- return get_using_dot_notation (file_json , pipeline_variable .json_path )
190- except ClientError as e :
191- self .execution .update_step_failure (
192- step_name ,
193- f"Received an error while file reading file '{ property_file .path } ' from S3: "
194- f"{ e .response .get ('Code' )} : { e .response .get ('Message' )} " ,
195- )
196- except json .JSONDecodeError :
197- self .execution .update_step_failure (
198- step_name ,
199- f"Contents of property file '{ property_file .name } ' are not in valid JSON format." ,
200- )
201- except ValueError :
202- self .execution .update_step_failure (
203- step_name , f"Invalid json path '{ pipeline_variable .json_path } '"
204- )
228+ s3_bucket , s3_key_prefix = parse_s3_url (processing_output_s3_bucket )
229+ s3_key = s3_path_join (s3_key_prefix , property_file .path )
230+ return s3_bucket , s3_key
205231
206232
207233class _StepExecutor (ABC ):
0 commit comments