Skip to content
Merged
21 changes: 13 additions & 8 deletions src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def fit(self, records, mini_batch_size=None, wait=True, logs=True, job_name=None
if wait:
self.latest_training_job.wait(logs=logs)

def record_set(self, train, labels=None, channel="train"):
def record_set(self, train, labels=None, channel="train", encrypt=False):
"""Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.

For the 2D ``ndarray`` ``train``, each row is converted to a :class:`~Record` object.
Expand All @@ -177,8 +177,10 @@ def record_set(self, train, labels=None, channel="train"):
Args:
train (numpy.ndarray): A 2D numpy array of training data.
labels (numpy.ndarray): A 1D numpy array of labels. Its length must be equal to the
number of rows in ``train``.
number of rows in ``train``.
channel (str): The SageMaker TrainingJob channel this RecordSet should be assigned to.
encrypt (bool): Specifies whether the objects uploaded to S3 are encrypted on the
server side using AES-256 (default: ``False``).
Returns:
RecordSet: A RecordSet referencing the encoded, uploading training and label data.
"""
Expand All @@ -188,7 +190,8 @@ def record_set(self, train, labels=None, channel="train"):
key_prefix = key_prefix + '{}-{}/'.format(type(self).__name__, sagemaker_timestamp())
key_prefix = key_prefix.lstrip('/')
logger.debug('Uploading to bucket {} and key_prefix {}'.format(bucket, key_prefix))
manifest_s3_file = upload_numpy_to_s3_shards(self.train_instance_count, s3, bucket, key_prefix, train, labels)
manifest_s3_file = upload_numpy_to_s3_shards(self.train_instance_count, s3, bucket,
key_prefix, train, labels, encrypt)
logger.debug("Created manifest file {}".format(manifest_s3_file))
return RecordSet(manifest_s3_file, num_records=train.shape[0], feature_dim=train.shape[1], channel=channel)

Expand Down Expand Up @@ -239,15 +242,17 @@ def _build_shards(num_shards, array):
return shards


def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=None):
"""Upload the training ``array`` and ``labels`` arrays to ``num_shards`` s3 objects,
stored in "s3://``bucket``/``key_prefix``/"."""
def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=None, encrypt=False):
"""Upload the training ``array`` and ``labels`` arrays to ``num_shards`` S3 objects,
stored in "s3://``bucket``/``key_prefix``/". Optionally ``encrypt`` the S3 objects using
AES-256."""
shards = _build_shards(num_shards, array)
if labels is not None:
label_shards = _build_shards(num_shards, labels)
uploaded_files = []
if key_prefix[-1] != '/':
key_prefix = key_prefix + '/'
extra_put_kwargs = {'ServerSideEncryption': 'AES256'} if encrypt else {}
try:
for shard_index, shard in enumerate(shards):
with tempfile.TemporaryFile() as file:
Expand All @@ -260,12 +265,12 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=
file_name = "matrix_{}.pbr".format(shard_index_string)
key = key_prefix + file_name
logger.debug("Creating object {} in bucket {}".format(key, bucket))
s3.Object(bucket, key).put(Body=file)
s3.Object(bucket, key).put(Body=file, **extra_put_kwargs)
uploaded_files.append(file_name)
manifest_key = key_prefix + ".amazon.manifest"
manifest_str = json.dumps(
[{'prefix': 's3://{}/{}'.format(bucket, key_prefix)}] + uploaded_files)
s3.Object(bucket, manifest_key).put(Body=manifest_str.encode('utf-8'))
s3.Object(bucket, manifest_key).put(Body=manifest_str.encode('utf-8'), **extra_put_kwargs)
return "s3://{}/{}".format(bucket, manifest_key)
except Exception as ex: # pylint: disable=broad-except
try:
Expand Down
42 changes: 42 additions & 0 deletions tests/integ/test_record_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import gzip
import os
import pickle
import sys

from six.moves.urllib.parse import urlparse

from sagemaker import KMeans
from tests.integ import DATA_DIR


def test_record_set(sagemaker_session):
"""Test the method ``AmazonAlgorithmEstimatorBase.record_set``.

In particular, test that the objects uploaded to the S3 bucket are encrypted.
"""
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
with gzip.open(data_path, 'rb') as file_object:
train_set, _, _ = pickle.load(file_object, **pickle_args)
kmeans = KMeans(role='SageMakerRole', train_instance_count=1,
train_instance_type='ml.c4.xlarge',
k=10, sagemaker_session=sagemaker_session)
record_set = kmeans.record_set(train_set[0][:100], encrypt=True)
parsed_url = urlparse(record_set.s3_data)
s3_client = sagemaker_session.boto_session.client('s3')
head = s3_client.head_object(Bucket=parsed_url.netloc, Key=parsed_url.path.lstrip('/'))
assert head['ServerSideEncryption'] == 'AES256'
32 changes: 30 additions & 2 deletions tests/unit/test_amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import numpy as np
import pytest
from mock import Mock, patch, call
from mock import ANY, Mock, patch, call

# Use PCA as a test implementation of AmazonAlgorithmEstimator
from sagemaker.amazon.pca import PCA
Expand Down Expand Up @@ -143,6 +143,22 @@ def test_prepare_for_training_list_no_train_channel(sagemaker_session):
assert 'Must provide train channel.' in str(ex)


def test_prepare_for_training_encrypt(sagemaker_session):
pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS)

train = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 8.0], [44.0, 55.0, 66.0]]
labels = [99, 85, 87, 2]
with patch('sagemaker.amazon.amazon_estimator.upload_numpy_to_s3_shards',
return_value='manfiest_file') as mock_upload:
pca.record_set(np.array(train), np.array(labels))
pca.record_set(np.array(train), np.array(labels), encrypt=True)

def make_upload_call(encrypt):
return call(ANY, ANY, ANY, ANY, ANY, ANY, encrypt)

mock_upload.assert_has_calls([make_upload_call(False), make_upload_call(True)])


@patch('time.strftime', return_value=TIMESTAMP)
def test_fit_ndarray(time, sagemaker_session):
mock_s3 = Mock()
Expand Down Expand Up @@ -185,9 +201,21 @@ def test_upload_numpy_to_s3_shards():
mock_s3 = Mock()
mock_object = Mock()
mock_s3.Object = Mock(return_value=mock_object)
mock_put = mock_s3.Object.return_value.put
array = np.array([[j for j in range(10)] for i in range(10)])
labels = np.array([i for i in range(10)])
upload_numpy_to_s3_shards(3, mock_s3, BUCKET_NAME, "key-prefix", array, labels)
num_shards = 3
num_objects = num_shards + 1 # Account for the manifest file.

def make_all_put_calls(**kwargs):
return [call(Body=ANY, **kwargs) for i in range(num_objects)]

upload_numpy_to_s3_shards(num_shards, mock_s3, BUCKET_NAME, "key-prefix", array, labels)
mock_s3.Object.assert_has_calls([call(BUCKET_NAME, 'key-prefix/matrix_0.pbr')])
mock_s3.Object.assert_has_calls([call(BUCKET_NAME, 'key-prefix/matrix_1.pbr')])
mock_s3.Object.assert_has_calls([call(BUCKET_NAME, 'key-prefix/matrix_2.pbr')])
mock_put.assert_has_calls(make_all_put_calls())

mock_put.reset()
upload_numpy_to_s3_shards(3, mock_s3, BUCKET_NAME, "key-prefix", array, labels, encrypt=True)
mock_put.assert_has_calls(make_all_put_calls(ServerSideEncryption='AES256'))