Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def stop_tuning_job(self, name):
raise

def transform(self, job_name, model_name, strategy, max_concurrent_transforms, max_payload, env,
input_config, output_config, resource_config, tags):
input_config, output_config, resource_config, data_processing, tags):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually we try and put the new argument at the end of the list so that it's not a breaking change if someone had been calling with function without labeling the args before

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it to end. also fixed some UTs

"""Create an Amazon SageMaker transform job.

Args:
Expand All @@ -510,7 +510,8 @@ def transform(self, job_name, model_name, strategy, max_concurrent_transforms, m
input_config (dict): A dictionary describing the input data (and its location) for the job.
output_config (dict): A dictionary describing the output location for the job.
resource_config (dict): A dictionary describing the resources to complete the job.
tags (list[dict]): List of tags for labeling a training job. For more, see
data_processing(dict): A dictionary describing config for combining the input data and transformed data.
tags (list[dict]): List of tags for labeling a transform job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
"""
transform_request = {
Expand All @@ -536,6 +537,9 @@ def transform(self, job_name, model_name, strategy, max_concurrent_transforms, m
if tags is not None:
transform_request['Tags'] = tags

if data_processing is not None:
transform_request['DataProcessing'] = data_processing

LOGGER.info('Creating transform job with name: {}'.format(job_name))
LOGGER.debug('Transform request: {}'.format(json.dumps(transform_request, indent=4)))
self.sagemaker_client.create_transform_job(**transform_request)
Expand Down
34 changes: 30 additions & 4 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
self.sagemaker_session = sagemaker_session or Session()

def transform(self, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None,
job_name=None):
job_name=None, input_filter=None, output_filter=None, join_source=None):
"""Start a new transform job.

Args:
Expand All @@ -97,6 +97,11 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
split_type (str): The record delimiter for the input object (default: 'None').
Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
job_name (str): job name (default: None). If not specified, one will be generated.
input_filter (str): A JSONPath to select a portion of the input to pass to the algorithm container for
inference. If you omit the field, it gets the value '$', representing the entire input. (default: None).
output_filter (str): A JSONPath to select a portion of the joined/original output to return as the output. (default: None).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add an example for the two filters?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

join_source (str): The source of data to be joined to the transform output. It can be set to 'Input' meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None).
Valid values: Input, None.
"""
local_mode = self.sagemaker_session.local_mode
if not local_mode and not data.startswith('s3://'):
Expand All @@ -116,7 +121,7 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
self.output_path = 's3://{}/{}'.format(self.sagemaker_session.default_bucket(), self._current_job_name)

self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type,
split_type)
split_type, input_filter, output_filter, join_source)

