diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 9d958810d8..75b7697668 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -499,7 +499,7 @@ def stop_tuning_job(self, name): raise def transform(self, job_name, model_name, strategy, max_concurrent_transforms, max_payload, env, - input_config, output_config, resource_config, tags): + input_config, output_config, resource_config, tags, data_processing): """Create an Amazon SageMaker transform job. Args: @@ -514,8 +514,9 @@ def transform(self, job_name, model_name, strategy, max_concurrent_transforms, m input_config (dict): A dictionary describing the input data (and its location) for the job. output_config (dict): A dictionary describing the output location for the job. resource_config (dict): A dictionary describing the resources to complete the job. - tags (list[dict]): List of tags for labeling a training job. For more, see - https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + tags (list[dict]): List of tags for labeling a transform job. + data_processing(dict): A dictionary describing config for combining the input data and transformed data. + For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. """ transform_request = { 'TransformJobName': job_name, @@ -540,6 +541,9 @@ def transform(self, job_name, model_name, strategy, max_concurrent_transforms, m if tags is not None: transform_request['Tags'] = tags + if data_processing is not None: + transform_request['DataProcessing'] = data_processing + LOGGER.info('Creating transform job with name: {}'.format(job_name)) LOGGER.debug('Transform request: {}'.format(json.dumps(transform_request, indent=4))) self.sagemaker_client.create_transform_job(**transform_request) diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index ea43efe090..136d1f7f2e 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -79,7 +79,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass self.sagemaker_session = sagemaker_session or Session() def transform(self, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, - job_name=None): + job_name=None, input_filter=None, output_filter=None, join_source=None): """Start a new transform job. Args: @@ -97,6 +97,15 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t split_type (str): The record delimiter for the input object (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'. job_name (str): job name (default: None). If not specified, one will be generated. + input_filter (str): A JSONPath to select a portion of the input to pass to the algorithm container for + inference. If you omit the field, it gets the value '$', representing the entire input. + Some examples: "$[1:]", "$.features"(default: None). + output_filter (str): A JSONPath to select a portion of the joined/original output to return as the output. + Some examples: "$[1:]", "$.prediction" (default: None). + join_source (str): The source of data to be joined to the transform output. It can be set to 'Input' + meaning the entire input record will be joined to the inference result. + You can use OutputFilter to select the useful portion before uploading to S3. (default: None). + Valid values: Input, None. """ local_mode = self.sagemaker_session.local_mode if not local_mode and not data.startswith('s3://'): @@ -116,7 +125,7 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t self.output_path = 's3://{}/{}'.format(self.sagemaker_session.default_bucket(), self._current_job_name) self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type, - split_type) + split_type, input_filter, output_filter, join_source) def delete_model(self): """Delete the corresponding SageMaker model for this Transformer. @@ -214,8 +223,10 @@ def _prepare_init_params_from_job_description(cls, job_details): class _TransformJob(_Job): @classmethod - def start_new(cls, transformer, data, data_type, content_type, compression_type, split_type): + def start_new(cls, transformer, data, data_type, content_type, compression_type, + split_type, input_filter, output_filter, join_source): config = _TransformJob._load_config(data, data_type, content_type, compression_type, split_type, transformer) + data_processing = _TransformJob._prepare_data_processing(input_filter, output_filter, join_source) transformer.sagemaker_session.transform(job_name=transformer._current_job_name, model_name=transformer.model_name, strategy=transformer.strategy, @@ -223,7 +234,8 @@ def start_new(cls, transformer, data, data_type, content_type, compression_type, max_payload=transformer.max_payload, env=transformer.env, input_config=config['input_config'], output_config=config['output_config'], - resource_config=config['resource_config'], tags=transformer.tags) + resource_config=config['resource_config'], + tags=transformer.tags, data_processing=data_processing) return cls(transformer.sagemaker_session, transformer._current_job_name) @@ -287,3 +299,21 @@ def _prepare_resource_config(instance_count, instance_type, volume_kms_key): config['VolumeKmsKeyId'] = volume_kms_key return config + + @staticmethod + def _prepare_data_processing(input_filter, output_filter, join_source): + config = {} + + if input_filter is not None: + config['InputFilter'] = input_filter + + if output_filter is not None: + config['OutputFilter'] = output_filter + + if join_source is not None: + config['JoinSource'] = join_source + + if len(config) == 0: + return None + + return config diff --git a/tests/integ/kms_utils.py b/tests/integ/kms_utils.py index 921ec103b8..f5e5e0c5aa 100644 --- a/tests/integ/kms_utils.py +++ b/tests/integ/kms_utils.py @@ -68,7 +68,7 @@ def _create_kms_key(kms_client, role_arn=role_arn, sagemaker_role=sagemaker_role) else: - principal = "{account_id}".format(account_id=account_id) + principal = '"{account_id}"'.format(account_id=account_id) response = kms_client.create_key( Policy=KEY_POLICY.format(id=POLICY_NAME, principal=principal, sagemaker_role=sagemaker_role), diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index d47e0f7373..02047def33 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -54,8 +54,11 @@ def test_transform_mxnet(sagemaker_session, mxnet_full_version): key_prefix=transform_input_key_prefix) kms_key_arn = get_or_create_kms_key(sagemaker_session) + output_filter = "$" - transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn) + transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn, + input_filter=None, output_filter=output_filter, + join_source=None) with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): transformer.wait() @@ -63,6 +66,7 @@ def test_transform_mxnet(sagemaker_session, mxnet_full_version): job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job( TransformJobName=transformer.latest_transform_job.name) assert kms_key_arn == job_desc['TransformResources']['VolumeKmsKeyId'] + assert output_filter == job_desc['DataProcessing']['OutputFilter'] @pytest.mark.canary_quick @@ -232,7 +236,9 @@ def test_transform_byo_estimator(sagemaker_session): assert tags == model_tags -def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None): +def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None, + input_filter=None, output_filter=None, join_source=None): transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key) - transformer.transform(transform_input, content_type='text/csv') + transformer.transform(transform_input, content_type='text/csv', input_filter=input_filter, + output_filter=output_filter, join_source=join_source) return transformer diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 48250beda5..cc428692d4 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -588,7 +588,7 @@ def test_transform_pack_to_request(sagemaker_session): sagemaker_session.transform(job_name=JOB_NAME, model_name=model_name, strategy=None, max_concurrent_transforms=None, max_payload=None, env=None, input_config=in_config, output_config=out_config, - resource_config=resource_config, tags=None) + resource_config=resource_config, tags=None, data_processing=None) _, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] assert actual_args == expected_args @@ -603,7 +603,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session): sagemaker_session.transform(job_name=JOB_NAME, model_name='my-model', strategy=strategy, max_concurrent_transforms=max_concurrent_transforms, env=env, max_payload=max_payload, input_config={}, output_config={}, - resource_config={}, tags=TAGS) + resource_config={}, tags=TAGS, data_processing=None) _, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] assert actual_args['BatchStrategy'] == strategy diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 856ff9a07a..d549a6ad37 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -98,13 +98,18 @@ def test_transform_with_all_params(start_new_job, transformer): content_type = 'text/csv' compression = 'Gzip' split = 'Line' + input_filter = "$.feature" + output_filter = "$['sagemaker_output', 'id']" + join_source = "Input" transformer.transform(DATA, S3_DATA_TYPE, content_type=content_type, compression_type=compression, split_type=split, - job_name=JOB_NAME) + job_name=JOB_NAME, input_filter=input_filter, output_filter=output_filter, + join_source=join_source) assert transformer._current_job_name == JOB_NAME assert transformer.output_path == OUTPUT_PATH - start_new_job.assert_called_once_with(transformer, DATA, S3_DATA_TYPE, content_type, compression, split) + start_new_job.assert_called_once_with(transformer, DATA, S3_DATA_TYPE, content_type, compression, + split, input_filter, output_filter, join_source) @patch('sagemaker.transformer.name_from_base') @@ -300,7 +305,8 @@ def test_start_new(transformer, sagemaker_session): transformer._current_job_name = JOB_NAME job = _TransformJob(sagemaker_session, JOB_NAME) - started_job = job.start_new(transformer, DATA, S3_DATA_TYPE, None, None, None) + started_job = job.start_new(transformer, DATA, S3_DATA_TYPE, None, None, None, + None, None, None) assert started_job.sagemaker_session == sagemaker_session sagemaker_session.transform.assert_called_once() @@ -392,6 +398,23 @@ def test_prepare_resource_config(): assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, 'VolumeKmsKeyId': KMS_KEY_ID} +def test_data_processing_config(): + actual_config = _TransformJob._prepare_data_processing("$", None, None) + assert actual_config == {'InputFilter': "$"} + + actual_config = _TransformJob._prepare_data_processing(None, "$", None) + assert actual_config == {'OutputFilter': "$"} + + actual_config = _TransformJob._prepare_data_processing(None, None, "Input") + assert actual_config == {'JoinSource': "Input"} + + actual_config = _TransformJob._prepare_data_processing("$[0]", "$[1]", "Input") + assert actual_config == {'InputFilter': "$[0]", 'OutputFilter': "$[1]", 'JoinSource': "Input"} + + actual_config = _TransformJob._prepare_data_processing(None, None, None) + assert actual_config is None + + def test_transform_job_wait(sagemaker_session): job = _TransformJob(sagemaker_session, JOB_NAME) job.wait()