@@ -1439,6 +1439,24 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
14391439 Instance of the calling ``Estimator`` Class with the attached
14401440 training job.
14411441 """
1442+ return cls ._attach (
1443+ training_job_name = training_job_name ,
1444+ sagemaker_session = sagemaker_session ,
1445+ model_channel_name = model_channel_name ,
1446+ )
1447+
1448+ @classmethod
1449+ def _attach (
1450+ cls ,
1451+ training_job_name : str ,
1452+ sagemaker_session : Optional [str ] = None ,
1453+ model_channel_name : str = "model" ,
1454+ additional_kwargs : Optional [Dict [str , Any ]] = None ,
1455+ ) -> "EstimatorBase" :
1456+ """Creates an Estimator bound to an existing training job.
1457+
1458+ Additional kwargs are allowed for instantiating Estimator.
1459+ """
14421460 sagemaker_session = sagemaker_session or Session ()
14431461
14441462 job_details = sagemaker_session .sagemaker_client .describe_training_job (
@@ -1450,6 +1468,9 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
14501468 )["Tags" ]
14511469 init_params .update (tags = tags )
14521470
1471+ if additional_kwargs :
1472+ init_params .update (additional_kwargs )
1473+
14531474 estimator = cls (sagemaker_session = sagemaker_session , ** init_params )
14541475 estimator .latest_training_job = _TrainingJob (
14551476 sagemaker_session = sagemaker_session , job_name = training_job_name
@@ -1751,21 +1772,41 @@ def register(
17511772
17521773 @property
17531774 def model_data (self ):
1754- """str : The model location in S3. Only set if Estimator has been ``fit()``."""
1775+ """Str or dict : The model location in S3. Only set if Estimator has been ``fit()``."""
17551776 if self .latest_training_job is not None and not isinstance (
17561777 self .sagemaker_session , PipelineSession
17571778 ):
1758- model_uri = self .sagemaker_session .sagemaker_client .describe_training_job (
1779+ job_details = self .sagemaker_session .sagemaker_client .describe_training_job (
17591780 TrainingJobName = self .latest_training_job .name
1760- )["ModelArtifacts" ]["S3ModelArtifacts" ]
1761- else :
1762- logger .warning (
1763- "No finished training job found associated with this estimator. Please make sure "
1764- "this estimator is only used for building workflow config"
17651781 )
1766- model_uri = os .path .join (
1767- self .output_path , self ._current_job_name , "output" , "model.tar.gz"
1782+ model_uri = job_details ["ModelArtifacts" ]["S3ModelArtifacts" ]
1783+ compression_type = job_details .get ("OutputDataConfig" , {}).get (
1784+ "CompressionType" , "GZIP"
17681785 )
1786+ if compression_type == "GZIP" :
1787+ return model_uri
1788+ # fail fast if we don't recognize training output compression type
1789+ if compression_type not in {"GZIP" , "NONE" }:
1790+ raise ValueError (
1791+ f'Unrecognized training job output data compression type "{ compression_type } "'
1792+ )
1793+ # model data is in uncompressed form NOTE SageMaker Hosting mandates presence of
1794+ # trailing forward slash in S3 model data URI, so append one if necessary.
1795+ if not model_uri .endswith ("/" ):
1796+ model_uri += "/"
1797+ return {
1798+ "S3DataSource" : {
1799+ "S3Uri" : model_uri ,
1800+ "S3DataType" : "S3Prefix" ,
1801+ "CompressionType" : "None" ,
1802+ }
1803+ }
1804+
1805+ logger .warning (
1806+ "No finished training job found associated with this estimator. Please make sure "
1807+ "this estimator is only used for building workflow config"
1808+ )
1809+ model_uri = os .path .join (self .output_path , self ._current_job_name , "output" , "model.tar.gz" )
17691810 return model_uri
17701811
17711812 @abstractmethod
0 commit comments