Skip to content
46 changes: 35 additions & 11 deletions src/stepfunctions/steps/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,27 +254,27 @@ def accept(self, visitor):

def add_retry(self, retry):
"""
Add a Retry block to the tail end of the list of retriers for the state.
Add a Retry block or a list of Retry blocks to the tail end of the list of retriers for the state.

Args:
retry (Retry): Retry block to add.
retry (Retry or list(Retry)): Retry block(s) to add.
"""
if Field.Retry in self.allowed_fields():
self.retries.append(retry)
self.retries.extend(retry) if isinstance(retry, list) else self.retries.append(retry)
else:
raise ValueError("{state_type} state does not support retry field. ".format(state_type=type(self).__name__))
raise ValueError(f"{type(self).__name__} state does not support retry field. ")

def add_catch(self, catch):
"""
Add a Catch block to the tail end of the list of catchers for the state.
Add a Catch block or a list of Catch blocks to the tail end of the list of catchers for the state.

Args:
catch (Catch): Catch block to add.
catch (Catch or list(Catch): Catch block(s) to add.
"""
if Field.Catch in self.allowed_fields():
self.catches.append(catch)
self.catches.extend(catch) if isinstance(catch, list) else self.catches.append(catch)
else:
raise ValueError("{state_type} state does not support catch field. ".format(state_type=type(self).__name__))
raise ValueError(f"{type(self).__name__} state does not support catch field. ")

def to_dict(self):
result = super(State, self).to_dict()
Expand Down Expand Up @@ -487,10 +487,12 @@ class Parallel(State):
A Parallel state causes the interpreter to execute each branch as concurrently as possible, and wait until each branch terminates (reaches a terminal state) before processing the next state in the Chain.
"""

def __init__(self, state_id, **kwargs):
def __init__(self, state_id, retry=None, catch=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
retry (Retry or list(Retry), optional): Retry block(s) to add to list of Retriers that define a retry policy in case the state encounters runtime errors
catch (Catch or list(Catch), optional): Catch block(s) to add to list of Catchers that define a fallback state that is executed if the state encounters runtime errors and its retry policy is exhausted or isn't defined
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
Expand All @@ -500,6 +502,12 @@ def __init__(self, state_id, **kwargs):
super(Parallel, self).__init__(state_id, 'Parallel', **kwargs)
self.branches = []

if retry:
self.add_retry(retry)

if catch:
self.add_catch(catch)

def allowed_fields(self):
return [
Field.Comment,
Expand Down Expand Up @@ -536,11 +544,13 @@ class Map(State):
A Map state can accept an input with a list of items, execute a state or chain for each item in the list, and return a list, with all corresponding results of each execution, as its output.
"""

