-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Add DataProcessing Fields for Batch Transform #827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
516df28
2d61291
cdcde78
349d30b
59357d5
06e7d9e
5a7f08c
3e52c2e
0cb74c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,7 +79,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass | |
| self.sagemaker_session = sagemaker_session or Session() | ||
|
|
||
| def transform(self, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, | ||
| job_name=None): | ||
| job_name=None, input_filter=None, output_filter=None, join_source=None): | ||
| """Start a new transform job. | ||
|
|
||
| Args: | ||
|
|
@@ -97,6 +97,11 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t | |
| split_type (str): The record delimiter for the input object (default: 'None'). | ||
| Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'. | ||
| job_name (str): job name (default: None). If not specified, one will be generated. | ||
| input_filter (str): A JSONPath to select a portion of the input to pass to the algorithm container for | ||
| inference. If you omit the field, it gets the value '$', representing the entire input. (default: None). | ||
| output_filter (str): A JSONPath to select a portion of the joined/original output to return as the output. (default: None). | ||
|
||
| join_source (str): The source of data to be joined to the transform output. It can be set to 'Input' meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). | ||
| Valid values: Input, None. | ||
| """ | ||
| local_mode = self.sagemaker_session.local_mode | ||
| if not local_mode and not data.startswith('s3://'): | ||
|
|
@@ -116,7 +121,7 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t | |
| self.output_path = 's3://{}/{}'.format(self.sagemaker_session.default_bucket(), self._current_job_name) | ||
|
|
||
| self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type, | ||
| split_type) | ||
| split_type, input_filter, output_filter, join_source) | ||
|
|
||
| def delete_model(self): | ||
| """Delete the corresponding SageMaker model for this Transformer. | ||
|
|
@@ -214,16 +219,19 @@ def _prepare_init_params_from_job_description(cls, job_details): | |
|
|
||
| class _TransformJob(_Job): | ||
| @classmethod | ||
| def start_new(cls, transformer, data, data_type, content_type, compression_type, split_type): | ||
| def start_new(cls, transformer, data, data_type, content_type, compression_type, | ||
| split_type, input_filter, output_filter, join_source): | ||
| config = _TransformJob._load_config(data, data_type, content_type, compression_type, split_type, transformer) | ||
| data_processing = _TransformJob._prepare_data_processing(input_filter, output_filter, join_source) | ||
|
|
||
| transformer.sagemaker_session.transform(job_name=transformer._current_job_name, | ||
| model_name=transformer.model_name, strategy=transformer.strategy, | ||
| max_concurrent_transforms=transformer.max_concurrent_transforms, | ||
| max_payload=transformer.max_payload, env=transformer.env, | ||
| input_config=config['input_config'], | ||
| output_config=config['output_config'], | ||
| resource_config=config['resource_config'], tags=transformer.tags) | ||
| resource_config=config['resource_config'], | ||
| data_processing=data_processing, tags=transformer.tags) | ||
|
|
||
| return cls(transformer.sagemaker_session, transformer._current_job_name) | ||
|
|
||
|
|
@@ -287,3 +295,21 @@ def _prepare_resource_config(instance_count, instance_type, volume_kms_key): | |
| config['VolumeKmsKeyId'] = volume_kms_key | ||
|
|
||
| return config | ||
|
|
||
| @staticmethod | ||
| def _prepare_data_processing(input_filter, output_filter, join_source): | ||
| config = {} | ||
|
|
||
| if input_filter is not None: | ||
| config['InputFilter'] = input_filter | ||
|
|
||
| if output_filter is not None: | ||
| config['OutputFilter'] = output_filter | ||
|
|
||
| if join_source is not None: | ||
| config['JoinSource'] = join_source | ||
|
|
||
| if len(config) == 0: | ||
| return None | ||
|
|
||
| return config | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -68,7 +68,7 @@ def _create_kms_key(kms_client, | |
| role_arn=role_arn, | ||
| sagemaker_role=sagemaker_role) | ||
| else: | ||
| principal = "{account_id}".format(account_id=account_id) | ||
| principal = '"{account_id}"'.format(account_id=account_id) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any special reason to include this change on this PR? |
||
|
|
||
| response = kms_client.create_key( | ||
| Policy=KEY_POLICY.format(id=POLICY_NAME, principal=principal, sagemaker_role=sagemaker_role), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually we try and put the new argument at the end of the list so that it's not a breaking change if someone had been calling with function without labeling the args before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved it to end. also fixed some UTs