From c80b63de15f9ae65637d41a816fda9afe6d1dd6b Mon Sep 17 00:00:00 2001 From: Naresh Kumar Kolloju Date: Tue, 18 Jun 2019 18:32:48 -0700 Subject: [PATCH 1/3] change: [pr-827][followups]Improve documentation of some functions Also some unit test fixes. See comments from mario in https://github.com/aws/sagemaker-python-sdk/pull/827 --- src/sagemaker/session.py | 4 ++-- src/sagemaker/transformer.py | 2 ++ tests/integ/test_transformer.py | 4 +++- tests/unit/test_session.py | 9 ++++++++- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 75b7697668..b06cc9b143 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -514,9 +514,9 @@ def transform(self, job_name, model_name, strategy, max_concurrent_transforms, m input_config (dict): A dictionary describing the input data (and its location) for the job. output_config (dict): A dictionary describing the output location for the job. resource_config (dict): A dictionary describing the resources to complete the job. - tags (list[dict]): List of tags for labeling a transform job. + tags (list[dict]): List of tags for labeling a transform job. For more information, + see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. data_processing(dict): A dictionary describing config for combining the input data and transformed data. - For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. """ transform_request = { 'TransformJobName': job_name, diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 136d1f7f2e..c5f0decdd3 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -99,8 +99,10 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t job_name (str): job name (default: None). If not specified, one will be generated. input_filter (str): A JSONPath to select a portion of the input to pass to the algorithm container for inference. If you omit the field, it gets the value '$', representing the entire input. + For more information, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html Some examples: "$[1:]", "$.features"(default: None). output_filter (str): A JSONPath to select a portion of the joined/original output to return as the output. + For more information, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html Some examples: "$[1:]", "$.prediction" (default: None). join_source (str): The source of data to be joined to the transform output. It can be set to 'Input' meaning the entire input record will be joined to the inference result. diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 02047def33..8485c95da7 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -55,9 +55,10 @@ def test_transform_mxnet(sagemaker_session, mxnet_full_version): kms_key_arn = get_or_create_kms_key(sagemaker_session) output_filter = "$" + input_filter = "$" transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn, - input_filter=None, output_filter=output_filter, + input_filter=input_filter, output_filter=output_filter, join_source=None) with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): @@ -67,6 +68,7 @@ def test_transform_mxnet(sagemaker_session, mxnet_full_version): TransformJobName=transformer.latest_transform_job.name) assert kms_key_arn == job_desc['TransformResources']['VolumeKmsKeyId'] assert output_filter == job_desc['DataProcessing']['OutputFilter'] + assert input_filter == job_desc['DataProcessing']['InputFilter'] @pytest.mark.canary_quick diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index cc428692d4..e72d235c7b 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -578,17 +578,24 @@ def test_transform_pack_to_request(sagemaker_session): 'InstanceType': INSTANCE_TYPE, } + data_processing = { + 'OutputFilter': '$', + 'InputFilter': '$', + 'JoinSource': 'Input' + } + expected_args = { 'TransformJobName': JOB_NAME, 'ModelName': model_name, 'TransformInput': in_config, 'TransformOutput': out_config, 'TransformResources': resource_config, + 'DataProcessing': data_processing, } sagemaker_session.transform(job_name=JOB_NAME, model_name=model_name, strategy=None, max_concurrent_transforms=None, max_payload=None, env=None, input_config=in_config, output_config=out_config, - resource_config=resource_config, tags=None, data_processing=None) + resource_config=resource_config, tags=None, data_processing=data_processing) _, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] assert actual_args == expected_args From 36045cf7d4bdcd32c99cc192a782e76e7290bd56 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Fri, 19 Jul 2019 11:27:33 -0700 Subject: [PATCH 2/3] fix quotes --- tests/unit/test_session.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index de2ac6a592..fa48f58e3d 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -677,9 +677,9 @@ def test_transform_pack_to_request(sagemaker_session): resource_config = {"InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE} data_processing = { - 'OutputFilter': '$', - 'InputFilter': '$', - 'JoinSource': 'Input' + "OutputFilter": "$", + "InputFilter": "$", + "JoinSource": "Input", } expected_args = { From e429269347cfa65db4d90a3b1b669a0fc2a2b421 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Fri, 19 Jul 2019 13:48:01 -0700 Subject: [PATCH 3/3] address black formatting --- tests/unit/test_session.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fa48f58e3d..714da580c8 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -676,11 +676,7 @@ def test_transform_pack_to_request(sagemaker_session): resource_config = {"InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE} - data_processing = { - "OutputFilter": "$", - "InputFilter": "$", - "JoinSource": "Input", - } + data_processing = {"OutputFilter": "$", "InputFilter": "$", "JoinSource": "Input"} expected_args = { "TransformJobName": JOB_NAME,