Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions doc/using_tf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ models on SageMaker Hosting.

For general information about using the SageMaker Python SDK, see :ref:`overview:Using the SageMaker Python SDK`.

.. warning::
The TensorFlow estimator is available only for Python 3, starting by the TensorFlow version 1.14.

.. warning::
We have added a new format of your TensorFlow training script with TensorFlow version 1.11.
This new way gives the user script more flexibility.
Expand Down
21 changes: 20 additions & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,26 @@ def _using_merged_images(region, framework, py_version, accelerator_type, framew
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
is_py3 = py_version == "py3" or py_version is None
is_merged_versions = _is_merged_versions(framework, framework_version)
return (not is_gov_region) and is_merged_versions and is_py3 and accelerator_type is None
return (
(not is_gov_region)
and is_merged_versions
and (is_py3 or _is_tf_14_or_later(framework, framework_version))
and accelerator_type is None
)


def _is_tf_14_or_later(framework, framework_version):
"""
Args:
framework:
framework_version:
"""
# Asimov team now owns Tensorflow 1.14.0 py2 and py3
asimov_lowest_tf_py2 = [1, 14, 0]
version = [int(s) for s in framework_version.split(".")]
return (
framework == "tensorflow-scriptmode" and version >= asimov_lowest_tf_py2[0 : len(version)]
)


def _registry_id(region, framework, py_version, account, accelerator_type, framework_version):
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ class TensorFlow(Framework):
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""

_LOWEST_SCRIPT_MODE_ONLY_VERSION = [1, 13]
# 1.14.0 now supports py2
# we will need to update this version number if future versions do not support py2 anymore
_LOWEST_PYTHON_2_ONLY_VERSION = [1, 14]

def __init__(
Expand Down Expand Up @@ -343,7 +345,7 @@ def _validate_args(

if py_version == "py2" and self._only_python_3_supported():
msg = (
"Python 2 containers are only available until TensorFlow version 1.13.1. "
"Python 2 containers are only available until TensorFlow version 1.14.0. "

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change the warning to say until Jan 1st 2020

"Please use a Python 3 container."
)
raise AttributeError(msg)
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ def test_create_image_uri_merged_py2():
== "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-scriptmode:1.13.1-gpu-py2"
)

image_uri = fw_utils.create_image_uri(
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.14", "py2"
)
assert (
image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.14-gpu-py2"
)

image_uri = fw_utils.create_image_uri("us-west-2", "mxnet", "ml.p3.2xlarge", "1.4.1", "py2")
assert image_uri == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.1-gpu-py2"

Expand Down
10 changes: 2 additions & 8 deletions tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,18 +957,12 @@ def test_script_mode_deprecated_args(sagemaker_session):

def test_py2_version_deprecated(sagemaker_session):
with pytest.raises(AttributeError) as e:
_build_tf(sagemaker_session=sagemaker_session, framework_version="1.14", py_version="py2")
_build_tf(sagemaker_session=sagemaker_session, framework_version="1.14.1", py_version="py2")

msg = "Python 2 containers are only available until TensorFlow version 1.13.1. Please use a Python 3 container."
msg = "Python 2 containers are only available until TensorFlow version 1.14.0. Please use a Python 3 container."
assert msg in str(e.value)


def test_py3_is_default_version_after_tf1_14(sagemaker_session):
estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.14")

assert estimator.py_version == "py3"


def test_py3_is_default_version_before_tf1_14(sagemaker_session):
estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.13")

Expand Down