Skip to content

Commit 76119f9

Browse files
committed
feat: Support selective pipeline execution for function step
1 parent f2b47ab commit 76119f9

File tree

5 files changed

+123
-42
lines changed

5 files changed

+123
-42
lines changed

src/sagemaker/remote_function/core/pipeline_variables.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from sagemaker.s3 import s3_path_join
2121
from sagemaker.remote_function.core.serialization import deserialize_obj_from_s3
22+
from sagemaker.workflow.step_outputs import get_step
2223

2324

2425
@dataclass
@@ -166,6 +167,7 @@ def __init__(
166167
hmac_key: str,
167168
parameter_resolver: _ParameterResolver,
168169
execution_variable_resolver: _ExecutionVariableResolver,
170+
properties_resolver: _PropertiesResolver,
169171
s3_base_uri: str,
170172
**settings,
171173
):
@@ -184,6 +186,7 @@ def __init__(
184186
self._s3_base_uri = s3_base_uri
185187
self._parameter_resolver = parameter_resolver
186188
self._execution_variable_resolver = execution_variable_resolver
189+
self._properties_resolver = properties_resolver
187190
# different delayed returns can have the same uri, so we need to dedupe
188191
uris = {
189192
self._resolve_delayed_return_uri(delayed_return) for delayed_return in delayed_returns
@@ -223,6 +226,8 @@ def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn):
223226
uri.append(self._execution_variable_resolver.resolve(component))
224227
elif isinstance(component, _S3BaseUriIdentifier):
225228
uri.append(self._s3_base_uri)
229+
elif isinstance(component, _Properties):
230+
uri.append(self._properties_resolver.resolve(component))
226231
else:
227232
uri.append(component)
228233
return s3_path_join(*uri)
@@ -276,6 +281,7 @@ def resolve_pipeline_variables(
276281
hmac_key=hmac_key,
277282
parameter_resolver=parameter_resolver,
278283
execution_variable_resolver=execution_variable_resolver,
284+
properties_resolver=properties_resolver,
279285
s3_base_uri=s3_base_uri,
280286
**settings,
281287
)
@@ -322,39 +328,43 @@ def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict
322328
func_args: function args.
323329
func_kwargs: function kwargs.
324330
"""
331+
converted_func_args = tuple(_convert_pipeline_variable_to_pickleable(arg) for arg in func_args)
332+
converted_func_kwargs = {
333+
key: _convert_pipeline_variable_to_pickleable(arg) for key, arg in func_kwargs.items()
334+
}
325335

326-
from sagemaker.workflow.entities import PipelineVariable
327-
328-
from sagemaker.workflow.execution_variables import ExecutionVariables
329-
330-
from sagemaker.workflow.function_step import DelayedReturn
331-
332-
# Notes:
333-
# 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
334-
# when defining function steps. After step-level arg serialization,
335-
# it's hard to update the s3_base_uri in pipeline compile time.
336-
# Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
337-
# 2. For saying s3_root_uri is unknown, it's because when defining function steps,
338-
# the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
339-
# should be retrieved from the pipeline's sagemaker_session.
340-
def convert(arg):
341-
if isinstance(arg, DelayedReturn):
342-
return _DelayedReturn(
343-
uri=[
344-
_S3BaseUriIdentifier(),
345-
ExecutionVariables.PIPELINE_EXECUTION_ID._pickleable,
346-
arg._step.name,
347-
"results",
348-
],
349-
reference_path=arg._reference_path,
350-
)
336+
return converted_func_args, converted_func_kwargs
351337

352-
if isinstance(arg, PipelineVariable):
353-
return arg._pickleable
354338

355-
return arg
339+
def _convert_pipeline_variable_to_pickleable(arg):
340+
"""Convert a pipeline variable to pickleable."""
341+
from sagemaker.workflow.entities import PipelineVariable
356342

357-
converted_func_args = tuple(convert(arg) for arg in func_args)
358-
converted_func_kwargs = {key: convert(arg) for key, arg in func_kwargs.items()}
343+
from sagemaker.workflow.function_step import DelayedReturn
359344

360-
return converted_func_args, converted_func_kwargs
345+
if isinstance(arg, DelayedReturn):
346+
# Notes:
347+
# 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
348+
# when defining function steps. After step-level arg serialization,
349+
# it's hard to update the s3_base_uri in pipeline compile time.
350+
# Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
351+
# 2. For saying s3_root_uri is unknown, it's because when defining function steps,
352+
# the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
353+
# should be retrieved from the pipeline's sagemaker_session.
354+
355+
container_args = get_step(arg)._properties.AlgorithmSpecification.ContainerArguments
356+
execution_id = container_args[11]._pickleable
357+
return _DelayedReturn(
358+
uri=[
359+
_S3BaseUriIdentifier(),
360+
execution_id,
361+
arg._step.name,
362+
"results",
363+
],
364+
reference_path=arg._reference_path,
365+
)
366+
367+
if isinstance(arg, PipelineVariable):
368+
return arg._pickleable
369+
370+
return arg

src/sagemaker/remote_function/job.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@
5959
from sagemaker.s3 import s3_path_join, S3Uploader
6060
from sagemaker import vpc_utils
6161
from sagemaker.remote_function.core.stored_function import StoredFunction, _SerializedData
62-
from sagemaker.remote_function.core.pipeline_variables import Context
62+
from sagemaker.remote_function.core.pipeline_variables import (
63+
Context,
64+
_convert_pipeline_variable_to_pickleable,
65+
_Properties,
66+
)
6367
from sagemaker.remote_function.runtime_environment.runtime_environment_manager import (
6468
RuntimeEnvironmentManager,
6569
_DependencySettings,
@@ -72,6 +76,7 @@
7276
copy_workdir,
7377
resolve_custom_file_filter_from_config_file,
7478
)
79+
from sagemaker.workflow.function_step import DelayedReturn
7580

7681
if TYPE_CHECKING:
7782
from sagemaker.workflow.entities import PipelineVariable
@@ -804,6 +809,13 @@ def compile(
804809
if isinstance(arg, (Parameter, ExecutionVariable, Properties)):
805810
container_args.extend([arg.expr["Get"], arg.to_string()])
806811

812+
if isinstance(arg, DelayedReturn):
813+
uri = _convert_pipeline_variable_to_pickleable(arg).uri
814+
for uri_element in uri:
815+
if not isinstance(uri_element, _Properties):
816+
continue
817+
container_args.extend([uri_element.path, {"Get": uri_element.path}])
818+
807819
if run_info is not None:
808820
container_args.extend(["--run_in_context", json.dumps(dataclasses.asdict(run_info))])
809821
elif _RunContext.get_current_run() is not None:

src/sagemaker/workflow/function_step.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535

3636
from sagemaker.workflow.execution_variables import ExecutionVariables
37+
from sagemaker.workflow.properties import Properties
3738
from sagemaker.workflow.retry import RetryPolicy
3839
from sagemaker.workflow.steps import Step, ConfigurableRetryStep, StepTypeEnum
3940
from sagemaker.workflow.step_collections import StepCollection
@@ -100,6 +101,9 @@ def __init__(
100101
self._step_kwargs = kwargs
101102

102103
self.__job_settings = None
104+
self._properties = Properties(
105+
step_name=name, step=self, shape_name="DescribeTrainingJobResponse"
106+
)
103107

104108
(
105109
self._converted_func_args,

tests/integ/sagemaker/workflow/helpers.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,24 @@ def create_and_execute_pipeline(
3939
step_result_type=None,
4040
step_result_value=None,
4141
wait_duration=400, # seconds
42+
selective_execution_config=None,
4243
):
43-
response = pipeline.create(role)
44-
45-
create_arn = response["PipelineArn"]
46-
assert re.match(
47-
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
48-
create_arn,
44+
create_arn = None
45+
if not selective_execution_config:
46+
response = pipeline.create(role)
47+
create_arn = response["PipelineArn"]
48+
assert re.match(
49+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
50+
create_arn,
51+
)
52+
53+
execution = pipeline.start(
54+
parameters=execution_parameters, selective_execution_config=selective_execution_config
4955
)
5056

51-
execution = pipeline.start(parameters=execution_parameters)
52-
response = execution.describe()
53-
assert response["PipelineArn"] == create_arn
57+
if create_arn:
58+
response = execution.describe()
59+
assert response["PipelineArn"] == create_arn
5460

5561
wait_pipeline_execution(execution=execution, delay=20, max_attempts=int(wait_duration / 20))
5662

@@ -71,6 +77,16 @@ def create_and_execute_pipeline(
7177
if step_result_value:
7278
result = execution.result(execution_steps[0]["StepName"])
7379
assert result == step_result_value, f"Expected {step_result_value}, instead found {result}"
80+
81+
if selective_execution_config:
82+
for exe_step in execution_steps:
83+
if exe_step["StepName"] in selective_execution_config.selected_steps:
84+
continue
85+
assert (
86+
exe_step["SelectiveExecutionResult"]["SourcePipelineExecutionArn"]
87+
== selective_execution_config.source_pipeline_execution_arn
88+
)
89+
7490
return execution, execution_steps
7591

7692

tests/integ/sagemaker/workflow/test_step_decorator.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
ParameterFloat,
4040
ParameterBoolean,
4141
)
42+
from sagemaker.workflow.selective_execution_config import SelectiveExecutionConfig
4243
from sagemaker.workflow.step_outputs import get_step
4344
from sagemaker.workflow.steps import ProcessingStep
4445

@@ -246,6 +247,20 @@ def sum(a, b):
246247
)
247248

248249
try:
250+
execution, _ = create_and_execute_pipeline(
251+
pipeline=pipeline,
252+
pipeline_name=pipeline_name,
253+
region_name=region_name,
254+
role=role,
255+
no_of_steps=2,
256+
last_step_name="sum",
257+
execution_parameters=dict(),
258+
step_status="Succeeded",
259+
step_result_type=int,
260+
step_result_value=7,
261+
)
262+
263+
# Test Selective Pipeline Execution on function step1 -> [select: function step2]
249264
create_and_execute_pipeline(
250265
pipeline=pipeline,
251266
pipeline_name=pipeline_name,
@@ -256,7 +271,13 @@ def sum(a, b):
256271
execution_parameters=dict(),
257272
step_status="Succeeded",
258273
step_result_type=int,
274+
step_result_value=7,
275+
selective_execution_config=SelectiveExecutionConfig(
276+
source_pipeline_execution_arn=execution.arn,
277+
selected_steps=[get_step(step_output_b).name],
278+
),
259279
)
280+
260281
finally:
261282
try:
262283
pipeline.delete()
@@ -379,7 +400,7 @@ def func_2(*args):
379400
)
380401

381402
try:
382-
create_and_execute_pipeline(
403+
execution, _ = create_and_execute_pipeline(
383404
pipeline=pipeline,
384405
pipeline_name=pipeline_name,
385406
region_name=region_name,
@@ -392,6 +413,24 @@ def func_2(*args):
392413
step_result_value=(3, True, 2.0, "string", "Completed", 3),
393414
wait_duration=600,
394415
)
416+
# Test Selective Pipeline Execution on regular step -> [select: function step]
417+
execution, _ = create_and_execute_pipeline(
418+
pipeline=pipeline,
419+
pipeline_name=pipeline_name,
420+
region_name=region_name,
421+
role=role,
422+
no_of_steps=3,
423+
last_step_name="func",
424+
execution_parameters=dict(param_a=10),
425+
step_status="Succeeded",
426+
step_result_type=tuple,
427+
step_result_value=(10, True, 2.0, "string", "Completed", 3),
428+
wait_duration=600,
429+
selective_execution_config=SelectiveExecutionConfig(
430+
source_pipeline_execution_arn=execution.arn,
431+
selected_steps=[get_step(final_output).name],
432+
),
433+
)
395434

396435
finally:
397436
try:

0 commit comments

Comments
 (0)