def delete_model(self):
"""Delete the corresponding SageMaker model for this Transformer.
Expand Down Expand Up @@ -214,16 +219,19 @@ def _prepare_init_params_from_job_description(cls, job_details):

class _TransformJob(_Job):
@classmethod
def start_new(cls, transformer, data, data_type, content_type, compression_type, split_type):
def start_new(cls, transformer, data, data_type, content_type, compression_type,
split_type, input_filter, output_filter, join_source):
config = _TransformJob._load_config(data, data_type, content_type, compression_type, split_type, transformer)
data_processing = _TransformJob._prepare_data_processing(input_filter, output_filter, join_source)

transformer.sagemaker_session.transform(job_name=transformer._current_job_name,
model_name=transformer.model_name, strategy=transformer.strategy,
max_concurrent_transforms=transformer.max_concurrent_transforms,
max_payload=transformer.max_payload, env=transformer.env,
input_config=config['input_config'],
output_config=config['output_config'],
resource_config=config['resource_config'], tags=transformer.tags)
resource_config=config['resource_config'],
data_processing=data_processing, tags=transformer.tags)

return cls(transformer.sagemaker_session, transformer._current_job_name)

Expand Down Expand Up @@ -287,3 +295,21 @@ def _prepare_resource_config(instance_count, instance_type, volume_kms_key):
config['VolumeKmsKeyId'] = volume_kms_key

return config

@staticmethod
def _prepare_data_processing(input_filter, output_filter, join_source):
config = {}

if input_filter is not None:
config['InputFilter'] = input_filter

if output_filter is not None:
config['OutputFilter'] = output_filter

if join_source is not None:
config['JoinSource'] = join_source

if len(config) == 0:
return None

return config
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/integ/kms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _create_kms_key(kms_client,
role_arn=role_arn,
sagemaker_role=sagemaker_role)
else:
principal = "{account_id}".format(account_id=account_id)
principal = '"{account_id}"'.format(account_id=account_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any special reason to include this change on this PR?


response = kms_client.create_key(
Policy=KEY_POLICY.format(id=POLICY_NAME, principal=principal, sagemaker_role=sagemaker_role),
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def instance_type(request):
@pytest.mark.canary_quick
def test_horovod(sagemaker_session, instance_type, tmpdir):
job_name = sagemaker.utils.unique_name_from_base('tf-horovod')
estimator = TensorFlow(entry_point=os.path.join(horovod_dir, 'test_hvd_basic.py'),
estimator = TensorFlow(entry_point=os.path.join(horovod_dir, 'hvd_basic.py'),
role='SageMakerRole',
train_instance_count=2,
train_instance_type=instance_type,
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_horovod(sagemaker_session, instance_type, tmpdir):
def test_horovod_local_mode(sagemaker_local_session, instances, processes, tmpdir):
output_path = 'file://%s' % tmpdir
job_name = sagemaker.utils.unique_name_from_base('tf-horovod')
estimator = TensorFlow(entry_point=os.path.join(horovod_dir, 'test_hvd_basic.py'),
estimator = TensorFlow(entry_point=os.path.join(horovod_dir, 'hvd_basic.py'),
role='SageMakerRole',
train_instance_count=2,
train_instance_type='local',
Expand Down
9 changes: 6 additions & 3 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def test_transform_mxnet(sagemaker_session, mxnet_full_version):

kms_key_arn = get_or_create_kms_key(sagemaker_session)

transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn)
transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn,
input_filter=None, output_filter="$", join_source=None)
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
transformer.wait()
Expand Down Expand Up @@ -148,7 +149,9 @@ def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version):
assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds']


def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None,
input_filter=None, output_filter=None, join_source=None):
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
transformer.transform(transform_input, content_type='text/csv')
transformer.transform(transform_input, content_type='text/csv',
input_filter=input_filter, output_filter=output_filter, join_source=join_source)
return transformer
26 changes: 24 additions & 2 deletions tests/unit/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,18 @@ def test_transform_with_all_params(start_new_job, transformer):
content_type = 'text/csv'
compression = 'Gzip'
split = 'Line'
input_filter = "$.feature"
output_filter = "$['sagemaker_output', 'id']"
join_source = "Input"

transformer.transform(DATA, S3_DATA_TYPE, content_type=content_type, compression_type=compression, split_type=split,
job_name=JOB_NAME)
job_name=JOB_NAME, input_filter=input_filter, output_filter=output_filter,
join_source=join_source)

assert transformer._current_job_name == JOB_NAME
assert transformer.output_path == OUTPUT_PATH
start_new_job.assert_called_once_with(transformer, DATA, S3_DATA_TYPE, content_type, compression, split)
start_new_job.assert_called_once_with(transformer, DATA, S3_DATA_TYPE, content_type, compression,
split, input_filter, output_filter, join_source)


@patch('sagemaker.transformer.name_from_base')
Expand Down Expand Up @@ -392,6 +397,23 @@ def test_prepare_resource_config():
assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, 'VolumeKmsKeyId': KMS_KEY_ID}


def test_data_processing_config():
actual_config = _TransformJob._prepare_data_processing("$", None, None)
assert actual_config == {'InputFilter': "$"}

actual_config = _TransformJob._prepare_data_processing(None, "$", None)
assert actual_config == {'OutputFilter': "$"}

actual_config = _TransformJob._prepare_data_processing(None, None, "Input")
assert actual_config == {'JoinSource': "Input"}

actual_config = _TransformJob._prepare_data_processing("$[0]", "$[1]", "Input")
assert actual_config == {'InputFilter': "$[0]", 'OutputFilter': "$[1]", 'JoinSource': "Input"}

actual_config = _TransformJob._prepare_data_processing(None, None, None)
assert actual_config == None


def test_transform_job_wait(sagemaker_session):
job = _TransformJob(sagemaker_session, JOB_NAME)
job.wait()
Expand Down