Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
parameters['TrainingJobName'] = job_name

if hyperparameters is not None:
parameters['HyperParameters'] = hyperparameters
merged_hyperparameters = {}
if estimator.hyperparameters() is not None:
merged_hyperparameters.update(estimator.hyperparameters())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re:

should we warn/error instead of an in-place merge when we find duplicate keys?

Throwing an error wouldn't be helpful; it's more breaking (you can't create what you could before) and doesn't address the desired functionality in #99. Logging might at an INFO level may be helpful for duplicate keys.

We should also document if not already that parameters gets totally overridden by training_config. This isn't the case for all service integration steps. Maybe we should adopt an update strategy there too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was originally thinking that we should log it as a warning since INFO tends to generally also include a lot of junk but not strongly opinionated.

I'll re-spin the update calls to assemble the dicts so that they log something on duplicate keys.
absolutely agree that we need to document.

This isn't the case for all service integration steps. Maybe we should adopt an update strategy there too?

great call. wasn't on my radar, but I'm in favour of adopting the update strategy. The most consistent it is across things in the SDK, the more intuitive and idiomatic it will feel for users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make the changes to service integration steps in a separate PR. let me know if you had a different thought/idea of where we should be documenting this behaviour @wong-a

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the other way around actually. In all service integration steps besides sagemaker, the constructor accepts a parameters argument that becomes Parameters . We can update sagemaker step classes to accept a parameters dict which can be an escape hatch for full API coverage or override any explicitly exposed arguments.

For example, DynamoDBUpdateItem the caller must specify all parameters in the parameters field. The constructor doesn't have a table_name argument to set or other required fields:
https://github.com/aws/aws-step-functions-data-science-sdk-python/blob/main/src/stepfunctions/steps/service.py#L140-L167

Whereas the sagemaker steps always construct parameters using the sagemaker SDK and some special arguments in the constructor. You could provide parameters because the constructor accepts **kwargs, but it won't do anything.
https://github.com/aws/aws-step-functions-data-science-sdk-python/blob/main/src/stepfunctions/steps/sagemaker.py#L477-L479

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean. I'm in favour of adding that parameters property and will address it in another PR. Escape hatches are powerful because it'll give users a path forward without requiring first class support to be developed and released.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's out of scope of this PR. Can you create an issue for tracking?

merged_hyperparameters.update(hyperparameters)
parameters['HyperParameters'] = merged_hyperparameters

if experiment_config is not None:
parameters['ExperimentConfig'] = experiment_config
Expand Down
129 changes: 129 additions & 0 deletions tests/unit/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,135 @@ def test_training_step_creation_with_framework(tensorflow_estimator):
'End': True
}

@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def training_step_merges_hyperparameters_from_constructor_and_estimator(tensorflow_estimator):
step = TrainingStep('Training',
estimator=tensorflow_estimator,
data={'train': 's3://sagemaker/train'},
job_name='tensorflow-job',
mini_batch_size=1024,
hyperparameters={
'key': 'value'
}
)

assert step.to_dict() == {
'Type': 'Task',
'Parameters': {
'AlgorithmSpecification': {
'TrainingImage': TENSORFLOW_IMAGE,
'TrainingInputMode': 'File'
},
'InputDataConfig': [
{
'DataSource': {
'S3DataSource': {
'S3DataDistributionType': 'FullyReplicated',
'S3DataType': 'S3Prefix',
'S3Uri': 's3://sagemaker/train'
}
},
'ChannelName': 'train'
}
],
'OutputDataConfig': {
'S3OutputPath': 's3://sagemaker/models'
},
'DebugHookConfig': {
'S3OutputPath': 's3://sagemaker/models/debug'
},
'StoppingCondition': {
'MaxRuntimeInSeconds': 86400
},
'ResourceConfig': {
'InstanceCount': 1,
'InstanceType': 'ml.p2.xlarge',
'VolumeSizeInGB': 30
},
'RoleArn': EXECUTION_ROLE,
'HyperParameters': {
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
'evaluation_steps': '100',
'key': 'value',
'sagemaker_container_log_level': '20',
'sagemaker_job_name': '"tensorflow-job"',
'sagemaker_program': '"tf_train.py"',
'sagemaker_region': '"us-east-1"',
'sagemaker_submit_directory': '"s3://sagemaker/source"',
'training_steps': '1000',
},
'TrainingJobName': 'tensorflow-job',
},
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
'End': True
}


@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def training_step_uses_constructor_hyperparameters_when_duplicates_supplied_in_estimator(tensorflow_estimator):
step = TrainingStep('Training',
estimator=tensorflow_estimator,
data={'train': 's3://sagemaker/train'},
job_name='tensorflow-job',
mini_batch_size=1024,
hyperparameters={
# set as 1000 in estimator
'training_steps': '500'
}
)

assert step.to_dict() == {
'Type': 'Task',
'Parameters': {
'AlgorithmSpecification': {
'TrainingImage': TENSORFLOW_IMAGE,
'TrainingInputMode': 'File'
},
'InputDataConfig': [
{
'DataSource': {
'S3DataSource': {
'S3DataDistributionType': 'FullyReplicated',
'S3DataType': 'S3Prefix',
'S3Uri': 's3://sagemaker/train'
}
},
'ChannelName': 'train'
}
],
'OutputDataConfig': {
'S3OutputPath': 's3://sagemaker/models'
},
'DebugHookConfig': {
'S3OutputPath': 's3://sagemaker/models/debug'
},
'StoppingCondition': {
'MaxRuntimeInSeconds': 86400
},
'ResourceConfig': {
'InstanceCount': 1,
'InstanceType': 'ml.p2.xlarge',
'VolumeSizeInGB': 30
},
'RoleArn': EXECUTION_ROLE,
'HyperParameters': {
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
'evaluation_steps': '100',
'sagemaker_container_log_level': '20',
'sagemaker_job_name': '"tensorflow-job"',
'sagemaker_program': '"tf_train.py"',
'sagemaker_region': '"us-east-1"',
'sagemaker_submit_directory': '"s3://sagemaker/source"',
'training_steps': '500',
},
'TrainingJobName': 'tensorflow-job',
},
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
'End': True
}


@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def test_transform_step_creation(pca_transformer):
Expand Down