|
26 | 26 | from sagemaker.tensorflow.defaults import TF_VERSION |
27 | 27 | from sagemaker.tensorflow.model import TensorFlowModel |
28 | 28 | from sagemaker.tensorflow.serving import Model |
29 | | -from sagemaker.utils import get_config_value |
| 29 | +from sagemaker.utils import get_config_value, get_short_version |
30 | 30 | from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT |
31 | 31 |
|
32 | 32 | logger = logging.getLogger('sagemaker') |
@@ -171,9 +171,11 @@ class TensorFlow(Framework): |
171 | 171 |
|
172 | 172 | __framework_name__ = 'tensorflow' |
173 | 173 |
|
174 | | - LATEST_VERSION = '1.12' |
| 174 | + LATEST_VERSION = '1.13' |
175 | 175 | """The latest version of TensorFlow included in the SageMaker pre-built Docker images.""" |
176 | 176 |
|
| 177 | + _LOWEST_SCRIPT_MODE_ONLY_VERSION = [1, 13] |
| 178 | + |
177 | 179 | def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2', |
178 | 180 | framework_version=None, model_dir=None, requirements_file='', image_name=None, |
179 | 181 | script_mode=False, distributions=None, **kwargs): |
@@ -276,6 +278,13 @@ def _validate_args(self, py_version, script_mode, framework_version, training_st |
276 | 278 | .format(', '.join(_FRAMEWORK_MODE_ARGS), ', '.join(found_args)) |
277 | 279 | ) |
278 | 280 |
|
| 281 | + if (not self._script_mode_enabled()) and \ |
| 282 | + [int(s) for s in self.framework_version.split('.')] >= self._LOWEST_SCRIPT_MODE_ONLY_VERSION: |
| 283 | + raise AttributeError( |
| 284 | + 'Legacy mode is deprecated in versions 1.13 and higher.' |
| 285 | + 'Please set the script_mode argument to True to use Script Mode' |
| 286 | + ) |
| 287 | + |
279 | 288 | def _validate_requirements_file(self, requirements_file): |
280 | 289 | if not requirements_file: |
281 | 290 | return |
@@ -427,7 +436,7 @@ def _create_tfs_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT): |
427 | 436 | image=self.image_name, |
428 | 437 | name=self._current_job_name, |
429 | 438 | container_log_level=self.container_log_level, |
430 | | - framework_version=self.framework_version, |
| 439 | + framework_version=get_short_version(self.framework_version), |
431 | 440 | sagemaker_session=self.sagemaker_session, |
432 | 441 | vpc_config=self.get_vpc_config(vpc_config_override)) |
433 | 442 |
|
|
0 commit comments