diff --git a/doc/overview.rst b/doc/overview.rst index cf525cd4a4..a05a233f45 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -299,6 +299,114 @@ Here are some examples of creating estimators with Git support: Git support can be used not only for training jobs, but also for hosting models. The usage is the same as the above, and ``git_config`` should be provided when creating model objects, e.g. ``TensorFlowModel``, ``MXNetModel``, ``PyTorchModel``. +Use File Systems as Training Inputs +------------------------------------- +Amazon SageMaker supports using Amazon Elastic File System (EFS) and FSx for Lustre as data sources to use during training. +If you want use those data sources, create a file system (EFS/FSx) and mount the file system on an Amazon EC2 instance. +For more information about setting up EFS and FSx, see the following documentation: + +- `Using File Systems in Amazon EFS `__ +- `Getting Started with Amazon FSx for Lustre `__ + +The general experience uses either the ``FileSystemInput`` or ``FileSystemRecordSet`` class, which encapsulates +all of the necessary arguments required by the service to use EFS or Lustre. + +Here are examples of how to use Amazon EFS as input for training: + +.. code:: python + + # This example shows how to use FileSystemInput class + # Configure an estimator with subnets and security groups from your VPC. The EFS volume must be in + # the same VPC as your Amazon EC2 instance + estimator = TensorFlow(entry_point='tensorflow_mnist/mnist.py', + role='SageMakerRole', + train_instance_count=1, + train_instance_type='ml.c4.xlarge', + subnets=['subnet-1', 'subnet-2'] + security_group_ids=['sg-1']) + + file_system_input = FileSystemInput(file_system_id='fs-1', + file_system_type='EFS', + directory_path='tensorflow', + file_system_access_mode='ro') + + # Start an Amazon SageMaker training job with EFS using the FileSystemInput class + estimator.fit(file_system_input) + +.. code:: python + + # This example shows how to use FileSystemRecordSet class + # Configure an estimator with subnets and security groups from your VPC. The EFS volume must be in + # the same VPC as your Amazon EC2 instance + kmeans = KMeans(role='SageMakerRole', + train_instance_count=1, + train_instance_type='ml.c4.xlarge', + k=10, + subnets=['subnet-1', 'subnet-2'], + security_group_ids=['sg-1']) + + records = FileSystemRecordSet(file_system_id='fs-1, + file_system_type='EFS', + directory_path='kmeans', + num_records=784, + feature_dim=784) + + # Start an Amazon SageMaker training job with EFS using the FileSystemRecordSet class + kmeans.fit(records) + +Here are examples of how to use Amazon FSx for Lustre as input for training: + +.. code:: python + + # This example shows how to use FileSystemInput class + # Configure an estimator with subnets and security groups from your VPC. The VPC should be the same as that + # you chose for your Amazon EC2 instance + + estimator = TensorFlow(entry_point='tensorflow_mnist/mnist.py', + role='SageMakerRole', + train_instance_count=1, + train_instance_type='ml.c4.xlarge', + subnets=['subnet-1', 'subnet-2'] + security_group_ids=['sg-1']) + + + file_system_input = FileSystemInput(file_system_id='fs-2', + file_system_type='FSxLustre', + directory_path='tensorflow', + file_system_access_mode='ro') + + # Start an Amazon SageMaker training job with FSx using the FileSystemInput class + estimator.fit(file_system_input) + +.. code:: python + + # This example shows how to use FileSystemRecordSet class + # Configure an estimator with subnets and security groups from your VPC. The VPC should be the same as that + # you chose for your Amazon EC2 instance + kmeans = KMeans(role='SageMakerRole', + train_instance_count=1, + train_instance_type='ml.c4.xlarge', + k=10, + subnets=['subnet-1', 'subnet-2'], + security_group_ids=['sg-1']) + + records = FileSystemRecordSet(file_system_id='fs-=2, + file_system_type='FSxLustre', + directory_path='kmeans', + num_records=784, + feature_dim=784) + + # Start an Amazon SageMaker training job with FSx using the FileSystemRecordSet class + kmeans.fit(records) + +Data sources from EFS and FSx can also be used for hyperparameter tuning jobs. The usage is the same as above. + +A few important notes: + +- Local mode is not supported if using EFS and FSx as data sources + +- Pipe mode is not supported if using EFS as data source + Training Metrics ---------------- The SageMaker Python SDK allows you to specify a name and a regular expression for metrics you want to track for training. diff --git a/setup.py b/setup.py index bb1909e6e0..c6b0c6fea8 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -34,7 +34,7 @@ def read_version(): # Declare minimal set for installation required_packages = [ - "boto3>=1.9.169", + "boto3>=1.9.213", "numpy>=1.9.0", "protobuf>=3.1", "scipy>=0.19.0", @@ -42,6 +42,7 @@ def read_version(): "protobuf3-to-dict>=0.1.5", "docker-compose>=1.23.0", "requests>=2.20.0, <2.21", + "fabric>=2.0", ] # enum is introduced in Python 3.4. Installing enum back port diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index fe60250be8..83e7b7c56f 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -1,4 +1,4 @@ -# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -23,6 +23,7 @@ from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.common import write_numpy_to_dense_tensor from sagemaker.estimator import EstimatorBase, _TrainingJob +from sagemaker.inputs import FileSystemInput from sagemaker.model import NEO_IMAGE_ACCOUNT from sagemaker.session import s3_input from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix @@ -281,6 +282,55 @@ def records_s3_input(self): return s3_input(self.s3_data, distribution="ShardedByS3Key", s3_data_type=self.s3_data_type) +class FileSystemRecordSet(object): + """Amazon SageMaker channel configuration for a file system data source + for Amazon algorithms. + """ + + def __init__( + self, + file_system_id, + file_system_type, + directory_path, + num_records, + feature_dim, + file_system_access_mode="ro", + channel="train", + ): + """Initialize a ``FileSystemRecordSet`` object. + + Args: + file_system_id (str): An Amazon file system ID starting with 'fs-'. + file_system_type (str): The type of file system used for the input. + Valid values: 'EFS', 'FSxLustre'. + directory_path (str): Relative path to the root directory (mount point) in + the file system. Reference: + https://docs.aws.amazon.com/efs/latest/ug/mounting-fs.html and + https://docs.aws.amazon.com/efs/latest/ug/wt1-test.html + num_records (int): The number of records in the set. + feature_dim (int): The dimensionality of "values" arrays in the Record features, + and label (if each Record is labeled). + file_system_access_mode (str): Permissions for read and write. + Valid values: 'ro' or 'rw'. Defaults to 'ro'. + channel (str): The SageMaker Training Job channel this RecordSet should be bound to + """ + + self.file_system_input = FileSystemInput( + file_system_id, file_system_type, directory_path, file_system_access_mode + ) + self.feature_dim = feature_dim + self.num_records = num_records + self.channel = channel + + def __repr__(self): + """Return an unambiguous representation of this RecordSet""" + return str((FileSystemRecordSet, self.__dict__)) + + def data_channel(self): + """Return a dictionary to represent the training data in a channel for use with ``fit()``""" + return {self.channel: self.file_system_input} + + def _build_shards(num_shards, array): """ Args: diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index c61a727779..6e79e26ff8 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -308,21 +308,21 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None): about the training data. This can be one of three types: * (str) the S3 location where training data is saved. - * (dict[str, str] or dict[str, sagemaker.session.s3_input]) If using multiple channels for training data, you can specify a dict mapping channel names to strings or :func:`~sagemaker.session.s3_input` objects. - * (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide additional information as well as the path to the training dataset. See :func:`sagemaker.session.s3_input` for full details. - wait (bool): Whether the call should wait until the job completes - (default: True). - logs (bool): Whether to show the logs produced by the job. Only - meaningful when wait is True (default: True). - job_name (str): Training job name. If not specified, the estimator - generates a default job name, based on the training image name - and current timestamp. + * (sagemaker.session.FileSystemInput) - channel configuration for + a file system data source that can provide additional information as well as + the path to the training dataset. + + wait (bool): Whether the call should wait until the job completes (default: True). + logs (bool): Whether to show the logs produced by the job. + Only meaningful when wait is True (default: True). + job_name (str): Training job name. If not specified, the estimator generates + a default job name, based on the training image name and current timestamp. """ self._prepare_for_training(job_name=job_name) diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py new file mode 100644 index 0000000000..856612353d --- /dev/null +++ b/src/sagemaker/inputs.py @@ -0,0 +1,146 @@ +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Amazon SageMaker channel configurations for S3 data sources and file system data sources""" +from __future__ import absolute_import, print_function + +FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"] +FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"] + + +class s3_input(object): + """Amazon SageMaker channel configurations for S3 data sources. + + Attributes: + config (dict[str, dict]): A SageMaker ``DataSource`` referencing + a SageMaker ``S3DataSource``. + """ + + def __init__( + self, + s3_data, + distribution="FullyReplicated", + compression=None, + content_type=None, + record_wrapping=None, + s3_data_type="S3Prefix", + input_mode=None, + attribute_names=None, + shuffle_config=None, + ): + """Create a definition for input data used by an SageMaker training job. + See AWS documentation on the ``CreateTrainingJob`` API for more details on the parameters. + + Args: + s3_data (str): Defines the location of s3 data to train on. + distribution (str): Valid values: 'FullyReplicated', 'ShardedByS3Key' + (default: 'FullyReplicated'). + compression (str): Valid values: 'Gzip', None (default: None). This is used only in + Pipe input mode. + content_type (str): MIME type of the input data (default: None). + record_wrapping (str): Valid values: 'RecordIO' (default: None). + s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile', 'AugmentedManifestFile'. + If 'S3Prefix', ``s3_data`` defines a prefix of s3 objects to train on. + All objects with s3 keys beginning with ``s3_data`` will be used to train. + If 'ManifestFile' or 'AugmentedManifestFile', then ``s3_data`` defines a + single S3 manifest file or augmented manifest file (respectively), + listing the S3 data to train on. Both the ManifestFile and + AugmentedManifestFile formats are described in the SageMaker API documentation: + https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html + input_mode (str): Optional override for this channel's input mode (default: None). + By default, channels will use the input mode defined on + ``sagemaker.estimator.EstimatorBase.input_mode``, but they will ignore + that setting if this parameter is set. + + * None - Amazon SageMaker will use the input mode specified in the ``Estimator`` + * 'File' - Amazon SageMaker copies the training dataset from the S3 location to + a local directory. + * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via + a Unix-named pipe. + + attribute_names (list[str]): A list of one or more attribute names to use that are + found in a specified AugmentedManifestFile. + shuffle_config (ShuffleConfig): If specified this configuration enables shuffling on + this channel. See the SageMaker API documentation for more info: + https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html + """ + + self.config = { + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": distribution, + "S3DataType": s3_data_type, + "S3Uri": s3_data, + } + } + } + + if compression is not None: + self.config["CompressionType"] = compression + if content_type is not None: + self.config["ContentType"] = content_type + if record_wrapping is not None: + self.config["RecordWrapperType"] = record_wrapping + if input_mode is not None: + self.config["InputMode"] = input_mode + if attribute_names is not None: + self.config["DataSource"]["S3DataSource"]["AttributeNames"] = attribute_names + if shuffle_config is not None: + self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed} + + +class FileSystemInput(object): + """Amazon SageMaker channel configurations for file system data sources. + + Attributes: + config (dict[str, dict]): A Sagemaker File System ``DataSource``. + """ + + def __init__( + self, file_system_id, file_system_type, directory_path, file_system_access_mode="ro" + ): + """Create a new file system input used by an SageMaker training job. + + Args: + file_system_id (str): An Amazon file system ID starting with 'fs-'. + file_system_type (str): The type of file system used for the input. + Valid values: 'EFS', 'FSxLustre'. + directory_path (str): Relative path to the root directory (mount point) in + the file system. + Reference: https://docs.aws.amazon.com/efs/latest/ug/mounting-fs.html and + https://docs.aws.amazon.com/fsx/latest/LustreGuide/mount-fs-auto-mount-onreboot.html + file_system_access_mode (str): Permissions for read and write. + Valid values: 'ro' or 'rw'. Defaults to 'ro'. + """ + + if file_system_type not in FILE_SYSTEM_TYPES: + raise ValueError( + "Unrecognized file system type: %s. Valid values: %s." + % (file_system_type, ", ".join(FILE_SYSTEM_TYPES)) + ) + + if file_system_access_mode not in FILE_SYSTEM_ACCESS_MODES: + raise ValueError( + "Unrecognized file system access mode: %s. Valid values: %s." + % (file_system_access_mode, ", ".join(FILE_SYSTEM_ACCESS_MODES)) + ) + + self.config = { + "DataSource": { + "FileSystemDataSource": { + "FileSystemId": file_system_id, + "FileSystemType": file_system_type, + "DirectoryPath": directory_path, + "FileSystemAccessMode": file_system_access_mode, + } + } + } diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 6f8a1af028..79d41c8db1 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -1,4 +1,4 @@ -# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -16,6 +16,7 @@ from abc import abstractmethod from six import string_types +from sagemaker.inputs import FileSystemInput from sagemaker.local import file_input from sagemaker.session import s3_input @@ -127,8 +128,9 @@ def _format_inputs_to_input_config(inputs, validate_uri=True): # Deferred import due to circular dependency from sagemaker.amazon.amazon_estimator import RecordSet + from sagemaker.amazon.amazon_estimator import FileSystemRecordSet - if isinstance(inputs, RecordSet): + if isinstance(inputs, (RecordSet, FileSystemRecordSet)): inputs = inputs.data_channel() input_dict = {} @@ -143,10 +145,11 @@ def _format_inputs_to_input_config(inputs, validate_uri=True): input_dict[k] = _Job._format_string_uri_input(v, validate_uri) elif isinstance(inputs, list): input_dict = _Job._format_record_set_list_input(inputs) + elif isinstance(inputs, FileSystemInput): + input_dict["training"] = inputs else: - raise ValueError( - "Cannot format input {}. Expecting one of str, dict or s3_input".format(inputs) - ) + msg = "Cannot format input {}. Expecting one of str, dict, s3_input or FileSystemInput" + raise ValueError(msg.format(inputs)) channels = [ _Job._convert_input_to_channel(name, input) for name, input in input_dict.items() @@ -185,14 +188,12 @@ def _format_string_uri_input(uri_input, validate_uri=True, content_type=None, in ) if isinstance(uri_input, str): return s3_input(uri_input, content_type=content_type, input_mode=input_mode) - if isinstance(uri_input, s3_input): - return uri_input - if isinstance(uri_input, file_input): + if isinstance(uri_input, (s3_input, file_input, FileSystemInput)): return uri_input + raise ValueError( - "Cannot format input {}. Expecting one of str, s3_input, or file_input".format( - uri_input - ) + "Cannot format input {}. Expecting one of str, s3_input, file_input or " + "FileSystemInput".format(uri_input) ) @staticmethod @@ -263,22 +264,24 @@ def _format_model_uri_input(model_uri, validate_uri=True): @staticmethod def _format_record_set_list_input(inputs): - # Deferred import due to circular dependency """ Args: inputs: """ - from sagemaker.amazon.amazon_estimator import RecordSet + # Deferred import due to circular dependency + from sagemaker.amazon.amazon_estimator import FileSystemRecordSet, RecordSet input_dict = {} for record in inputs: - if not isinstance(record, RecordSet): - raise ValueError("List compatible only with RecordSets.") + if not isinstance(record, (RecordSet, FileSystemRecordSet)): + raise ValueError("List compatible only with RecordSets or FileSystemRecordSets.") if record.channel in input_dict: raise ValueError("Duplicate channels not allowed.") - - input_dict[record.channel] = record.records_s3_input() + if isinstance(record, RecordSet): + input_dict[record.channel] = record.records_s3_input() + if isinstance(record, FileSystemRecordSet): + input_dict[record.channel] = record.file_system_input return input_dict diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index a3426e0ddf..617982de52 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -29,6 +29,9 @@ import sagemaker.logs from sagemaker import vpc_utils + +# import s3_input for backward compatibility +from sagemaker.inputs import s3_input # noqa # pylint: disable=unused-import from sagemaker.user_agent import prepend_user_agent from sagemaker.utils import ( name_from_image, @@ -1605,87 +1608,6 @@ def get_execution_role(sagemaker_session=None): raise ValueError(message.format(arn)) -class s3_input(object): - """Amazon SageMaker channel configurations for S3 data sources. - - Attributes: - config (dict[str, dict]): A SageMaker ``DataSource`` referencing a SageMaker - ``S3DataSource``. - """ - - def __init__( - self, - s3_data, - distribution="FullyReplicated", - compression=None, - content_type=None, - record_wrapping=None, - s3_data_type="S3Prefix", - input_mode=None, - attribute_names=None, - shuffle_config=None, - ): - """Create a definition for input data used by an SageMaker training job. - - See AWS documentation on the ``CreateTrainingJob`` API for more details on the parameters. - - Args: - s3_data (str): Defines the location of s3 data to train on. - distribution (str): Valid values: 'FullyReplicated', 'ShardedByS3Key' - (default: 'FullyReplicated'). - compression (str): Valid values: 'Gzip', None (default: None). This is used only in - Pipe input mode. - content_type (str): MIME type of the input data (default: None). - record_wrapping (str): Valid values: 'RecordIO' (default: None). - s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile', 'AugmentedManifestFile'. - If 'S3Prefix', ``s3_data`` defines a prefix of s3 objects to train on. All objects - with s3 keys beginning with ``s3_data`` will be used to train. If 'ManifestFile' - or 'AugmentedManifestFile', then ``s3_data`` defines a single s3 manifest file or - augmented manifest file (respectively), listing the s3 data to train on. Both the - ManifestFile and AugmentedManifestFile formats are described in the SageMaker API - documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html - input_mode (str): Optional override for this channel's input mode (default: None). By - default, channels will use the input mode defined on - ``sagemaker.estimator.EstimatorBase.input_mode``, but they will ignore that setting - if this parameter is set. - * None - Amazon SageMaker will use the input mode specified in the - ``Estimator``. - * 'File' - Amazon SageMaker copies the training dataset from the S3 location - to a local directory. - * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via - a Unix-named pipe. - - attribute_names (list[str]): A list of one or more attribute names to use that are - found in a specified AugmentedManifestFile. - shuffle_config (ShuffleConfig): If specified this configuration enables shuffling on - this channel. See the SageMaker API documentation for more info: - https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html - """ - - self.config = { - "DataSource": { - "S3DataSource": { - "S3DataDistributionType": distribution, - "S3DataType": s3_data_type, - "S3Uri": s3_data, - } - } - } - - if compression is not None: - self.config["CompressionType"] = compression - if content_type is not None: - self.config["ContentType"] = content_type - if record_wrapping is not None: - self.config["RecordWrapperType"] = record_wrapping - if input_mode is not None: - self.config["InputMode"] = input_mode - if attribute_names is not None: - self.config["DataSource"]["S3DataSource"]["AttributeNames"] = attribute_names - if shuffle_config is not None: - self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed} - - class ShuffleConfig(object): """ Used to configure channel shuffling using a seed. See SageMaker documentation for diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index debbc5c013..4ac5a8b690 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -1,4 +1,4 @@ -# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -20,7 +20,11 @@ from enum import Enum import sagemaker -from sagemaker.amazon.amazon_estimator import RecordSet, AmazonAlgorithmEstimatorBase +from sagemaker.amazon.amazon_estimator import ( + RecordSet, + AmazonAlgorithmEstimatorBase, + FileSystemRecordSet, +) from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.analytics import HyperparameterTuningJobAnalytics from sagemaker.estimator import Framework @@ -315,39 +319,41 @@ def fit(self, inputs=None, job_name=None, include_cls_metadata=False, **kwargs): any of the following forms: * (str) - The S3 location where training data is saved. - - * (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple - channels for training data, you can specify a dict mapping channel - names to strings or :func:`~sagemaker.session.s3_input` - objects. - + * (dict[str, str] or dict[str, sagemaker.session.s3_input]) - + If using multiple channels for training data, you can specify + a dict mapping channel names to strings or + :func:`~sagemaker.session.s3_input` objects. * (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can - provide additional information about the training dataset. See - :func:`sagemaker.session.s3_input` for full details. - + provide additional information about the training dataset. + See :func:`sagemaker.session.s3_input` for full details. + * (sagemaker.session.FileSystemInput) - channel configuration for + a file system data source that can provide additional information as well as + the path to the training dataset. * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of - Amazon :class:~`Record` objects serialized and stored in - S3. For use with an estimator for an Amazon algorithm. - + Amazon :class:~`Record` objects serialized and stored in S3. + For use with an estimator for an Amazon algorithm. + * (sagemaker.amazon.amazon_estimator.FileSystemRecordSet) - + Amazon SageMaker channel configuration for a file system data source for + Amazon algorithms. * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of - :class:~`sagemaker.amazon.amazon_estimator.RecordSet` - objects, where each instance is a different channel of - training data. - job_name (str): Tuning job name. If not specified, the tuner - generates a default job name, based on the training image name - and current timestamp. - include_cls_metadata (bool): Whether or not the hyperparameter - tuning job should include information about the estimator class - (default: False). This information is passed as a - hyperparameter, so if the algorithm you are using cannot handle - unknown hyperparameters (e.g. an Amazon SageMaker built-in - algorithm that does not have a custom estimator in the Python - SDK), then set ``include_cls_metadata`` to ``False``. - **kwargs: Other arguments needed for training. Please refer to the - ``fit()`` method of the associated estimator to see what other - arguments are needed. + :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, + where each instance is a different channel of training data. + * (list[sagemaker.amazon.amazon_estimator.FileSystemRecordSet]) - A list of + :class:~`sagemaker.amazon.amazon_estimator.FileSystemRecordSet` objects, + where each instance is a different channel of training data. + + job_name (str): Tuning job name. If not specified, the tuner generates + a default job name, based on the training image name and current timestamp. + include_cls_metadata (bool): Whether or not the hyperparameter tuning job should include + information about the estimator class (default: False). This information is passed + as a hyperparameter, so if the algorithm you are using cannot handle + unknown hyperparameters (e.g. an Amazon SageMaker built-in algorithm that + does not have a custom estimator in the Python SDK), then set + ``include_cls_metadata`` to ``False``. + **kwargs: Other arguments needed for training. Please refer to the ``fit()`` method of + the associated estimator to see what other arguments are needed. """ - if isinstance(inputs, (list, RecordSet)): + if isinstance(inputs, (list, RecordSet, FileSystemRecordSet)): self.estimator._prepare_for_training(inputs, **kwargs) else: self.estimator._prepare_for_training(job_name) diff --git a/tests/data/protobuf_data/matrix_0.pbr b/tests/data/protobuf_data/matrix_0.pbr new file mode 100644 index 0000000000..6e97b47c81 Binary files /dev/null and b/tests/data/protobuf_data/matrix_0.pbr differ diff --git a/tests/integ/file_system_input_utils.py b/tests/integ/file_system_input_utils.py new file mode 100644 index 0000000000..deb8ff8569 --- /dev/null +++ b/tests/integ/file_system_input_utils.py @@ -0,0 +1,350 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import collections +import logging +import os +from os import path +import stat +import tempfile +import time +import uuid + +from botocore.exceptions import ClientError +from fabric import Connection + +from tests.integ.vpc_test_utils import check_or_create_vpc_resources_efs_fsx + +VPC_NAME = "sagemaker-efs-fsx-vpc" +EFS_CREATION_TOKEN = str(uuid.uuid4()) +PREFIX = "ec2_fs_key_" +KEY_NAME = PREFIX + str(uuid.uuid4().hex.upper()[0:8]) +ROLE_NAME = "SageMakerRole" +REGION = "us-west-2" +EC2_INSTANCE_TYPE = "t2.micro" +AMI_ID = "ami-082b5a644766e0e6f" +MIN_COUNT = 1 +MAX_COUNT = 1 +TIME_SLEEP_DURATION = 10 + +RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "data") +MNIST_RESOURCE_PATH = os.path.join(RESOURCE_PATH, "tensorflow_mnist") +MNIST_LOCAL_DATA = os.path.join(MNIST_RESOURCE_PATH, "data") +ONE_P_RESOURCE_PATH = os.path.join(RESOURCE_PATH, "protobuf_data") +ONE_P_LOCAL_DATA = os.path.join(ONE_P_RESOURCE_PATH, "matrix_0.pbr") + +SCRIPTS_FOLDER = os.path.join(os.path.dirname(__file__), "..", "scripts") +FS_MOUNT_SCRIPT = os.path.join(SCRIPTS_FOLDER, "fs_mount_setup.sh") +FILE_NAME = KEY_NAME + ".pem" +KEY_PATH = os.path.join(tempfile.gettempdir(), FILE_NAME) +STORAGE_CAPACITY_IN_BYTES = 3600 + +FsResources = collections.namedtuple( + "FsResources", + [ + "key_name", + "key_path", + "role_name", + "subnet_id", + "security_group_ids", + "file_system_efs_id", + "file_system_fsx_id", + "ec2_instance_id", + "mount_efs_target_id", + ], +) + + +def set_up_efs_fsx(sagemaker_session): + _check_or_create_key_pair(sagemaker_session) + _check_or_create_iam_profile_and_attach_role(sagemaker_session) + subnet_ids, security_group_ids = check_or_create_vpc_resources_efs_fsx( + sagemaker_session, REGION, VPC_NAME + ) + + ec2_instance = _create_ec2_instance( + sagemaker_session, + AMI_ID, + EC2_INSTANCE_TYPE, + KEY_NAME, + MIN_COUNT, + MAX_COUNT, + security_group_ids, + subnet_ids[0], + ) + + file_system_efs_id = _check_or_create_efs(sagemaker_session) + mount_efs_target_id = _create_efs_mount(sagemaker_session, file_system_efs_id) + + file_system_fsx_id = _check_or_create_fsx(sagemaker_session) + + fs_resources = FsResources( + KEY_NAME, + KEY_PATH, + ROLE_NAME, + subnet_ids[0], + security_group_ids, + file_system_efs_id, + file_system_fsx_id, + ec2_instance.id, + mount_efs_target_id, + ) + + try: + connected_instance = _connect_ec2_instance(ec2_instance) + _upload_data_and_mount_fs(connected_instance, file_system_efs_id, file_system_fsx_id) + except Exception: + tear_down(sagemaker_session, fs_resources) + raise + + return fs_resources + + +def _connect_ec2_instance(ec2_instance): + public_ip_address = ec2_instance.public_ip_address + connected_instance = Connection( + host=public_ip_address, port=22, user="ec2-user", connect_kwargs={"key_filename": KEY_PATH} + ) + return connected_instance + + +def _upload_data_and_mount_fs(connected_instance, file_system_efs_id, file_system_fsx_id): + connected_instance.put(FS_MOUNT_SCRIPT, ".") + connected_instance.run("mkdir temp_tf; mkdir temp_one_p", in_stream=False) + for dir_name, subdir_list, file_list in os.walk(MNIST_LOCAL_DATA): + for fname in file_list: + local_file = os.path.join(MNIST_LOCAL_DATA, fname) + connected_instance.put(local_file, "temp_tf/") + connected_instance.put(ONE_P_LOCAL_DATA, "temp_one_p/") + connected_instance.run( + "sudo sh fs_mount_setup.sh {} {}".format(file_system_efs_id, file_system_fsx_id), + in_stream=False, + ) + + +def _check_or_create_efs(sagemaker_session): + efs_client = sagemaker_session.boto_session.client("efs") + file_system_exists = False + efs_id = "" + try: + create_response = efs_client.create_file_system(CreationToken=EFS_CREATION_TOKEN) + efs_id = create_response["FileSystemId"] + except ClientError as e: + error_code = e.response["Error"]["Code"] + if error_code == "FileSystemAlreadyExists": + file_system_exists = True + logging.warning( + "File system with given creation token %s already exists", EFS_CREATION_TOKEN + ) + else: + raise + + if file_system_exists: + desc = efs_client.describe_file_systems(CreationToken=EFS_CREATION_TOKEN) + efs_id = desc["FileSystems"][0]["FileSystemId"] + mount_target_id = efs_client.describe_mount_targets(FileSystemId=efs_id)["MountTargets"][0][ + "MountTargetId" + ] + return efs_id, mount_target_id + + for _ in retries(50, "Checking EFS creating status"): + desc = efs_client.describe_file_systems(CreationToken=EFS_CREATION_TOKEN) + status = desc["FileSystems"][0]["LifeCycleState"] + if status == "available": + break + + return efs_id + + +def _create_efs_mount(sagemaker_session, file_system_id): + subnet_ids, security_group_ids = check_or_create_vpc_resources_efs_fsx( + sagemaker_session, REGION, VPC_NAME + ) + efs_client = sagemaker_session.boto_session.client("efs") + mount_response = efs_client.create_mount_target( + FileSystemId=file_system_id, SubnetId=subnet_ids[0], SecurityGroups=security_group_ids + ) + mount_target_id = mount_response["MountTargetId"] + + for _ in retries(50, "Checking EFS mounting target status"): + desc = efs_client.describe_mount_targets(MountTargetId=mount_target_id) + status = desc["MountTargets"][0]["LifeCycleState"] + if status == "available": + break + + return mount_target_id + + +def _check_or_create_fsx(sagemaker_session): + fsx_client = sagemaker_session.boto_session.client("fsx") + subnet_ids, security_group_ids = check_or_create_vpc_resources_efs_fsx( + sagemaker_session, REGION, VPC_NAME + ) + create_response = fsx_client.create_file_system( + FileSystemType="LUSTRE", + StorageCapacity=STORAGE_CAPACITY_IN_BYTES, + SubnetIds=[subnet_ids[0]], + SecurityGroupIds=security_group_ids, + ) + fsx_id = create_response["FileSystem"]["FileSystemId"] + + for _ in retries(50, "Checking FSX creating status"): + desc = fsx_client.describe_file_systems(FileSystemIds=[fsx_id]) + status = desc["FileSystems"][0]["Lifecycle"] + if status == "AVAILABLE": + break + + return fsx_id + + +def _create_ec2_instance( + sagemaker_session, + image_id, + instance_type, + key_name, + min_count, + max_count, + security_group_ids, + subnet_id, +): + ec2_resource = sagemaker_session.boto_session.resource("ec2") + ec2_instances = ec2_resource.create_instances( + ImageId=image_id, + InstanceType=instance_type, + KeyName=key_name, + MinCount=min_count, + MaxCount=max_count, + IamInstanceProfile={"Name": ROLE_NAME}, + DryRun=False, + NetworkInterfaces=[ + { + "SubnetId": subnet_id, + "DeviceIndex": 0, + "AssociatePublicIpAddress": True, + "Groups": security_group_ids, + } + ], + ) + + ec2_instances[0].wait_until_running() + ec2_instances[0].reload() + ec2_client = sagemaker_session.boto_session.client("ec2") + + for _ in retries(30, "Checking EC2 creation status"): + statuses = ec2_client.describe_instance_status(InstanceIds=[ec2_instances[0].id]) + status = statuses["InstanceStatuses"][0] + if status["InstanceStatus"]["Status"] == "ok" and status["SystemStatus"]["Status"] == "ok": + break + return ec2_instances[0] + + +def _check_key_pair_and_cleanup_old_artifacts(sagemaker_session): + ec2_client = sagemaker_session.boto_session.client("ec2") + response = ec2_client.describe_key_pairs(Filters=[{"Name": "key-name", "Values": [KEY_NAME]}]) + if len(response["KeyPairs"]) > 0 and not path.exists(KEY_PATH): + ec2_client.delete_key_pair(KeyName=KEY_NAME) + if len(response["KeyPairs"]) == 0 and path.exists(KEY_PATH): + os.remove(KEY_PATH) + return len(response["KeyPairs"]) > 0 and path.exists(KEY_PATH) + + +def _check_or_create_key_pair(sagemaker_session): + if _check_key_pair_and_cleanup_old_artifacts(sagemaker_session): + return + ec2_client = sagemaker_session.boto_session.client("ec2") + key_pair = ec2_client.create_key_pair(KeyName=KEY_NAME) + with open(KEY_PATH, "w") as file: + file.write(key_pair["KeyMaterial"]) + fd = os.open(KEY_PATH, os.O_RDONLY) + os.fchmod(fd, stat.S_IREAD) + + +def _delete_key_pair(sagemaker_session): + ec2_client = sagemaker_session.boto_session.client("ec2") + ec2_client.delete_key_pair(KeyName=KEY_NAME) + os.remove(KEY_PATH) + + +def _terminate_instance(ec2_resource, instance_ids): + ec2_resource.instances.filter(InstanceIds=instance_ids).terminate() + + +def _check_or_create_iam_profile_and_attach_role(sagemaker_session): + if _instance_profile_exists(sagemaker_session): + return + iam_client = sagemaker_session.boto_session.client("iam") + iam_client.create_instance_profile(InstanceProfileName=ROLE_NAME) + iam_client.add_role_to_instance_profile(InstanceProfileName=ROLE_NAME, RoleName=ROLE_NAME) + + for _ in retries(30, "Checking EC2 instance profile creating status"): + profile_info = iam_client.get_instance_profile(InstanceProfileName=ROLE_NAME) + if profile_info["InstanceProfile"]["Roles"][0]["RoleName"] == ROLE_NAME: + break + + +def _instance_profile_exists(sagemaker_session): + iam = sagemaker_session.boto_session.client("iam") + try: + iam.get_instance_profile(InstanceProfileName=ROLE_NAME) + except ClientError as e: + error_code = e.response["Error"]["Code"] + message = e.response["Error"]["Message"] + if error_code == "NoSuchEntity" and ROLE_NAME in message: + return False + else: + raise + return True + + +def retries(max_retry_count, exception_message_prefix): + current_retry_count = 0 + while current_retry_count <= max_retry_count: + yield current_retry_count + + current_retry_count += 1 + time.sleep(TIME_SLEEP_DURATION) + + raise Exception( + "{} has reached the maximum retry count {}".format( + exception_message_prefix, max_retry_count + ) + ) + + +def tear_down(sagemaker_session, fs_resources): + fsx_client = sagemaker_session.boto_session.client("fsx") + file_system_fsx_id = fs_resources.file_system_fsx_id + fsx_client.delete_file_system(FileSystemId=file_system_fsx_id) + + efs_client = sagemaker_session.boto_session.client("efs") + mount_efs_target_id = fs_resources.mount_efs_target_id + efs_client.delete_mount_target(MountTargetId=mount_efs_target_id) + + file_system_efs_id = fs_resources.file_system_efs_id + for _ in retries(30, "Checking mount target deleting status"): + desc = efs_client.describe_mount_targets(FileSystemId=file_system_efs_id) + if len(desc["MountTargets"]) > 0: + status = desc["MountTargets"][0]["LifeCycleState"] + if status == "deleted": + break + else: + break + + efs_client.delete_file_system(FileSystemId=file_system_efs_id) + + ec2_resource = sagemaker_session.boto_session.resource("ec2") + instance_id = fs_resources.ec2_instance_id + _terminate_instance(ec2_resource, [instance_id]) + + _delete_key_pair(sagemaker_session) diff --git a/tests/integ/s3_utils.py b/tests/integ/s3_utils.py new file mode 100644 index 0000000000..10ed99ec76 --- /dev/null +++ b/tests/integ/s3_utils.py @@ -0,0 +1,29 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import boto3 +from six.moves.urllib.parse import urlparse + + +def assert_s3_files_exist(sagemaker_session, s3_url, files): + parsed_url = urlparse(s3_url) + region = sagemaker_session.boto_region_name + s3 = boto3.client("s3", region_name=region) + contents = s3.list_objects_v2(Bucket=parsed_url.netloc, Prefix=parsed_url.path.lstrip("/"))[ + "Contents" + ] + for f in files: + found = [x["Key"] for x in contents if x["Key"].endswith(f)] + if not found: + raise ValueError("File {} is not found under {}".format(f, s3_url)) diff --git a/tests/integ/test_kmeans_efs_fsx.py b/tests/integ/test_kmeans_efs_fsx.py new file mode 100644 index 0000000000..c30c6fdbcb --- /dev/null +++ b/tests/integ/test_kmeans_efs_fsx.py @@ -0,0 +1,219 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from sagemaker import KMeans +from sagemaker.amazon.amazon_estimator import FileSystemRecordSet +from sagemaker.parameter import IntegerParameter, CategoricalParameter +from sagemaker.tuner import HyperparameterTuner +from sagemaker.utils import unique_name_from_base +from tests.integ import TRAINING_DEFAULT_TIMEOUT_MINUTES, TUNING_DEFAULT_TIMEOUT_MINUTES +from tests.integ.s3_utils import assert_s3_files_exist +from tests.integ.file_system_input_utils import set_up_efs_fsx, tear_down +from tests.integ.timeout import timeout + +TRAIN_INSTANCE_TYPE = "ml.c4.xlarge" +TRAIN_INSTANCE_COUNT = 1 +OBJECTIVE_METRIC_NAME = "test:msd" +EFS_DIR_PATH = "/one_p_mnist" +FSX_DIR_PATH = "/fsx/one_p_mnist" +MAX_JOBS = 2 +MAX_PARALLEL_JOBS = 2 +K = 10 +NUM_RECORDS = 784 +FEATURE_DIM = 784 + + +@pytest.fixture(scope="module") +def efs_fsx_setup(sagemaker_session): + fs_resources = set_up_efs_fsx(sagemaker_session) + try: + yield fs_resources + finally: + tear_down(sagemaker_session, fs_resources) + + +@pytest.mark.canary_quick +def test_kmeans_efs(efs_fsx_setup, sagemaker_session): + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + subnets = [efs_fsx_setup.subnet_id] + security_group_ids = efs_fsx_setup.security_group_ids + role = efs_fsx_setup.role_name + kmeans = KMeans( + role=role, + train_instance_count=TRAIN_INSTANCE_COUNT, + train_instance_type=TRAIN_INSTANCE_TYPE, + k=K, + sagemaker_session=sagemaker_session, + subnets=subnets, + security_group_ids=security_group_ids, + ) + + file_system_efs_id = efs_fsx_setup.file_system_efs_id + records = FileSystemRecordSet( + file_system_id=file_system_efs_id, + file_system_type="EFS", + directory_path=EFS_DIR_PATH, + num_records=NUM_RECORDS, + feature_dim=FEATURE_DIM, + ) + + job_name = unique_name_from_base("kmeans-efs") + kmeans.fit(records, job_name=job_name) + model_path, _ = kmeans.model_data.rsplit("/", 1) + assert_s3_files_exist(sagemaker_session, model_path, ["model.tar.gz"]) + + +@pytest.mark.canary_quick +def test_kmeans_fsx(efs_fsx_setup, sagemaker_session): + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + subnets = [efs_fsx_setup.subnet_id] + security_group_ids = efs_fsx_setup.security_group_ids + role = efs_fsx_setup.role_name + kmeans = KMeans( + role=role, + train_instance_count=TRAIN_INSTANCE_COUNT, + train_instance_type=TRAIN_INSTANCE_TYPE, + k=K, + sagemaker_session=sagemaker_session, + subnets=subnets, + security_group_ids=security_group_ids, + ) + + file_system_fsx_id = efs_fsx_setup.file_system_fsx_id + records = FileSystemRecordSet( + file_system_id=file_system_fsx_id, + file_system_type="FSxLustre", + directory_path=FSX_DIR_PATH, + num_records=NUM_RECORDS, + feature_dim=FEATURE_DIM, + ) + + job_name = unique_name_from_base("kmeans-fsx") + kmeans.fit(records, job_name=job_name) + model_path, _ = kmeans.model_data.rsplit("/", 1) + assert_s3_files_exist(sagemaker_session, model_path, ["model.tar.gz"]) + + +def test_tuning_kmeans_efs(efs_fsx_setup, sagemaker_session): + subnets = [efs_fsx_setup.subnet_id] + security_group_ids = efs_fsx_setup.security_group_ids + role = efs_fsx_setup.role_name + kmeans = KMeans( + role=role, + train_instance_count=TRAIN_INSTANCE_COUNT, + train_instance_type=TRAIN_INSTANCE_TYPE, + k=K, + sagemaker_session=sagemaker_session, + subnets=subnets, + security_group_ids=security_group_ids, + ) + + hyperparameter_ranges = { + "extra_center_factor": IntegerParameter(4, 10), + "mini_batch_size": IntegerParameter(10, 100), + "epochs": IntegerParameter(1, 2), + "init_method": CategoricalParameter(["kmeans++", "random"]), + } + + with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES): + tuner = HyperparameterTuner( + estimator=kmeans, + objective_metric_name=OBJECTIVE_METRIC_NAME, + hyperparameter_ranges=hyperparameter_ranges, + objective_type="Minimize", + max_jobs=MAX_JOBS, + max_parallel_jobs=MAX_PARALLEL_JOBS, + ) + + file_system_efs_id = efs_fsx_setup.file_system_efs_id + train_records = FileSystemRecordSet( + file_system_id=file_system_efs_id, + file_system_type="EFS", + directory_path=EFS_DIR_PATH, + num_records=NUM_RECORDS, + feature_dim=FEATURE_DIM, + ) + + test_records = FileSystemRecordSet( + file_system_id=file_system_efs_id, + file_system_type="EFS", + directory_path=EFS_DIR_PATH, + num_records=NUM_RECORDS, + feature_dim=FEATURE_DIM, + channel="test", + ) + + job_name = unique_name_from_base("tune-kmeans-efs") + tuner.fit([train_records, test_records], job_name=job_name) + tuner.wait() + best_training_job = tuner.best_training_job() + assert best_training_job + + +def test_tuning_kmeans_fsx(efs_fsx_setup, sagemaker_session): + subnets = [efs_fsx_setup.subnet_id] + security_group_ids = efs_fsx_setup.security_group_ids + role = efs_fsx_setup.role_name + kmeans = KMeans( + role=role, + train_instance_count=TRAIN_INSTANCE_COUNT, + train_instance_type=TRAIN_INSTANCE_TYPE, + k=K, + sagemaker_session=sagemaker_session, + subnets=subnets, + security_group_ids=security_group_ids, + ) + + hyperparameter_ranges = { + "extra_center_factor": IntegerParameter(4, 10), + "mini_batch_size": IntegerParameter(10, 100), + "epochs": IntegerParameter(1, 2), + "init_method": CategoricalParameter(["kmeans++", "random"]), + } + + with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES): + tuner = HyperparameterTuner( + estimator=kmeans, + objective_metric_name=OBJECTIVE_METRIC_NAME, + hyperparameter_ranges=hyperparameter_ranges, + objective_type="Minimize", + max_jobs=MAX_JOBS, + max_parallel_jobs=MAX_PARALLEL_JOBS, + ) + + file_system_fsx_id = efs_fsx_setup.file_system_fsx_id + train_records = FileSystemRecordSet( + file_system_id=file_system_fsx_id, + file_system_type="FSxLustre", + directory_path=FSX_DIR_PATH, + num_records=NUM_RECORDS, + feature_dim=FEATURE_DIM, + ) + + test_records = FileSystemRecordSet( + file_system_id=file_system_fsx_id, + file_system_type="FSxLustre", + directory_path=FSX_DIR_PATH, + num_records=NUM_RECORDS, + feature_dim=FEATURE_DIM, + channel="test", + ) + + job_name = unique_name_from_base("tune-kmeans-fsx") + tuner.fit([train_records, test_records], job_name=job_name) + tuner.wait() + best_training_job = tuner.best_training_job() + assert best_training_job diff --git a/tests/integ/test_tf_efs_fsx.py b/tests/integ/test_tf_efs_fsx.py new file mode 100644 index 0000000000..02c4dd95bc --- /dev/null +++ b/tests/integ/test_tf_efs_fsx.py @@ -0,0 +1,202 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import time + +import pytest + +from sagemaker.inputs import FileSystemInput +from sagemaker.parameter import IntegerParameter +from sagemaker.tensorflow import TensorFlow +from sagemaker.tuner import HyperparameterTuner +from sagemaker.utils import unique_name_from_base +from tests.integ import TRAINING_DEFAULT_TIMEOUT_MINUTES, TUNING_DEFAULT_TIMEOUT_MINUTES +from tests.integ.s3_utils import assert_s3_files_exist +from tests.integ.file_system_input_utils import tear_down, set_up_efs_fsx +from tests.integ.timeout import timeout + +RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "data") +MNIST_RESOURCE_PATH = os.path.join(RESOURCE_PATH, "tensorflow_mnist") +SCRIPT = os.path.join(MNIST_RESOURCE_PATH, "mnist.py") +TFS_RESOURCE_PATH = os.path.join(RESOURCE_PATH, "tfs", "tfs-test-entrypoint-with-handler") +INSTANCE_TYPE = "ml.c4.xlarge" +EFS_DIR_PATH = "/tensorflow" +FSX_DIR_PATH = "/fsx/tensorflow" +MAX_JOBS = 2 +MAX_PARALLEL_JOBS = 2 +PY_VERSION = "py3" + + +@pytest.fixture(scope="module") +def efs_fsx_setup(sagemaker_session): + fs_resources = set_up_efs_fsx(sagemaker_session) + try: + yield fs_resources + finally: + tear_down(sagemaker_session, fs_resources) + + +@pytest.mark.canary_quick +def test_mnist_efs(efs_fsx_setup, sagemaker_session): + role = efs_fsx_setup.role_name + subnets = [efs_fsx_setup.subnet_id] + security_group_ids = efs_fsx_setup.security_group_ids + + estimator = TensorFlow( + entry_point=SCRIPT, + role=role, + train_instance_count=1, + train_instance_type=INSTANCE_TYPE, + sagemaker_session=sagemaker_session, + script_mode=True, + framework_version=TensorFlow.LATEST_VERSION, + py_version=PY_VERSION, + subnets=subnets, + security_group_ids=security_group_ids, + ) + + file_system_efs_id = efs_fsx_setup.file_system_efs_id + file_system_input = FileSystemInput( + file_system_id=file_system_efs_id, file_system_type="EFS", directory_path=EFS_DIR_PATH + ) + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + estimator.fit(inputs=file_system_input, job_name=unique_name_from_base("test-mnist-efs")) + + assert_s3_files_exist( + sagemaker_session, + estimator.model_dir, + ["graph.pbtxt", "model.ckpt-0.index", "model.ckpt-0.meta"], + ) + + +@pytest.mark.canary_quick +def test_mnist_lustre(efs_fsx_setup, sagemaker_session): + role = efs_fsx_setup.role_name + subnets = [efs_fsx_setup.subnet_id] + security_group_ids = efs_fsx_setup.security_group_ids + + estimator = TensorFlow( + entry_point=SCRIPT, + role=role, + train_instance_count=1, + train_instance_type=INSTANCE_TYPE, + sagemaker_session=sagemaker_session, + script_mode=True, + framework_version=TensorFlow.LATEST_VERSION, + py_version=PY_VERSION, + subnets=subnets, + security_group_ids=security_group_ids, + ) + + file_system_fsx_id = efs_fsx_setup.file_system_fsx_id + file_system_input = FileSystemInput( + file_system_id=file_system_fsx_id, file_system_type="FSxLustre", directory_path=FSX_DIR_PATH + ) + + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + estimator.fit(inputs=file_system_input, job_name=unique_name_from_base("test-mnist-lustre")) + assert_s3_files_exist( + sagemaker_session, + estimator.model_dir, + ["graph.pbtxt", "model.ckpt-0.index", "model.ckpt-0.meta"], + ) + + +def test_tuning_tf_script_mode_efs(efs_fsx_setup, sagemaker_session): + role = efs_fsx_setup.role_name + subnets = [efs_fsx_setup.subnet_id] + security_group_ids = efs_fsx_setup.security_group_ids + + estimator = TensorFlow( + entry_point=SCRIPT, + role=role, + train_instance_count=1, + train_instance_type=INSTANCE_TYPE, + script_mode=True, + sagemaker_session=sagemaker_session, + py_version=PY_VERSION, + framework_version=TensorFlow.LATEST_VERSION, + subnets=subnets, + security_group_ids=security_group_ids, + ) + + hyperparameter_ranges = {"epochs": IntegerParameter(1, 2)} + objective_metric_name = "accuracy" + metric_definitions = [{"Name": objective_metric_name, "Regex": "accuracy = ([0-9\\.]+)"}] + tuner = HyperparameterTuner( + estimator, + objective_metric_name, + hyperparameter_ranges, + metric_definitions, + max_jobs=MAX_JOBS, + max_parallel_jobs=MAX_PARALLEL_JOBS, + ) + + file_system_efs_id = efs_fsx_setup.file_system_efs_id + file_system_input = FileSystemInput( + file_system_id=file_system_efs_id, file_system_type="EFS", directory_path=EFS_DIR_PATH + ) + + with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES): + tuning_job_name = unique_name_from_base("test-tuning-tf-script-mode-efs", max_length=32) + tuner.fit(file_system_input, job_name=tuning_job_name) + time.sleep(15) + tuner.wait() + best_training_job = tuner.best_training_job() + assert best_training_job + + +def test_tuning_tf_script_mode_lustre(efs_fsx_setup, sagemaker_session): + role = efs_fsx_setup.role_name + subnets = [efs_fsx_setup.subnet_id] + security_group_ids = efs_fsx_setup.security_group_ids + + estimator = TensorFlow( + entry_point=SCRIPT, + role=role, + train_instance_count=1, + train_instance_type=INSTANCE_TYPE, + script_mode=True, + sagemaker_session=sagemaker_session, + py_version=PY_VERSION, + framework_version=TensorFlow.LATEST_VERSION, + subnets=subnets, + security_group_ids=security_group_ids, + ) + + hyperparameter_ranges = {"epochs": IntegerParameter(1, 2)} + objective_metric_name = "accuracy" + metric_definitions = [{"Name": objective_metric_name, "Regex": "accuracy = ([0-9\\.]+)"}] + tuner = HyperparameterTuner( + estimator, + objective_metric_name, + hyperparameter_ranges, + metric_definitions, + max_jobs=MAX_JOBS, + max_parallel_jobs=MAX_PARALLEL_JOBS, + ) + + file_system_fsx_id = efs_fsx_setup.file_system_fsx_id + file_system_input = FileSystemInput( + file_system_id=file_system_fsx_id, file_system_type="FSxLustre", directory_path=FSX_DIR_PATH + ) + + with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES): + tuning_job_name = unique_name_from_base("test-tuning-tf-script-mode-lustre", max_length=32) + tuner.fit(file_system_input, job_name=tuning_job_name) + time.sleep(15) + tuner.wait() + best_training_job = tuner.best_training_job() + assert best_training_job diff --git a/tests/integ/test_tf_script_mode.py b/tests/integ/test_tf_script_mode.py index fcb9ea516e..3f20dd8a26 100644 --- a/tests/integ/test_tf_script_mode.py +++ b/tests/integ/test_tf_script_mode.py @@ -18,13 +18,12 @@ import pytest -import boto3 from sagemaker.tensorflow import TensorFlow -from six.moves.urllib.parse import urlparse from sagemaker.utils import unique_name_from_base import tests.integ from tests.integ import timeout +from tests.integ.s3_utils import assert_s3_files_exist ROLE = "SageMakerRole" @@ -56,10 +55,10 @@ def test_mnist(sagemaker_session, instance_type): with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): estimator.fit(inputs=inputs, job_name=unique_name_from_base("test-tf-sm-mnist")) - _assert_s3_files_exist( + assert_s3_files_exist( + sagemaker_session, estimator.model_dir, ["graph.pbtxt", "model.ckpt-0.index", "model.ckpt-0.meta"], - sagemaker_session.boto_region_name, ) df = estimator.training_job_analytics.dataframe() assert df.size > 0 @@ -119,10 +118,10 @@ def test_mnist_distributed(sagemaker_session, instance_type): with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): estimator.fit(inputs=inputs, job_name=unique_name_from_base("test-tf-sm-distributed")) - _assert_s3_files_exist( + assert_s3_files_exist( + sagemaker_session, estimator.model_dir, ["graph.pbtxt", "model.ckpt-0.index", "model.ckpt-0.meta"], - sagemaker_session.boto_region_name, ) @@ -200,18 +199,6 @@ def test_deploy_with_input_handlers(sagemaker_session, instance_type): assert expected_result == result -def _assert_s3_files_exist(s3_url, files, region): - parsed_url = urlparse(s3_url) - s3 = boto3.client("s3", region_name=region) - contents = s3.list_objects_v2(Bucket=parsed_url.netloc, Prefix=parsed_url.path.lstrip("/"))[ - "Contents" - ] - for f in files: - found = [x["Key"] for x in contents if x["Key"].endswith(f)] - if not found: - raise ValueError("File {} is not found under {}".format(f, s3_url)) - - def _assert_tags_match(sagemaker_client, resource_arn, tags, retries=15): actual_tags = None for _ in range(retries): diff --git a/tests/integ/vpc_test_utils.py b/tests/integ/vpc_test_utils.py index 717d2a9a52..ec3f01a51e 100644 --- a/tests/integ/vpc_test_utils.py +++ b/tests/integ/vpc_test_utils.py @@ -1,4 +1,4 @@ -# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -37,17 +37,91 @@ def _get_security_id_by_name(ec2_client, name): return desc["SecurityGroups"][0]["GroupId"] +def _security_group_ids_by_vpc_id(sagemaker_session, vpc_id): + ec2_resource = sagemaker_session.boto_session.resource("ec2") + security_group_ids = [] + vpc = ec2_resource.Vpc(vpc_id) + for sg in vpc.security_groups.all(): + security_group_ids.append(sg.id) + return security_group_ids + + def _vpc_exists(ec2_client, name): desc = ec2_client.describe_vpcs(Filters=[{"Name": "tag-value", "Values": [name]}]) return len(desc["Vpcs"]) > 0 -def _get_route_table_id(ec2_client, vpc_id): +def _vpc_id_by_name(ec2_client, name): + desc = ec2_client.describe_vpcs(Filters=[{"Name": "tag-value", "Values": [name]}]) + vpc_id = desc["Vpcs"][0]["VpcId"] + return vpc_id + + +def _route_table_id(ec2_client, vpc_id): desc = ec2_client.describe_route_tables(Filters=[{"Name": "vpc-id", "Values": [vpc_id]}]) return desc["RouteTables"][0]["RouteTableId"] -def _create_vpc_with_name(ec2_client, region, name): +def check_or_create_vpc_resources_efs_fsx(sagemaker_session, region, name=VPC_NAME): + # use lock to prevent race condition when tests are running concurrently + with lock.lock(LOCK_PATH): + ec2_client = sagemaker_session.boto_session.client("ec2") + + if _vpc_exists(ec2_client, name): + vpc_id = _vpc_id_by_name(ec2_client, name) + return ( + _get_subnet_ids_by_name(ec2_client, name), + _security_group_ids_by_vpc_id(sagemaker_session, vpc_id), + ) + else: + return _create_vpc_with_name_efs_fsx(ec2_client, region, name) + + +def _create_vpc_with_name_efs_fsx(ec2_client, region, name): + vpc_id, [subnet_id_a, subnet_id_b], security_group_id = _create_vpc_resources( + ec2_client, region, name + ) + ec2_client.modify_vpc_attribute(EnableDnsHostnames={"Value": True}, VpcId=vpc_id) + + ig = ec2_client.create_internet_gateway() + internet_gateway_id = ig["InternetGateway"]["InternetGatewayId"] + ec2_client.attach_internet_gateway(InternetGatewayId=internet_gateway_id, VpcId=vpc_id) + + route_table_id = _route_table_id(ec2_client, vpc_id) + ec2_client.create_route( + DestinationCidrBlock="0.0.0.0/0", GatewayId=internet_gateway_id, RouteTableId=route_table_id + ) + ec2_client.associate_route_table(RouteTableId=route_table_id, SubnetId=subnet_id_a) + ec2_client.associate_route_table(RouteTableId=route_table_id, SubnetId=subnet_id_b) + + ec2_client.authorize_security_group_ingress( + GroupId=security_group_id, + IpPermissions=[ + { + "IpProtocol": "tcp", + "FromPort": 988, + "ToPort": 988, + "UserIdGroupPairs": [{"GroupId": security_group_id}], + }, + { + "IpProtocol": "tcp", + "FromPort": 2049, + "ToPort": 2049, + "UserIdGroupPairs": [{"GroupId": security_group_id}], + }, + { + "IpProtocol": "tcp", + "FromPort": 22, + "ToPort": 22, + "IpRanges": [{"CidrIp": "0.0.0.0/0", "Description": "For SSH to EC2"}], + }, + ], + ) + + return [subnet_id_a], [security_group_id] + + +def _create_vpc_resources(ec2_client, region, name): vpc_id = ec2_client.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"]["VpcId"] print("created vpc: {}".format(vpc_id)) @@ -68,9 +142,7 @@ def _create_vpc_with_name(ec2_client, region, name): s for s in ec2_client.describe_vpc_endpoint_services()["ServiceNames"] if s.endswith("s3") ][0] ec2_client.create_vpc_endpoint( - VpcId=vpc_id, - ServiceName=s3_service, - RouteTableIds=[_get_route_table_id(ec2_client, vpc_id)], + VpcId=vpc_id, ServiceName=s3_service, RouteTableIds=[_route_table_id(ec2_client, vpc_id)] ) print("created s3 vpc endpoint") @@ -97,6 +169,13 @@ def _create_vpc_with_name(ec2_client, region, name): Tags=[{"Key": "Name", "Value": name}], ) + return vpc_id, [subnet_id_a, subnet_id_b], security_group_id + + +def _create_vpc_with_name(ec2_client, region, name): + vpc_id, [subnet_id_a, subnet_id_b], security_group_id = _create_vpc_resources( + ec2_client, region, name + ) return [subnet_id_a, subnet_id_b], security_group_id diff --git a/tests/scripts/fs_mount_setup.sh b/tests/scripts/fs_mount_setup.sh new file mode 100644 index 0000000000..a5e5eaa051 --- /dev/null +++ b/tests/scripts/fs_mount_setup.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# +# Mounting EFS and FSx for Lustre file systems for integration Tests +FILE_SYSTEM_EFS_ID=$1 +FILE_SYSTEM_FSX_ID=$2 + +echo "Mounting EFS File Systems" +sudo yum install -y amazon-efs-utils.noarch 0:1.10-1.amzn2 +sudo mkdir efs +sudo mount -t efs "$FILE_SYSTEM_EFS_ID":/ efs +sudo mkdir efs/tensorflow +sudo mkdir efs/one_p_mnist + +echo "Mounting FSx for Lustre File System" +sudo amazon-linux-extras install -y lustre2.10 +sudo mkdir -p /mnt/fsx +sudo mount -t lustre -o noatime,flock "$FILE_SYSTEM_FSX_ID".fsx.us-west-2.amazonaws.com@tcp:/fsx /mnt/fsx +sudo mkdir /mnt/fsx/tensorflow +sudo mkdir /mnt/fsx/one_p_mnist + +echo "Copying files to the mounted folders" +sudo cp temp_tf/* efs/tensorflow +sudo cp temp_tf/* /mnt/fsx/tensorflow +sudo cp temp_one_p/* efs/one_p_mnist/ +sudo cp temp_one_p/* /mnt/fsx/one_p_mnist diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 19291fc7b9..67c65a776a 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -1,4 +1,4 @@ -# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -23,6 +23,7 @@ _build_shards, registry, get_image_uri, + FileSystemRecordSet, ) COMMON_ARGS = {"role": "myrole", "train_instance_count": 1, "train_instance_type": "ml.c4.xlarge"} @@ -265,6 +266,154 @@ def make_all_put_calls(**kwargs): mock_put.assert_has_calls(make_all_put_calls(ServerSideEncryption="AES256")) +def test_file_system_record_set_efs_default_parameters(): + file_system_id = "fs-0a48d2a1" + file_system_type = "EFS" + directory_path = "ipinsights" + num_records = 1 + feature_dim = 1 + + actual = FileSystemRecordSet( + file_system_id=file_system_id, + file_system_type=file_system_type, + directory_path=directory_path, + num_records=num_records, + feature_dim=feature_dim, + ) + + expected_input_config = { + "DataSource": { + "FileSystemDataSource": { + "DirectoryPath": "ipinsights", + "FileSystemId": "fs-0a48d2a1", + "FileSystemType": "EFS", + "FileSystemAccessMode": "ro", + } + } + } + assert actual.file_system_input.config == expected_input_config + assert actual.num_records == num_records + assert actual.feature_dim == feature_dim + assert actual.channel == "train" + + +def test_file_system_record_set_efs_customized_parameters(): + file_system_id = "fs-0a48d2a1" + file_system_type = "EFS" + directory_path = "ipinsights" + num_records = 1 + feature_dim = 1 + + actual = FileSystemRecordSet( + file_system_id=file_system_id, + file_system_type=file_system_type, + directory_path=directory_path, + num_records=num_records, + feature_dim=feature_dim, + file_system_access_mode="rw", + channel="test", + ) + + expected_input_config = { + "DataSource": { + "FileSystemDataSource": { + "DirectoryPath": "ipinsights", + "FileSystemId": "fs-0a48d2a1", + "FileSystemType": "EFS", + "FileSystemAccessMode": "rw", + } + } + } + assert actual.file_system_input.config == expected_input_config + assert actual.num_records == num_records + assert actual.feature_dim == feature_dim + assert actual.channel == "test" + + +def test_file_system_record_set_fsx_default_parameters(): + file_system_id = "fs-0a48d2a1" + file_system_type = "FSxLustre" + directory_path = "ipinsights" + num_records = 1 + feature_dim = 1 + + actual = FileSystemRecordSet( + file_system_id=file_system_id, + file_system_type=file_system_type, + directory_path=directory_path, + num_records=num_records, + feature_dim=feature_dim, + ) + expected_input_config = { + "DataSource": { + "FileSystemDataSource": { + "DirectoryPath": "ipinsights", + "FileSystemId": "fs-0a48d2a1", + "FileSystemType": "FSxLustre", + "FileSystemAccessMode": "ro", + } + } + } + assert actual.file_system_input.config == expected_input_config + assert actual.num_records == num_records + assert actual.feature_dim == feature_dim + assert actual.channel == "train" + + +def test_file_system_record_set_fsx_customized_parameters(): + file_system_id = "fs-0a48d2a1" + file_system_type = "FSxLustre" + directory_path = "ipinsights" + num_records = 1 + feature_dim = 1 + + actual = FileSystemRecordSet( + file_system_id=file_system_id, + file_system_type=file_system_type, + directory_path=directory_path, + num_records=num_records, + feature_dim=feature_dim, + file_system_access_mode="rw", + channel="test", + ) + + expected_input_config = { + "DataSource": { + "FileSystemDataSource": { + "DirectoryPath": "ipinsights", + "FileSystemId": "fs-0a48d2a1", + "FileSystemType": "FSxLustre", + "FileSystemAccessMode": "rw", + } + } + } + assert actual.file_system_input.config == expected_input_config + assert actual.num_records == num_records + assert actual.feature_dim == feature_dim + assert actual.channel == "test" + + +def test_file_system_record_set_data_channel(): + file_system_id = "fs-0a48d2a1" + file_system_type = "EFS" + directory_path = "ipinsights" + num_records = 1 + feature_dim = 1 + record_set = FileSystemRecordSet( + file_system_id=file_system_id, + file_system_type=file_system_type, + directory_path=directory_path, + num_records=num_records, + feature_dim=feature_dim, + ) + + file_system_input = Mock() + record_set.file_system_input = file_system_input + actual = record_set.data_channel() + expected = {"train": file_system_input} + assert actual == expected + + def test_get_xgboost_image_uri(): legacy_xgb_image_uri = get_image_uri(REGION, "xgboost") assert legacy_xgb_image_uri == "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1" diff --git a/tests/unit/test_inputs.py b/tests/unit/test_inputs.py new file mode 100644 index 0000000000..954e28bd71 --- /dev/null +++ b/tests/unit/test_inputs.py @@ -0,0 +1,141 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from sagemaker import s3_input +from sagemaker.inputs import FileSystemInput + + +def test_s3_input_all_defaults(): + prefix = "pre" + actual = s3_input(s3_data=prefix) + expected = { + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": prefix, + } + } + } + assert actual.config == expected + + +def test_s3_input_all_arguments(): + prefix = "pre" + distribution = "FullyReplicated" + compression = "Gzip" + content_type = "text/csv" + record_wrapping = "RecordIO" + s3_data_type = "Manifestfile" + input_mode = "Pipe" + result = s3_input( + s3_data=prefix, + distribution=distribution, + compression=compression, + input_mode=input_mode, + content_type=content_type, + record_wrapping=record_wrapping, + s3_data_type=s3_data_type, + ) + expected = { + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": distribution, + "S3DataType": s3_data_type, + "S3Uri": prefix, + } + }, + "CompressionType": compression, + "ContentType": content_type, + "RecordWrapperType": record_wrapping, + "InputMode": input_mode, + } + + assert result.config == expected + + +def test_file_system_input_default_access_mode(): + file_system_id = "fs-0a48d2a1" + file_system_type = "EFS" + directory_path = "tensorflow" + actual = FileSystemInput( + file_system_id=file_system_id, + file_system_type=file_system_type, + directory_path=directory_path, + ) + expected = { + "DataSource": { + "FileSystemDataSource": { + "FileSystemId": file_system_id, + "FileSystemType": file_system_type, + "DirectoryPath": directory_path, + "FileSystemAccessMode": "ro", + } + } + } + assert actual.config == expected + + +def test_file_system_input_all_arguments(): + file_system_id = "fs-0a48d2a1" + file_system_type = "FSxLustre" + directory_path = "tensorflow" + file_system_access_mode = "rw" + actual = FileSystemInput( + file_system_id=file_system_id, + file_system_type=file_system_type, + directory_path=directory_path, + file_system_access_mode=file_system_access_mode, + ) + expected = { + "DataSource": { + "FileSystemDataSource": { + "FileSystemId": file_system_id, + "FileSystemType": file_system_type, + "DirectoryPath": directory_path, + "FileSystemAccessMode": "rw", + } + } + } + assert actual.config == expected + + +def test_file_system_input_type_invalid(): + with pytest.raises(ValueError) as excinfo: + file_system_id = "fs-0a48d2a1" + file_system_type = "ABC" + directory_path = "tensorflow" + FileSystemInput( + file_system_id=file_system_id, + file_system_type=file_system_type, + directory_path=directory_path, + ) + assert str(excinfo.value) == "Unrecognized file system type: ABC. Valid values: FSxLustre, EFS." + + +def test_file_system_input_mode_invalid(): + with pytest.raises(ValueError) as excinfo: + file_system_id = "fs-0a48d2a1" + file_system_type = "EFS" + directory_path = "tensorflow" + file_system_access_mode = "p" + FileSystemInput( + file_system_id=file_system_id, + file_system_type=file_system_type, + directory_path=directory_path, + file_system_access_mode=file_system_access_mode, + ) + assert str(excinfo.value) == "Unrecognized file system access mode: p. Valid values: ro, rw." diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 0eacd160f0..d3bd73674b 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -1,4 +1,4 @@ -# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -16,8 +16,9 @@ import os from mock import Mock -from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.amazon.amazon_estimator import RecordSet, FileSystemRecordSet from sagemaker.estimator import Estimator, Framework +from sagemaker.inputs import FileSystemInput from sagemaker.job import _Job from sagemaker.model import FrameworkModel from sagemaker.session import s3_input @@ -255,6 +256,26 @@ def test_format_inputs_to_input_config_record_set(): assert channels[0]["DataSource"]["S3DataSource"]["S3DataType"] == inputs.s3_data_type +def test_format_inputs_to_input_config_file_system_record_set(): + file_system_id = "fs-0a48d2a1" + file_system_type = "EFS" + directory_path = "ipinsights" + num_records = 1 + feature_dim = 1 + records = FileSystemRecordSet( + file_system_id=file_system_id, + file_system_type=file_system_type, + directory_path=directory_path, + num_records=num_records, + feature_dim=feature_dim, + ) + channels = _Job._format_inputs_to_input_config(records) + assert channels[0]["DataSource"]["FileSystemDataSource"]["DirectoryPath"] == directory_path + assert channels[0]["DataSource"]["FileSystemDataSource"]["FileSystemId"] == file_system_id + assert channels[0]["DataSource"]["FileSystemDataSource"]["FileSystemType"] == file_system_type + assert channels[0]["DataSource"]["FileSystemDataSource"]["FileSystemAccessMode"] == "ro" + + def test_format_inputs_to_input_config_list(): records = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1) inputs = [records] @@ -265,6 +286,28 @@ def test_format_inputs_to_input_config_list(): assert channels[0]["DataSource"]["S3DataSource"]["S3DataType"] == records.s3_data_type +def test_format_record_set_list_input(): + records = FileSystemRecordSet( + file_system_id="fs-fd85e556", + file_system_type="EFS", + directory_path="ipinsights", + num_records=100, + feature_dim=1, + ) + test_records = FileSystemRecordSet( + file_system_id="fs-fd85e556", + file_system_type="EFS", + directory_path="ipinsights", + num_records=20, + feature_dim=1, + channel="validation", + ) + inputs = [records, test_records] + input_dict = _Job._format_record_set_list_input(inputs) + assert isinstance(input_dict["train"], FileSystemInput) + assert isinstance(input_dict["validation"], FileSystemInput) + + @pytest.mark.parametrize( "channel_uri, channel_name, content_type, input_mode", [ @@ -328,7 +371,7 @@ def test_format_inputs_to_input_config_list_not_all_records(): with pytest.raises(ValueError) as ex: _Job._format_inputs_to_input_config(inputs) - assert "List compatible only with RecordSets." in str(ex) + assert "List compatible only with RecordSets or FileSystemRecordSets." in str(ex) def test_format_inputs_to_input_config_list_duplicate_channel(): @@ -465,6 +508,21 @@ def test_format_string_uri_input_string(): assert s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] == inputs +def test_format_string_uri_file_system_input(): + file_system_id = "fs-fd85e556" + file_system_type = "EFS" + directory_path = "ipinsights" + + file_system_input = FileSystemInput( + file_system_id=file_system_id, + file_system_type=file_system_type, + directory_path=directory_path, + ) + + uri_input = _Job._format_string_uri_input(file_system_input) + assert uri_input == file_system_input + + def test_format_string_uri_input_string_exception(): inputs = "mybucket/train"