-
Notifications
You must be signed in to change notification settings - Fork 88
Support placeholders for processing step #155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
927b24f
003b5e8
a7700a6
6b6443a
7f6ef30
00830f3
c708da7
2ea9e1f
17543ed
36e2ee8
ea40f7c
e499108
4c63229
34bb281
a098c61
06eb069
da99c92
37b2422
c433576
fd640ab
1dfa0e3
6143783
ebc5e22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,4 +22,5 @@ class MissingRequiredParameter(Exception): | |
|
||
|
||
class DuplicateStatesInChain(Exception): | ||
pass | ||
pass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ | |
from stepfunctions.inputs import Placeholder | ||
from stepfunctions.steps.states import Task | ||
from stepfunctions.steps.fields import Field | ||
from stepfunctions.steps.utils import tags_dict_to_kv_list | ||
from stepfunctions.steps.utils import merge_dicts, tags_dict_to_kv_list | ||
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn | ||
|
||
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config | ||
|
@@ -30,6 +30,7 @@ | |
|
||
SAGEMAKER_SERVICE_NAME = "sagemaker" | ||
|
||
|
||
class SageMakerApi(Enum): | ||
CreateTrainingJob = "createTrainingJob" | ||
CreateTransformJob = "createTransformJob" | ||
|
@@ -477,9 +478,15 @@ class ProcessingStep(Task): | |
|
||
""" | ||
Creates a Task State to execute a SageMaker Processing Job. | ||
|
||
The following properties can be passed down as kwargs to the sagemaker.processing.Processor to be used dynamically | ||
in the processing job (compatible with Placeholders): role, image_uri, instance_count, instance_type, | ||
volume_size_in_gb, volume_kms_key, output_kms_key | ||
|
||
""" | ||
|
||
def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, tags=None, **kwargs): | ||
def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, | ||
container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, | ||
tags=None, **kwargs): | ||
""" | ||
Args: | ||
ca-nguyen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. | ||
|
@@ -491,15 +498,16 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp | |
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for | ||
the processing job. These can be specified as either path strings or | ||
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None). | ||
experiment_config (dict, optional): Specify the experiment config for the processing. (Default: None) | ||
container_arguments ([str]): The arguments for a container used to run a processing job. | ||
container_entrypoint ([str]): The entrypoint for a container used to run a processing job. | ||
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker | ||
experiment_config (dict or Placeholder, optional): Specify the experiment config for the processing. (Default: None) | ||
container_arguments ([str] or Placeholder): The arguments for a container used to run a processing job. | ||
container_entrypoint ([str] or Placeholder): The entrypoint for a container used to run a processing job. | ||
kms_key_id (str or Placeholder): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker | ||
uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key, | ||
ARN of a KMS key, alias of a KMS key, or alias of a KMS key. | ||
The KmsKeyId is applied to all outputs. | ||
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True) | ||
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource. | ||
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource. | ||
parameters(dict, optional): The value of this field becomes the effective input for the state. | ||
""" | ||
if wait_for_completion: | ||
""" | ||
|
@@ -518,22 +526,26 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp | |
SageMakerApi.CreateProcessingJob) | ||
|
||
if isinstance(job_name, str): | ||
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) | ||
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) | ||
else: | ||
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id) | ||
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id) | ||
|
||
if isinstance(job_name, Placeholder): | ||
ca-nguyen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
parameters['ProcessingJobName'] = job_name | ||
processing_parameters['ProcessingJobName'] = job_name | ||
|
||
if experiment_config is not None: | ||
parameters['ExperimentConfig'] = experiment_config | ||
processing_parameters['ExperimentConfig'] = experiment_config | ||
|
||
if tags: | ||
parameters['Tags'] = tags_dict_to_kv_list(tags) | ||
|
||
if 'S3Operations' in parameters: | ||
del parameters['S3Operations'] | ||
|
||
kwargs[Field.Parameters.value] = parameters | ||
processing_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags) | ||
|
||
if 'S3Operations' in processing_parameters: | ||
del processing_parameters['S3Operations'] | ||
|
||
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict): | ||
# Update processing_parameters with input parameters | ||
merge_dicts(processing_parameters, kwargs[Field.Parameters.value], "Processing Parameters", | ||
ca-nguyen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"Input Parameters") | ||
|
||
kwargs[Field.Parameters.value] = processing_parameters | ||
super(ProcessingStep, self).__init__(state_id, **kwargs) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -14,6 +14,7 @@ | |||||
|
||||||
import boto3 | ||||||
import logging | ||||||
from stepfunctions.inputs import Placeholder | ||||||
|
||||||
logger = logging.getLogger('stepfunctions') | ||||||
|
||||||
|
@@ -45,3 +46,24 @@ def get_aws_partition(): | |||||
return cur_partition | ||||||
|
||||||
return cur_partition | ||||||
|
||||||
|
||||||
def merge_dicts(first, second, first_name, second_name): | ||||||
|
def merge_dicts(first, second, first_name, second_name): | |
def merge_dicts(target, source): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 - I also like to push for doc strings where behaviour is not entirely intuitive. i.e. what happens if there are clashes, are overwrites allowed, etc.
ca-nguyen marked this conversation as resolved.
Show resolved
Hide resolved
ca-nguyen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
ca-nguyen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question: Do we think this is useful? If not, can just use Python's built-in dict.update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The built in update() does not take into account nested dictionary values - for ex:
d1 = {'a': {'aa': 1, 'bb': 2, 'c': 3}}
d2 = {'a': {'bb': 1}}
d1.update(d2)
print(d1)
Will have following output: {'a': {'bb': 1}}
Since we would expect to get {'a': {'aa': 1, 'bb': 1, 'c': 3}}
, we can't use the update() function in our case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initially added them to facilitate troubleshooting, but I'm open to remove the logs if we deem them not useful enough or too noisy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the expected behaviour is well documented it seems unnecessary. Since the method only exists for logging, if we get rid of it there's less code to maintain. What do you think, @shivlaks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The built in update() does not take into account nested dictionary values
Missed this comment. Since we need a deep merge, dict.update is not going to work here
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
from sagemaker.tuner import HyperparameterTuner | ||
from sagemaker.processing import ProcessingInput, ProcessingOutput | ||
|
||
from stepfunctions.inputs import ExecutionInput | ||
from stepfunctions.steps import Chain | ||
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep | ||
from stepfunctions.workflow import Workflow | ||
|
@@ -352,3 +353,100 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien | |
# Cleanup | ||
state_machine_delete_wait(sfn_client, workflow.state_machine_arn) | ||
# End of Cleanup | ||
|
||
|
||
def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn, | ||
sagemaker_role_arn): | ||
region = boto3.session.Session().region_name | ||
input_data = 's3://sagemaker-sample-data-{}/processing/census/census-income.csv'.format(region) | ||
|
||
|
||
input_s3 = sagemaker_session.upload_data( | ||
path=os.path.join(DATA_DIR, 'sklearn_processing'), | ||
bucket=sagemaker_session.default_bucket(), | ||
key_prefix='integ-test-data/sklearn_processing/code' | ||
) | ||
|
||
output_s3 = 's3://' + sagemaker_session.default_bucket() + '/integ-test-data/sklearn_processing' | ||
|
||
|
||
inputs = [ | ||
ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'), | ||
ProcessingInput(source=input_s3 + '/preprocessor.py', destination='/opt/ml/processing/input/code', | ||
input_name='code'), | ||
] | ||
|
||
outputs = [ | ||
ProcessingOutput(source='/opt/ml/processing/train', destination=output_s3 + '/train_data', | ||
output_name='train_data'), | ||
ProcessingOutput(source='/opt/ml/processing/test', destination=output_s3 + '/test_data', | ||
output_name='test_data'), | ||
] | ||
|
||
# Build workflow definition | ||
execution_input = ExecutionInput(schema={ | ||
'image_uri': str, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we're only using these values for test purposes, using the direct string values for better code readability |
||
'instance_count': int, | ||
'entrypoint': str, | ||
'role': str, | ||
'volume_size_in_gb': int, | ||
'max_runtime_in_seconds': int, | ||
'container_arguments': [str], | ||
}) | ||
|
||
parameters = { | ||
'AppSpecification': { | ||
'ContainerEntrypoint': execution_input['entrypoint'], | ||
'ImageUri': execution_input['image_uri'] | ||
}, | ||
'ProcessingResources': { | ||
'ClusterConfig': { | ||
'InstanceCount': execution_input['instance_count'], | ||
'VolumeSizeInGB': execution_input['volume_size_in_gb'] | ||
} | ||
}, | ||
'RoleArn': execution_input['role'], | ||
'StoppingCondition': { | ||
'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds'] | ||
} | ||
} | ||
|
||
job_name = generate_job_name() | ||
processing_step = ProcessingStep('create_processing_job_step', | ||
processor=sklearn_processor_fixture, | ||
job_name=job_name, | ||
inputs=inputs, | ||
outputs=outputs, | ||
container_arguments=execution_input['container_arguments'], | ||
container_entrypoint=execution_input['entrypoint'], | ||
parameters=parameters | ||
) | ||
workflow_graph = Chain([processing_step]) | ||
|
||
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): | ||
# Create workflow and check definition | ||
|
||
workflow = create_workflow_and_check_definition( | ||
workflow_graph=workflow_graph, | ||
workflow_name=unique_name_from_base("integ-test-processing-step-workflow"), | ||
sfn_client=sfn_client, | ||
sfn_role_arn=sfn_role_arn | ||
) | ||
|
||
execution_input = { | ||
'image_uri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3', | ||
'instance_count': 1, | ||
'entrypoint': ['python3', '/opt/ml/processing/input/code/preprocessor.py'], | ||
'role': sagemaker_role_arn, | ||
'volume_size_in_gb': 30, | ||
'max_runtime_in_seconds': 500, | ||
'container_arguments': ['--train-test-split-ratio', '0.2'] | ||
} | ||
|
||
# Execute workflow | ||
execution = workflow.execute(inputs=execution_input) | ||
execution_output = execution.get_output(wait=True) | ||
|
||
# Check workflow output | ||
assert execution_output.get("ProcessingJobStatus") == "Completed" | ||
|
||
# Cleanup | ||
state_machine_delete_wait(sfn_client, workflow.state_machine_arn) | ||
# End of Cleanup | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice to see us embracing pep8 in files we touch 🙌
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙌🙌🙌