def __init__(self, state_id, **kwargs):
def __init__(self, state_id, retry=None, catch=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
iterator (State or Chain): State or chain to execute for each of the items in `items_path`.
retry (Retry or list(Retry), optional): Retry block(s) to add to list of Retriers that define a retry policy in case the state encounters runtime errors
catch (Catch or list(Catch), optional): Catch block(s) to add to list of Catchers that define a fallback state that is executed if the state encounters runtime errors and its retry policy is exhausted or isn't defined
items_path (str, optional): Path in the input for items to iterate over. (default: '$')
max_concurrency (int, optional): Maximum number of iterations to have running at any given point in time. (default: 0)
comment (str, optional): Human-readable comment or description. (default: None)
Expand All @@ -551,6 +561,12 @@ def __init__(self, state_id, **kwargs):
"""
super(Map, self).__init__(state_id, 'Map', **kwargs)

if retry:
self.add_retry(retry)

if catch:
self.add_catch(catch)

def attach_iterator(self, iterator):
"""
Attach `State` or `Chain` as iterator to the Map state, that will execute for each of the items in `items_path`. If an iterator was attached previously with the Map state, it will be replaced.
Expand Down Expand Up @@ -586,10 +602,12 @@ class Task(State):
Task State causes the interpreter to execute the work identified by the state’s `resource` field.
"""

def __init__(self, state_id, **kwargs):
def __init__(self, state_id, retry=None, catch=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
retry (Retry or list(Retry), optional): Retry block(s) to add to list of Retriers that define a retry policy in case the state encounters runtime errors
catch (Catch or list(Catch), optional): Catch block(s) to add to list of Catchers that define a fallback state that is executed if the state encounters runtime errors and its retry policy is exhausted or isn't defined
resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI.
timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60)
timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer.
Expand All @@ -608,6 +626,12 @@ def __init__(self, state_id, **kwargs):
if self.heartbeat_seconds is not None and self.heartbeat_seconds_path is not None:
raise ValueError("Only one of 'heartbeat_seconds' or 'heartbeat_seconds_path' can be provided.")

if retry:
self.add_retry(retry)

if catch:
self.add_catch(catch)

def allowed_fields(self):
return [
Field.Comment,
Expand Down
107 changes: 106 additions & 1 deletion tests/integ/test_state_machine_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,59 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, catch_state_result)


def test_state_machine_creation_with_catch_in_constructor(sfn_client, sfn_role_arn, training_job_parameters):
catch_state_name = "TaskWithCatchState"
all_fail_error = "States.ALL"
all_error_state_name = "Catch All End"
catch_state_result = "Catch Result"
task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync"

# change the parameters to cause task state to fail
training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image"

asl_state_machine_definition = {
"StartAt": catch_state_name,
"States": {
catch_state_name: {
"Resource": task_resource,
"Parameters": training_job_parameters,
"Type": "Task",
"End": True,
"Catch": [
{
"ErrorEquals": [
all_fail_error
],
"Next": all_error_state_name
}
]
},
all_error_state_name: {
"Type": "Pass",
"Result": catch_state_result,
"End": True
}
}
}
task = steps.Task(
catch_state_name,
parameters=training_job_parameters,
resource=task_resource,
catch=steps.Catch(
error_equals=[all_fail_error],
next_step=steps.Pass(all_error_state_name, result=catch_state_result)
)
)

workflow = Workflow(
unique_name_from_base('Test_Catch_In_Constructor_Workflow'),
definition=task,
role=sfn_role_arn
)

workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, catch_state_result)


def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters):
retry_state_name = "RetryStateName"
all_fail_error = "Starts.ALL"
Expand Down Expand Up @@ -531,4 +584,56 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
role=sfn_role_arn
)

workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None)
workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None)


def test_state_machine_creation_with_retry_in_constructor(sfn_client, sfn_role_arn, training_job_parameters):
retry_state_name = "RetryStateName"
all_fail_error = "Starts.ALL"
interval_seconds = 1
max_attempts = 2
backoff_rate = 2
task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync"

# change the parameters to cause task state to fail
training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image"

asl_state_machine_definition = {
"StartAt": retry_state_name,
"States": {
retry_state_name: {
"Resource": task_resource,
"Parameters": training_job_parameters,
"Type": "Task",
"End": True,
"Retry": [
{
"ErrorEquals": [all_fail_error],
"IntervalSeconds": interval_seconds,
"MaxAttempts": max_attempts,
"BackoffRate": backoff_rate
}
]
}
}
}

task = steps.Task(
retry_state_name,
parameters=training_job_parameters,
resource=task_resource,
retry=steps.Retry(
error_equals=[all_fail_error],
interval_seconds=interval_seconds,
max_attempts=max_attempts,
backoff_rate=backoff_rate
)
)

workflow = Workflow(
unique_name_from_base('Test_Retry_In_Constructor_Workflow'),
definition=task,
role=sfn_role_arn
)

workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None)
64 changes: 63 additions & 1 deletion tests/unit/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,4 +469,66 @@ def test_default_paths_not_converted_to_null():
assert '"OutputPath": null' not in task_state.to_json()



RETRY = Retry(error_equals=['ErrorA', 'ErrorB'], interval_seconds=1, max_attempts=2, backoff_rate=2)
RETRIES = [RETRY, Retry(error_equals=['ErrorC'], interval_seconds=5)]
EXPECTED_RETRY = [{'ErrorEquals': ['ErrorA', 'ErrorB'], 'IntervalSeconds': 1, 'BackoffRate': 2, 'MaxAttempts': 2}]
EXPECTED_RETRIES = EXPECTED_RETRY + [{'ErrorEquals': ['ErrorC'], 'IntervalSeconds': 5}]

CATCH = Catch(error_equals=['States.ALL'], next_step=Pass('End State'))
CATCHES = [CATCH, Catch(error_equals=['States.TaskFailed'], next_step=Pass('Next State'))]
EXPECTED_CATCH = [{'ErrorEquals': ['States.ALL'], 'Next': 'End State'}]
EXPECTED_CATCHES = EXPECTED_CATCH + [{'ErrorEquals': ['States.TaskFailed'], 'Next': 'Next State'}]


@pytest.mark.parametrize("state, state_id, extra_args, retry, expected_retry", [
(Parallel, 'Parallel', {}, RETRY, EXPECTED_RETRY),
(Parallel, 'Parallel', {}, RETRIES, EXPECTED_RETRIES),
(Map, 'Map', {'iterator': Pass('Iterator')}, RETRY, EXPECTED_RETRY),
(Map, 'Map', {'iterator': Pass('Iterator')}, RETRIES, EXPECTED_RETRIES),
(Task, 'Task', {}, RETRY, EXPECTED_RETRY),
(Task, 'Task', {}, RETRIES, EXPECTED_RETRIES)
])
def test_state_creation_with_retry(state, state_id, extra_args, retry, expected_retry):
step = state(state_id, retry=retry, **extra_args)
assert step.to_dict()['Retry'] == expected_retry


@pytest.mark.parametrize("state, state_id, extra_args, catch, expected_catch", [
(Parallel, 'Parallel', {}, CATCH, EXPECTED_CATCH),
(Parallel, 'Parallel', {}, CATCHES, EXPECTED_CATCHES),
(Map, 'Map', {'iterator': Pass('Iterator')}, CATCH, EXPECTED_CATCH),
(Map, 'Map', {'iterator': Pass('Iterator')}, CATCHES, EXPECTED_CATCHES),
(Task, 'Task', {}, CATCH, EXPECTED_CATCH),
(Task, 'Task', {}, CATCHES, EXPECTED_CATCHES)
])
def test_state_creation_with_catch(state, state_id, extra_args, catch, expected_catch):
step = state(state_id, catch=catch, **extra_args)
assert step.to_dict()['Catch'] == expected_catch


@pytest.mark.parametrize("state, state_id, extra_args, retry, expected_retry", [
(Parallel, 'Parallel', {}, RETRY, EXPECTED_RETRY),
(Parallel, 'Parallel', {}, RETRIES, EXPECTED_RETRIES),
(Map, 'Map', {'iterator': Pass('Iterator')}, RETRIES, EXPECTED_RETRIES),
(Map, 'Map', {'iterator': Pass('Iterator')}, RETRY, EXPECTED_RETRY),
(Task, 'Task', {}, RETRY, EXPECTED_RETRY),
(Task, 'Task', {}, RETRIES, EXPECTED_RETRIES)
])
def test_state_with_added_retry(state, state_id, extra_args, retry, expected_retry):
step = state(state_id, **extra_args)
step.add_retry(retry)
assert step.to_dict()['Retry'] == expected_retry


@pytest.mark.parametrize("state, state_id, extra_args, catch, expected_catch", [
(Parallel, 'Parallel', {}, CATCH, EXPECTED_CATCH),
(Parallel, 'Parallel', {}, CATCHES, EXPECTED_CATCHES),
(Map, 'Map', {'iterator': Pass('Iterator')}, CATCH, EXPECTED_CATCH),
(Map, 'Map', {'iterator': Pass('Iterator')}, CATCHES, EXPECTED_CATCHES),
(Task, 'Task', {}, CATCHES, EXPECTED_CATCHES),
(Task, 'Task', {}, CATCHES, EXPECTED_CATCHES)
])
def test_state_with_added_catch(state, state_id, extra_args, catch, expected_catch):
step = state(state_id, **extra_args)
step.add_catch(catch)
assert step.to_dict()['Catch'] == expected_catch