diff --git a/src/sagemaker/algorithm.py b/src/sagemaker/algorithm.py index 38e2e1e07c..78d6788a61 100644 --- a/src/sagemaker/algorithm.py +++ b/src/sagemaker/algorithm.py @@ -553,3 +553,28 @@ def _algorithm_training_input_modes(self, training_channels): current_input_modes = current_input_modes & supported_input_modes return current_input_modes + + @classmethod + def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): + """Convert the job description to init params that can be handled by the + class constructor + + Args: + job_details (dict): the returned job details from a DescribeTrainingJob + API call. + model_channel_name (str): Name of the channel where pre-trained + model data will be downloaded. + + Returns: + dict: The transformed init_params + """ + init_params = super(AlgorithmEstimator, cls)._prepare_init_params_from_job_description( + job_details, model_channel_name + ) + + # This hyperparameter is added by Amazon SageMaker Automatic Model Tuning. + # It cannot be set through instantiating an estimator. + if "_tuning_objective_metric" in init_params["hyperparameters"]: + del init_params["hyperparameters"]["_tuning_objective_metric"] + + return init_params diff --git a/tests/unit/test_algorithm.py b/tests/unit/test_algorithm.py index f77e4c16e8..e2d1198bd7 100644 --- a/tests/unit/test_algorithm.py +++ b/tests/unit/test_algorithm.py @@ -943,3 +943,75 @@ def test_algorithm_no_required_hyperparameters(session): train_instance_count=1, sagemaker_session=session, ) + + +def test_algorithm_attach_from_hyperparameter_tuning(): + session = Mock() + job_name = "training-job-that-is-part-of-a-tuning-job" + algo_arn = "arn:aws:sagemaker:us-east-2:000000000000:algorithm/scikit-decision-trees" + role_arn = "arn:aws:iam::123412341234:role/SageMakerRole" + instance_count = 1 + instance_type = "ml.m4.xlarge" + train_volume_size = 30 + input_mode = "File" + + session.sagemaker_client.list_tags.return_value = {"Tags": []} + session.sagemaker_client.describe_algorithm.return_value = DESCRIBE_ALGORITHM_RESPONSE + session.sagemaker_client.describe_training_job.return_value = { + "TrainingJobName": job_name, + "TrainingJobArn": "arn:aws:sagemaker:us-east-2:123412341234:training-job/%s" % job_name, + "TuningJobArn": "arn:aws:sagemaker:us-east-2:123412341234:hyper-parameter-tuning-job/%s" + % job_name, + "ModelArtifacts": { + "S3ModelArtifacts": "s3://sagemaker-us-east-2-123412341234/output/model.tar.gz" + }, + "TrainingJobOutput": { + "S3TrainingJobOutput": "s3://sagemaker-us-east-2-123412341234/output/output.tar.gz" + }, + "TrainingJobStatus": "Succeeded", + "HyperParameters": { + "_tuning_objective_metric": "validation:accuracy", + "max_leaf_nodes": 1, + "free_text_hp1": "foo", + }, + "AlgorithmSpecification": {"AlgorithmName": algo_arn, "TrainingInputMode": input_mode}, + "MetricDefinitions": [ + {"Name": "validation:accuracy", "Regex": "validation-accuracy: (\\S+)"} + ], + "RoleArn": role_arn, + "InputDataConfig": [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": "s3://sagemaker-us-east-2-123412341234/input/training.csv", + "S3DataDistributionType": "FullyReplicated", + } + }, + "CompressionType": "None", + "RecordWrapperType": "None", + } + ], + "OutputDataConfig": { + "KmsKeyId": "", + "S3OutputPath": "s3://sagemaker-us-east-2-123412341234/output", + "RemoveJobNameFromS3OutputPath": False, + }, + "ResourceConfig": { + "InstanceType": instance_type, + "InstanceCount": instance_count, + "VolumeSizeInGB": train_volume_size, + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, + } + + estimator = AlgorithmEstimator.attach(job_name, sagemaker_session=session) + assert estimator.hyperparameters() == {"max_leaf_nodes": 1, "free_text_hp1": "foo"} + assert estimator.algorithm_arn == algo_arn + assert estimator.role == role_arn + assert estimator.train_instance_count == instance_count + assert estimator.train_instance_type == instance_type + assert estimator.train_volume_size == train_volume_size + assert estimator.input_mode == input_mode + assert estimator.sagemaker_session == session