diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 8f74f0cd77..16f6bf07d3 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -19,7 +19,7 @@ from enum import Enum import sagemaker -from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.amazon.amazon_estimator import RecordSet, AmazonAlgorithmEstimatorBase from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.analytics import HyperparameterTuningJobAnalytics from sagemaker.estimator import Framework @@ -358,7 +358,7 @@ def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estim estimator_cls, job_details["TrainingJobDefinition"] ) estimator = cls._prepare_estimator_from_job_description( - estimator_cls, job_details["TrainingJobDefinition"], sagemaker_session + estimator_cls, job_details, sagemaker_session ) init_params = cls._prepare_init_params_from_job_description(job_details) @@ -497,9 +497,9 @@ def _prepare_estimator_cls(cls, estimator_cls, training_details): ) @classmethod - def _prepare_estimator_from_job_description( - cls, estimator_cls, training_details, sagemaker_session - ): + def _prepare_estimator_from_job_description(cls, estimator_cls, job_details, sagemaker_session): + training_details = job_details["TrainingJobDefinition"] + # Swap name for static hyperparameters to what an estimator would expect training_details["HyperParameters"] = training_details["StaticHyperParameters"] del training_details["StaticHyperParameters"] @@ -507,6 +507,15 @@ def _prepare_estimator_from_job_description( # Remove hyperparameter reserved by SageMaker for tuning jobs del training_details["HyperParameters"]["_tuning_objective_metric"] + # Add missing hyperparameters defined in the hyperparameter ranges, + # as potentially required in the Amazon algorithm estimator's constructor + if issubclass(estimator_cls, AmazonAlgorithmEstimatorBase): + parameter_ranges = job_details["HyperParameterTuningJobConfig"]["ParameterRanges"] + additional_hyperparameters = cls._extract_hyperparameters_from_parameter_ranges( + parameter_ranges + ) + training_details["HyperParameters"].update(additional_hyperparameters) + # Add items expected by the estimator (but aren't needed otherwise) training_details["TrainingJobName"] = "" if "KmsKeyId" not in training_details["OutputDataConfig"]: @@ -559,6 +568,21 @@ def _prepare_parameter_ranges(cls, parameter_ranges): return ranges + @classmethod + def _extract_hyperparameters_from_parameter_ranges(cls, parameter_ranges): + hyperparameters = {} + + for parameter in parameter_ranges["CategoricalParameterRanges"]: + hyperparameters[parameter["Name"]] = parameter["Values"][0] + + for parameter in parameter_ranges["ContinuousParameterRanges"]: + hyperparameters[parameter["Name"]] = float(parameter["MinValue"]) + + for parameter in parameter_ranges["IntegerParameterRanges"]: + hyperparameters[parameter["Name"]] = int(parameter["MinValue"]) + + return hyperparameters + def hyperparameter_ranges(self): """Return the hyperparameter ranges in a dictionary to be used as part of a request for creating a hyperparameter tuning job. diff --git a/tests/integ/test_tuner.py b/tests/integ/test_tuner.py index 1d74acbe06..b823484efd 100644 --- a/tests/integ/test_tuner.py +++ b/tests/integ/test_tuner.py @@ -460,12 +460,15 @@ def test_tuning_lda(sagemaker_session): time.sleep(15) tuner.wait() - desc = tuner.latest_tuning_job.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job( - HyperParameterTuningJobName=latest_tuning_job_name + attached_tuner = HyperparameterTuner.attach( + tuning_job_name, sagemaker_session=sagemaker_session ) - assert desc["HyperParameterTuningJobConfig"]["TrainingJobEarlyStoppingType"] == "Auto" + assert attached_tuner.early_stopping_type == "Auto" + assert attached_tuner.estimator.alpha0 == 1.0 + assert attached_tuner.estimator.num_topics == 1 + + best_training_job = attached_tuner.best_training_job() - best_training_job = tuner.best_training_job() with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session): predictor = tuner.deploy(1, "ml.c4.xlarge") predict_input = np.random.rand(1, feature_num) diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 73010c3afb..6accb3ba95 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -78,7 +78,7 @@ "IntegerParameterRanges": [ { "MaxValue": "100", - "Name": "mini_batch_size", + "Name": "num_components", "MinValue": "10", "ScalingType": "Auto", } @@ -416,7 +416,7 @@ def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session assert tuner.estimator.output_kms_key == "" assert "_tuning_objective_metric" not in tuner.estimator.hyperparameters() - assert tuner.estimator.hyperparameters()["num_components"] == "1" + assert tuner.estimator.hyperparameters()["num_components"] == "10" def test_attach_tuning_job_with_estimator_from_hyperparameters_with_early_stopping(