2121import threading
2222
2323from sagemaker .estimator import Framework
24- from sagemaker .fw_utils import create_image_uri , framework_name_from_image , framework_version_from_tag
24+ from sagemaker .fw_utils import framework_name_from_image , framework_version_from_tag
2525from sagemaker .utils import get_config_value
2626
2727from sagemaker .tensorflow .defaults import TF_VERSION
@@ -157,7 +157,7 @@ class TensorFlow(Framework):
157157 __framework_name__ = 'tensorflow'
158158
159159 def __init__ (self , training_steps = None , evaluation_steps = None , checkpoint_path = None , py_version = 'py2' ,
160- framework_version = TF_VERSION , requirements_file = '' , ** kwargs ):
160+ framework_version = TF_VERSION , requirements_file = '' , image_name = None , ** kwargs ):
161161 """Initialize an ``TensorFlow`` estimator.
162162 Args:
163163 training_steps (int): Perform this many steps of training. `None`, the default means train forever.
@@ -171,9 +171,11 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
171171 requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and
172172 relative to ``source_dir``. Details on the format can be found in the
173173 `Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_.
174+ image_name (str): The container image to use for training. This will override py_version and
175+ framework_version. The image is expected to be a modification of the SageMaker TensorFlow image.
174176 **kwargs: Additional kwargs passed to the Framework constructor.
175177 """
176- super (TensorFlow , self ).__init__ (** kwargs )
178+ super (TensorFlow , self ).__init__ (image_name = image_name , ** kwargs )
177179 self .checkpoint_path = checkpoint_path
178180 self .py_version = py_version
179181 self .framework_version = framework_version
@@ -257,7 +259,14 @@ def _prepare_init_params_from_job_description(cls, job_details):
257259 if value is not None :
258260 init_params [argument ] = value
259261
260- framework , py_version , tag = framework_name_from_image (init_params .pop ('image' ))
262+ image_name = init_params .pop ('image' )
263+ framework , py_version , tag = framework_name_from_image (image_name )
264+ if not framework :
265+ # If we were unable to parse the framework name from the image it is not one of our
266+ # officially supported images, in this case just add the image to the init params.
267+ init_params ['image_name' ] = image_name
268+ return init_params
269+
261270 init_params ['py_version' ] = py_version
262271
263272 # We switched image tagging scheme from regular image version (e.g. '1.0') to more expressive
@@ -272,18 +281,6 @@ def _prepare_init_params_from_job_description(cls, job_details):
272281
273282 return init_params
274283
275- def train_image (self ):
276- """Return the Docker image to use for training.
277-
278- The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
279- find the image to use for model training.
280-
281- Returns:
282- str: The URI of the Docker image.
283- """
284- return create_image_uri (self .sagemaker_session .boto_region_name , self .__framework_name__ ,
285- self .train_instance_type , self .framework_version , py_version = self .py_version )
286-
287284 def create_model (self , model_server_workers = None ):
288285 """Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.
289286
@@ -296,9 +293,9 @@ def create_model(self, model_server_workers=None):
296293 See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
297294 """
298295 env = {'SAGEMAKER_REQUIREMENTS' : self .requirements_file }
299- return TensorFlowModel (self .model_data , self .role , self .entry_point , source_dir = self .source_dir ,
300- enable_cloudwatch_metrics = self .enable_cloudwatch_metrics , env = env ,
301- name = self ._current_job_name , container_log_level = self .container_log_level ,
296+ return TensorFlowModel (self .model_data , self .role , self .entry_point , image = self .image_name ,
297+ source_dir = self .source_dir , enable_cloudwatch_metrics = self . enable_cloudwatch_metrics ,
298+ env = env , name = self ._current_job_name , container_log_level = self .container_log_level ,
302299 code_location = self .code_location , py_version = self .py_version ,
303300 framework_version = self .framework_version , model_server_workers = model_server_workers ,
304301 sagemaker_session = self .sagemaker_session )
0 commit comments