From 2df563b9a4b98d697f79941363c0031f6769a20a Mon Sep 17 00:00:00 2001 From: "Nickolas J. Wilson" Date: Sat, 11 May 2019 09:11:56 -0500 Subject: [PATCH 1/3] feature: add encryption option to "record_set" --- src/sagemaker/amazon/amazon_estimator.py | 21 ++++++++++------ tests/unit/test_amazon_estimator.py | 32 ++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index a1b7511ffd..d43902b54d 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -159,7 +159,7 @@ def fit(self, records, mini_batch_size=None, wait=True, logs=True, job_name=None if wait: self.latest_training_job.wait(logs=logs) - def record_set(self, train, labels=None, channel="train"): + def record_set(self, train, labels=None, channel="train", encrypt=False): """Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector. For the 2D ``ndarray`` ``train``, each row is converted to a :class:`~Record` object. @@ -177,8 +177,10 @@ def record_set(self, train, labels=None, channel="train"): Args: train (numpy.ndarray): A 2D numpy array of training data. labels (numpy.ndarray): A 1D numpy array of labels. Its length must be equal to the - number of rows in ``train``. + number of rows in ``train``. channel (str): The SageMaker TrainingJob channel this RecordSet should be assigned to. + encrypt (bool): Specifies whether the objects uploaded to S3 are encrypted on the + server side using AES-256 (default: ``False``). Returns: RecordSet: A RecordSet referencing the encoded, uploading training and label data. """ @@ -188,7 +190,8 @@ def record_set(self, train, labels=None, channel="train"): key_prefix = key_prefix + '{}-{}/'.format(type(self).__name__, sagemaker_timestamp()) key_prefix = key_prefix.lstrip('/') logger.debug('Uploading to bucket {} and key_prefix {}'.format(bucket, key_prefix)) - manifest_s3_file = upload_numpy_to_s3_shards(self.train_instance_count, s3, bucket, key_prefix, train, labels) + manifest_s3_file = upload_numpy_to_s3_shards(self.train_instance_count, s3, bucket, + key_prefix, train, labels, encrypt) logger.debug("Created manifest file {}".format(manifest_s3_file)) return RecordSet(manifest_s3_file, num_records=train.shape[0], feature_dim=train.shape[1], channel=channel) @@ -239,15 +242,17 @@ def _build_shards(num_shards, array): return shards -def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=None): - """Upload the training ``array`` and ``labels`` arrays to ``num_shards`` s3 objects, - stored in "s3://``bucket``/``key_prefix``/".""" +def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=None, encrypt=False): + """Upload the training ``array`` and ``labels`` arrays to ``num_shards`` S3 objects, + stored in "s3://``bucket``/``key_prefix``/". Optionally ``encrypt`` the S3 objects using + AES-256.""" shards = _build_shards(num_shards, array) if labels is not None: label_shards = _build_shards(num_shards, labels) uploaded_files = [] if key_prefix[-1] != '/': key_prefix = key_prefix + '/' + extra_put_kwargs = {'ServerSideEncryption': 'AES256'} if encrypt else {} try: for shard_index, shard in enumerate(shards): with tempfile.TemporaryFile() as file: @@ -260,12 +265,12 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels= file_name = "matrix_{}.pbr".format(shard_index_string) key = key_prefix + file_name logger.debug("Creating object {} in bucket {}".format(key, bucket)) - s3.Object(bucket, key).put(Body=file) + s3.Object(bucket, key).put(Body=file, **extra_put_kwargs) uploaded_files.append(file_name) manifest_key = key_prefix + ".amazon.manifest" manifest_str = json.dumps( [{'prefix': 's3://{}/{}'.format(bucket, key_prefix)}] + uploaded_files) - s3.Object(bucket, manifest_key).put(Body=manifest_str.encode('utf-8')) + s3.Object(bucket, manifest_key).put(Body=manifest_str.encode('utf-8'), **extra_put_kwargs) return "s3://{}/{}".format(bucket, manifest_key) except Exception as ex: # pylint: disable=broad-except try: diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index f7329b61ce..168d90d3e7 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -14,7 +14,7 @@ import numpy as np import pytest -from mock import Mock, patch, call +from mock import ANY, Mock, patch, call # Use PCA as a test implementation of AmazonAlgorithmEstimator from sagemaker.amazon.pca import PCA @@ -143,6 +143,22 @@ def test_prepare_for_training_list_no_train_channel(sagemaker_session): assert 'Must provide train channel.' in str(ex) +def test_prepare_for_training_encrypt(sagemaker_session): + pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS) + + train = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 8.0], [44.0, 55.0, 66.0]] + labels = [99, 85, 87, 2] + with patch('sagemaker.amazon.amazon_estimator.upload_numpy_to_s3_shards', + return_value='manfiest_file') as mock_upload: + pca.record_set(np.array(train), np.array(labels)) + pca.record_set(np.array(train), np.array(labels), encrypt=True) + + def make_upload_call(encrypt): + return call(ANY, ANY, ANY, ANY, ANY, ANY, encrypt) + + mock_upload.assert_has_calls([make_upload_call(False), make_upload_call(True)]) + + @patch('time.strftime', return_value=TIMESTAMP) def test_fit_ndarray(time, sagemaker_session): mock_s3 = Mock() @@ -185,9 +201,21 @@ def test_upload_numpy_to_s3_shards(): mock_s3 = Mock() mock_object = Mock() mock_s3.Object = Mock(return_value=mock_object) + mock_put = mock_s3.Object.return_value.put array = np.array([[j for j in range(10)] for i in range(10)]) labels = np.array([i for i in range(10)]) - upload_numpy_to_s3_shards(3, mock_s3, BUCKET_NAME, "key-prefix", array, labels) + num_shards = 3 + num_objects = num_shards + 1 # Account for the manifest file. + + def make_all_put_calls(**kwargs): + return [call(Body=ANY, **kwargs) for i in range(num_objects)] + + upload_numpy_to_s3_shards(num_shards, mock_s3, BUCKET_NAME, "key-prefix", array, labels) mock_s3.Object.assert_has_calls([call(BUCKET_NAME, 'key-prefix/matrix_0.pbr')]) mock_s3.Object.assert_has_calls([call(BUCKET_NAME, 'key-prefix/matrix_1.pbr')]) mock_s3.Object.assert_has_calls([call(BUCKET_NAME, 'key-prefix/matrix_2.pbr')]) + mock_put.assert_has_calls(make_all_put_calls()) + + mock_put.reset() + upload_numpy_to_s3_shards(3, mock_s3, BUCKET_NAME, "key-prefix", array, labels, encrypt=True) + mock_put.assert_has_calls(make_all_put_calls(ServerSideEncryption='AES256')) From 26f3bc25c5024236eb3c165c5ebeb70aa68632e2 Mon Sep 17 00:00:00 2001 From: "Nickolas J. Wilson" Date: Fri, 17 May 2019 15:52:30 -0500 Subject: [PATCH 2/3] change: add integration test for "record_set" --- tests/integ/test_record_set.py | 42 ++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/integ/test_record_set.py diff --git a/tests/integ/test_record_set.py b/tests/integ/test_record_set.py new file mode 100644 index 0000000000..a27e34d4fe --- /dev/null +++ b/tests/integ/test_record_set.py @@ -0,0 +1,42 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import gzip +import os +import pickle +import sys + +from six.moves.urllib.parse import urlparse + +from sagemaker import KMeans +from tests.integ import DATA_DIR + + +def test_record_set(sagemaker_session): + """Test the method ``AmazonAlgorithmEstimatorBase.record_set``. + + In particular, test that the objects uploaded to the S3 bucket are encrypted. + """ + data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') + pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + with gzip.open(data_path, 'rb') as f: + train_set, _, _ = pickle.load(f, **pickle_args) + kmeans = KMeans(role='SageMakerRole', train_instance_count=1, + train_instance_type='ml.c4.xlarge', + k=10, sagemaker_session=sagemaker_session) + record_set = kmeans.record_set(train_set[0][:100], encrypt=True) + parsed_url = urlparse(record_set.s3_data) + s3_client = sagemaker_session.boto_session.client('s3') + head = s3_client.head_object(Bucket=parsed_url.netloc, Key=parsed_url.path.lstrip('/')) + assert head['ServerSideEncryption'] == 'AES256' From a84b1573f31bfb8ffb6f542f0d29e547ebb61813 Mon Sep 17 00:00:00 2001 From: "Nickolas J. Wilson" Date: Fri, 17 May 2019 19:00:02 -0500 Subject: [PATCH 3/3] change: lengthen a variable name --- tests/integ/test_record_set.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integ/test_record_set.py b/tests/integ/test_record_set.py index a27e34d4fe..96e2b84aa1 100644 --- a/tests/integ/test_record_set.py +++ b/tests/integ/test_record_set.py @@ -30,8 +30,8 @@ def test_record_set(sagemaker_session): """ data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} - with gzip.open(data_path, 'rb') as f: - train_set, _, _ = pickle.load(f, **pickle_args) + with gzip.open(data_path, 'rb') as file_object: + train_set, _, _ = pickle.load(file_object, **pickle_args) kmeans = KMeans(role='SageMakerRole', train_instance_count=1, train_instance_type='ml.c4.xlarge', k=10, sagemaker_session=sagemaker_session)