From 4fea25016156fad33be2f4da46fd45368dde9f14 Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Fri, 23 Aug 2019 13:45:05 -0700 Subject: [PATCH 1/2] fix: add logic to use asimov image for TF 1.14 py2 --- doc/using_tf.rst | 3 --- src/sagemaker/fw_utils.py | 17 ++++++++++++++++- src/sagemaker/tensorflow/estimator.py | 4 +++- tests/unit/test_fw_utils.py | 7 +++++++ tests/unit/test_tf_estimator.py | 10 ++-------- 5 files changed, 28 insertions(+), 13 deletions(-) diff --git a/doc/using_tf.rst b/doc/using_tf.rst index 98ce5d1810..db5d6ae141 100644 --- a/doc/using_tf.rst +++ b/doc/using_tf.rst @@ -8,9 +8,6 @@ models on SageMaker Hosting. For general information about using the SageMaker Python SDK, see :ref:`overview:Using the SageMaker Python SDK`. -.. warning:: - The TensorFlow estimator is available only for Python 3, starting by the TensorFlow version 1.14. - .. warning:: We have added a new format of your TensorFlow training script with TensorFlow version 1.11. This new way gives the user script more flexibility. diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 07b5ba312e..a2a75875e2 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -113,7 +113,22 @@ def _using_merged_images(region, framework, py_version, accelerator_type, framew is_gov_region = region in VALID_ACCOUNTS_BY_REGION is_py3 = py_version == "py3" or py_version is None is_merged_versions = _is_merged_versions(framework, framework_version) - return (not is_gov_region) and is_merged_versions and is_py3 and accelerator_type is None + return ( + (not is_gov_region) + and is_merged_versions + and (is_py3 or _is_tf_14(framework, framework_version)) + and accelerator_type is None + ) + + +def _is_tf_14(framework, framework_version): + """ + Args: + framework: + framework_version: + """ + # Asimov team now owns Tensorflow 1.14.0 py2 and py3 + return framework == "tensorflow-scriptmode" and framework_version in ("1.14", "1.14.0") def _registry_id(region, framework, py_version, account, accelerator_type, framework_version): diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 07c1e682fe..553abb9bd0 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -199,6 +199,8 @@ class TensorFlow(Framework): """The latest version of TensorFlow included in the SageMaker pre-built Docker images.""" _LOWEST_SCRIPT_MODE_ONLY_VERSION = [1, 13] + # 1.14.0 now supports py2 + # we will need to update this version number if future versions do not support py2 anymore _LOWEST_PYTHON_2_ONLY_VERSION = [1, 14] def __init__( @@ -343,7 +345,7 @@ def _validate_args( if py_version == "py2" and self._only_python_3_supported(): msg = ( - "Python 2 containers are only available until TensorFlow version 1.13.1. " + "Python 2 containers are only available until TensorFlow version 1.14.0. " "Please use a Python 3 container." ) raise AttributeError(msg) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 4e31e5821b..359e38866f 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -187,6 +187,13 @@ def test_create_image_uri_merged_py2(): == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-scriptmode:1.13.1-gpu-py2" ) + image_uri = fw_utils.create_image_uri( + "us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.14", "py2" + ) + assert ( + image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.14-gpu-py2" + ) + image_uri = fw_utils.create_image_uri("us-west-2", "mxnet", "ml.p3.2xlarge", "1.4.1", "py2") assert image_uri == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.1-gpu-py2" diff --git a/tests/unit/test_tf_estimator.py b/tests/unit/test_tf_estimator.py index 90d613aced..85558d13e1 100644 --- a/tests/unit/test_tf_estimator.py +++ b/tests/unit/test_tf_estimator.py @@ -957,18 +957,12 @@ def test_script_mode_deprecated_args(sagemaker_session): def test_py2_version_deprecated(sagemaker_session): with pytest.raises(AttributeError) as e: - _build_tf(sagemaker_session=sagemaker_session, framework_version="1.14", py_version="py2") + _build_tf(sagemaker_session=sagemaker_session, framework_version="1.14.1", py_version="py2") - msg = "Python 2 containers are only available until TensorFlow version 1.13.1. Please use a Python 3 container." + msg = "Python 2 containers are only available until TensorFlow version 1.14.0. Please use a Python 3 container." assert msg in str(e.value) -def test_py3_is_default_version_after_tf1_14(sagemaker_session): - estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.14") - - assert estimator.py_version == "py3" - - def test_py3_is_default_version_before_tf1_14(sagemaker_session): estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.13") From 60ef1b3699075b9dd18601fa704c880a09cf9d8a Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Fri, 23 Aug 2019 14:32:06 -0700 Subject: [PATCH 2/2] change logic to tf 1.14 or later --- src/sagemaker/fw_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index a2a75875e2..416127f67c 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -116,19 +116,23 @@ def _using_merged_images(region, framework, py_version, accelerator_type, framew return ( (not is_gov_region) and is_merged_versions - and (is_py3 or _is_tf_14(framework, framework_version)) + and (is_py3 or _is_tf_14_or_later(framework, framework_version)) and accelerator_type is None ) -def _is_tf_14(framework, framework_version): +def _is_tf_14_or_later(framework, framework_version): """ Args: framework: framework_version: """ # Asimov team now owns Tensorflow 1.14.0 py2 and py3 - return framework == "tensorflow-scriptmode" and framework_version in ("1.14", "1.14.0") + asimov_lowest_tf_py2 = [1, 14, 0] + version = [int(s) for s in framework_version.split(".")] + return ( + framework == "tensorflow-scriptmode" and version >= asimov_lowest_tf_py2[0 : len(version)] + ) def _registry_id(region, framework, py_version, account, accelerator_type, framework_version):