@@ -831,16 +831,10 @@ def _create_sagemaker_model(
831831 # _base_name, model_name are not needed under PipelineSession.
832832 # the model_data may be Pipeline variable
833833 # which may break the _base_name generation
834- model_uri = None
835- if isinstance (self .model_data , (str , PipelineVariable )):
836- model_uri = self .model_data
837- elif isinstance (self .model_data , dict ):
838- model_uri = self .model_data .get ("S3DataSource" , {}).get ("S3Uri" , None )
839-
840834 self ._ensure_base_name_if_needed (
841835 image_uri = container_def ["Image" ],
842836 script_uri = self .source_dir ,
843- model_uri = model_uri ,
837+ model_uri = self . _get_model_uri () ,
844838 )
845839 self ._set_model_name_if_needed ()
846840
@@ -877,6 +871,14 @@ def _create_sagemaker_model(
877871 )
878872 self .sagemaker_session .create_model (** create_model_args )
879873
874+ def _get_model_uri (self ):
875+ model_uri = None
876+ if isinstance (self .model_data , (str , PipelineVariable )):
877+ model_uri = self .model_data
878+ elif isinstance (self .model_data , dict ):
879+ model_uri = self .model_data .get ("S3DataSource" , {}).get ("S3Uri" , None )
880+ return model_uri
881+
880882 def _ensure_base_name_if_needed (self , image_uri , script_uri , model_uri ):
881883 """Create a base name from the image URI if there is no model name provided.
882884
@@ -1434,7 +1436,7 @@ def deploy(
14341436 self ._ensure_base_name_if_needed (
14351437 image_uri = self .image_uri ,
14361438 script_uri = self .source_dir ,
1437- model_uri = self .model_data ,
1439+ model_uri = self ._get_model_uri () ,
14381440 )
14391441 if self ._base_name is not None :
14401442 self ._base_name = "-" .join ((self ._base_name , compiled_model_suffix ))
0 commit comments