diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 2ef34071fe..f5ff344edc 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -115,7 +115,7 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): def boto_region_name(self): return self._region_name - def upload_data(self, path, bucket=None, key_prefix='data'): + def upload_data(self, path, bucket=None, key_prefix='data', extra_args=None): """Upload local file or directory to S3. If a single file is specified for upload, the resulting S3 object key is ``{key_prefix}/{filename}`` @@ -132,6 +132,10 @@ def upload_data(self, path, bucket=None, key_prefix='data'): creates it). key_prefix (str): Optional S3 object key name prefix (default: 'data'). S3 uses the prefix to create a directory structure for the bucket content that it display in the S3 console. + extra_args (dict): Optional extra arguments that may be passed to the upload operation. Similar to + ExtraArgs parameter in S3 upload_file function. Please refer to the ExtraArgs parameter + documentation here: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter Returns: str: The S3 URI of the uploaded file(s). If a file is specified in the path argument, the URI format is: @@ -158,7 +162,7 @@ def upload_data(self, path, bucket=None, key_prefix='data'): s3 = self.boto_session.resource('s3') for local_path, s3_key in files: - s3.Object(bucket, s3_key).upload_file(local_path) + s3.Object(bucket, s3_key).upload_file(local_path, ExtraArgs=extra_args) s3_uri = 's3://{}/{}'.format(bucket, key_prefix) # If a specific file was used as input (instead of a directory), we return the full S3 key diff --git a/tests/integ/test_data_upload.py b/tests/integ/test_data_upload.py new file mode 100755 index 0000000000..c009f137e9 --- /dev/null +++ b/tests/integ/test_data_upload.py @@ -0,0 +1,45 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os + +from six.moves.urllib.parse import urlparse + +from tests.integ import DATA_DIR + +AES_ENCRYPTION_ENABLED = {'ServerSideEncryption': 'AES256'} + + +def test_upload_data_absolute_file(sagemaker_session): + """Test the method ``Session.upload_data`` can upload one encrypted file to S3 bucket""" + data_path = os.path.join(DATA_DIR, 'upload_data_tests', 'file1.py') + uploaded_file = sagemaker_session.upload_data(data_path, extra_args=AES_ENCRYPTION_ENABLED) + parsed_url = urlparse(uploaded_file) + s3_client = sagemaker_session.boto_session.client('s3') + head = s3_client.head_object(Bucket=parsed_url.netloc, Key=parsed_url.path.lstrip('/')) + assert head['ServerSideEncryption'] == 'AES256' + + +def test_upload_data_absolute_dir(sagemaker_session): + """Test the method ``Session.upload_data`` can upload encrypted objects to S3 bucket""" + data_path = os.path.join(DATA_DIR, 'upload_data_tests', 'nested_dir') + uploaded_dir = sagemaker_session.upload_data(data_path, extra_args=AES_ENCRYPTION_ENABLED) + parsed_url = urlparse(uploaded_dir) + s3_bucket = parsed_url.netloc + s3_prefix = parsed_url.path.lstrip('/') + s3_client = sagemaker_session.boto_session.client('s3') + for file in os.listdir(data_path): + s3_key = '{}/{}'.format(s3_prefix, file) + head = s3_client.head_object(Bucket=s3_bucket, Key=s3_key) + assert head['ServerSideEncryption'] == 'AES256' diff --git a/tests/unit/test_upload_data.py b/tests/unit/test_upload_data.py index bb7ccad858..9a4d9e5aeb 100644 --- a/tests/unit/test_upload_data.py +++ b/tests/unit/test_upload_data.py @@ -24,6 +24,7 @@ SINGLE_FILE_NAME = 'file1.py' UPLOAD_DATA_TESTS_SINGLE_FILE = os.path.join(UPLOAD_DATA_TESTS_FILES_DIR, SINGLE_FILE_NAME) BUCKET_NAME = 'mybucket' +AES_ENCRYPTION_ENABLED = {'ServerSideEncryption': 'AES256'} @pytest.fixture() @@ -37,19 +38,46 @@ def sagemaker_session(): def test_upload_data_absolute_dir(sagemaker_session): result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_FILES_DIR) - uploaded_files = [args[0] for name, args, kwargs in sagemaker_session.boto_session.mock_calls - if name == 'resource().Object().upload_file'] + uploaded_files_with_args = [(args[0], kwargs) for name, args, kwargs in sagemaker_session.boto_session.mock_calls + if name == 'resource().Object().upload_file'] assert result_s3_uri == 's3://{}/data'.format(BUCKET_NAME) - assert len(uploaded_files) == 4 - for file in uploaded_files: + assert len(uploaded_files_with_args) == 4 + for file, kwargs in uploaded_files_with_args: assert os.path.exists(file) + assert kwargs['ExtraArgs'] is None def test_upload_data_absolute_file(sagemaker_session): result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_SINGLE_FILE) - uploaded_files = [args[0] for name, args, kwargs in sagemaker_session.boto_session.mock_calls - if name == 'resource().Object().upload_file'] + uploaded_files_with_args = [(args[0], kwargs) for name, args, kwargs in sagemaker_session.boto_session.mock_calls + if name == 'resource().Object().upload_file'] assert result_s3_uri == 's3://{}/data/{}'.format(BUCKET_NAME, SINGLE_FILE_NAME) - assert len(uploaded_files) == 1 - assert os.path.exists(uploaded_files[0]) + assert len(uploaded_files_with_args) == 1 + (file, kwargs) = uploaded_files_with_args[0] + assert os.path.exists(file) + assert kwargs['ExtraArgs'] is None + + +def test_upload_data_aes_encrypted_absolute_dir(sagemaker_session): + result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_FILES_DIR, extra_args=AES_ENCRYPTION_ENABLED) + + uploaded_files_with_args = [(args[0], kwargs) for name, args, kwargs in sagemaker_session.boto_session.mock_calls + if name == 'resource().Object().upload_file'] + assert result_s3_uri == 's3://{}/data'.format(BUCKET_NAME) + assert len(uploaded_files_with_args) == 4 + for file, kwargs in uploaded_files_with_args: + assert os.path.exists(file) + assert kwargs['ExtraArgs'] == AES_ENCRYPTION_ENABLED + + +def test_upload_data_aes_encrypted_absolute_file(sagemaker_session): + result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_SINGLE_FILE, extra_args=AES_ENCRYPTION_ENABLED) + + uploaded_files_with_args = [(args[0], kwargs) for name, args, kwargs in sagemaker_session.boto_session.mock_calls + if name == 'resource().Object().upload_file'] + assert result_s3_uri == 's3://{}/data/{}'.format(BUCKET_NAME, SINGLE_FILE_NAME) + assert len(uploaded_files_with_args) == 1 + (file, kwargs) = uploaded_files_with_args[0] + assert os.path.exists(file) + assert kwargs['ExtraArgs'] == AES_ENCRYPTION_ENABLED