Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,17 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
model (sagemaker.model.Model): The SageMaker model to use in the ModelStep. If :py:class:`TrainingStep` was used to train the model and saving the model is the next step in the workflow, the output of :py:func:`TrainingStep.get_expected_model()` can be passed here.
model_name (str or Placeholder, optional): Specify a model name, this is required for creating the model. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
tags (list[dict] or Placeholders, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateModel<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html>`_. (Default: None)
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
"""
model_type = "FrameworkModel" if isinstance(model, FrameworkModel) else "Model"
if isinstance(model, FrameworkModel):
parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri)
model_parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: didn't really have an objection to this parameter name since they do ultimately resolve to Parameters and you're calling the model_config method to assign it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was renamed to avoid confusion with the input parameters provided in the args

if model_name:
parameters['ModelName'] = model_name
model_parameters['ModelName'] = model_name
elif isinstance(model, Model):
parameters = {
model_parameters = {
'ExecutionRoleArn': model.role,
'ModelName': model_name or model.name,
'PrimaryContainer': {
Expand All @@ -315,13 +318,17 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
else:
raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__))

if 'S3Operations' in parameters:
del parameters['S3Operations']
if 'S3Operations' in model_parameters:
del model_parameters['S3Operations']

if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)
model_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)

kwargs[Field.Parameters.value] = parameters
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
# Update model parameters with input parameters
merge_dicts(model_parameters, kwargs[Field.Parameters.value])

kwargs[Field.Parameters.value] = model_parameters

"""
Example resource arn: arn:aws:states:::sagemaker:createModel
Expand Down
63 changes: 58 additions & 5 deletions tests/integ/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,59 @@ def test_model_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_a
delete_sagemaker_model(model_name, sagemaker_session)
# End of Cleanup


def test_model_step_with_placeholders(trained_estimator, sfn_client, sagemaker_session, sfn_role_arn):
# Build workflow definition
execution_input = ExecutionInput(schema={
'ModelName': str,
'Mode': str,
'Tags': list
})

parameters = {
'PrimaryContainer': {
'Mode': execution_input['Mode']
},
'Tags': execution_input['Tags']
}

model_step = ModelStep('create_model_step', model=trained_estimator.create_model(),
model_name=execution_input['ModelName'], parameters=parameters)
model_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
workflow_graph = Chain([model_step])

with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
# Create workflow and check definition
workflow = create_workflow_and_check_definition(
workflow_graph=workflow_graph,
workflow_name=unique_name_from_base("integ-test-model-step-workflow"),
sfn_client=sfn_client,
sfn_role_arn=sfn_role_arn
)

inputs = {
'ModelName': generate_job_name(),
'Mode': 'SingleModel',
'Tags': [{
'Key': 'Environment',
'Value': 'test'
}]
}

# Execute workflow
execution = workflow.execute(inputs=inputs)
execution_output = execution.get_output(wait=True)

# Check workflow output
assert execution_output.get("ModelArn") is not None
assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200

# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
model_name = get_resource_name_from_arn(execution_output.get("ModelArn")).split("/")[1]
delete_sagemaker_model(model_name, sagemaker_session)


def test_transform_step(trained_estimator, sfn_client, sfn_role_arn):
# Create transformer from previously created estimator
job_name = generate_job_name()
Expand Down Expand Up @@ -293,7 +346,7 @@ def test_endpoint_config_step(trained_estimator, sfn_client, sagemaker_session,
# Execute workflow
execution = workflow.execute()
execution_output = execution.get_output(wait=True)

# Check workflow output
assert execution_output.get("EndpointConfigArn") is not None
assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200
Expand Down Expand Up @@ -334,7 +387,7 @@ def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client,
# Execute workflow
execution = workflow.execute()
execution_output = execution.get_output(wait=True)

# Check workflow output
endpoint_arn = execution_output.get("EndpointArn")
assert execution_output.get("EndpointArn") is not None
Expand Down Expand Up @@ -372,7 +425,7 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
max_jobs=2,
max_parallel_jobs=2,
)

# Build workflow definition
tuning_step = TuningStep('Tuning', tuner=tuner, job_name=job_name, data=record_set_for_hyperparameter_tuning)
tuning_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
Expand All @@ -390,7 +443,7 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
# Execute workflow
execution = workflow.execute()
execution_output = execution.get_output(wait=True)

# Check workflow output
assert execution_output.get("HyperParameterTuningJobStatus") == "Completed"

Expand Down Expand Up @@ -440,7 +493,7 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien
sfn_client=sfn_client,
sfn_role_arn=sfn_role_arn
)

# Execute workflow
execution = workflow.execute()
execution_output = execution.get_output(wait=True)
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,36 @@ def test_model_step_creation(pca_model):
}


@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def test_model_step_creation_with_placeholders(pca_model):
execution_input = ExecutionInput(schema={
'Environment': str,
'Tags': str
})

step_input = StepInput(schema={
'ModelName': str
})

parameters = {
'PrimaryContainer': {
'Environment': execution_input['Environment']
}
}
step = ModelStep('Create model', model=pca_model, model_name=step_input['ModelName'], tags=execution_input['Tags'],
parameters=parameters)
assert step.to_dict()['Parameters'] == {
'ExecutionRoleArn': EXECUTION_ROLE,
'ModelName.$': "$['ModelName']",
'PrimaryContainer': {
'Environment.$': "$$.Execution.Input['Environment']",
'Image': pca_model.image_uri,
'ModelDataUrl': pca_model.model_data
},
'Tags.$': "$$.Execution.Input['Tags']"
}


@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def test_model_step_creation_with_env(pca_model_with_env):
step = ModelStep('Create model', model=pca_model_with_env, model_name='pca-model', tags=DEFAULT_TAGS)
Expand Down