From f1b14010c6e6ea53c67c2c775c97611e4038c61c Mon Sep 17 00:00:00 2001 From: chiragvp-aws Date: Mon, 6 Oct 2025 06:24:44 +0000 Subject: [PATCH 1/2] feature: Added condition to allow eval recipe. --- src/sagemaker/pytorch/estimator.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 9e2f0f0dd4..5e96584309 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -163,6 +163,23 @@ def _is_nova_recipe(recipe): return bool(has_nova_model) or bool(has_distillation) +def _is_eval_recipe(recipe): + """Check if the recipe is an eval recipe. + + An eval recipe is identified by: + 1. Having a evaluation section + + Args: + recipe (OmegaConf): The loaded recipe configuration + + Returns: + bool: True if the recipe is an eval recipe, False otherwise + """ + # Check for eval model + eval_config = recipe.get("evaluation", {}) + return bool(eval_config) + + def _recipe_initialize_args(source_dir): """Initialize the arguments dictionary for recipe setup. @@ -949,7 +966,7 @@ def _device_validate_and_get_type(kwargs, recipe): if "instance_type" not in kwargs: raise ValueError("Must pass instance type to estimator when using training recipes.") - if not _is_nova_recipe(recipe) and "trainer" not in recipe: + if not _is_nova_recipe(recipe) and "trainer" not in recipe and not _is_eval_recipe(recipe): raise ValueError("Supplied recipe does not contain required field trainer.") instance_type = kwargs["instance_type"].split(".")[1] @@ -973,7 +990,7 @@ def _device_handle_instance_count(kwargs, recipe): """ # Check if instance_count is already provided in kwargs - is_nova = _is_nova_recipe(recipe) + is_nova_or_eval = _is_nova_recipe(recipe) or _is_eval_recipe(recipe) if "instance_count" in kwargs: # Warn if there are conflicting configurations in the recipe if "num_nodes" in recipe.get("trainer", {}): @@ -981,7 +998,7 @@ def _device_handle_instance_count(kwargs, recipe): "Using instance_count argument to estimator to set number " "of nodes. Ignoring trainer -> num_nodes in recipe." ) - if is_nova and "replicas" in recipe.get("run", {}): + if is_nova_or_eval and "replicas" in recipe.get("run", {}): logger.warning( "Using instance_count argument to estimator to set number " "of nodes. Ignoring run -> replicas in recipe." @@ -993,7 +1010,7 @@ def _device_handle_instance_count(kwargs, recipe): kwargs["instance_count"] = recipe["trainer"]["num_nodes"] return - if is_nova and "run" in recipe and "replicas" in recipe["run"]: + if is_nova_or_eval and "run" in recipe and "replicas" in recipe["run"]: kwargs["instance_count"] = recipe["run"]["replicas"] return @@ -1137,7 +1154,7 @@ def _setup_for_training_recipe(self, training_recipe, recipe_overrides, source_d # Merge with overrides recipe = OmegaConf.merge(recipe, recipe_overrides) - self.is_nova_recipe = _is_nova_recipe(recipe) + self.is_nova_recipe = _is_nova_recipe(recipe) or _is_eval_recipe(recipe) if self.is_nova_recipe: return self._setup_for_nova_recipe( recipe, From 28257a153de16547cf56abe3609b63428cf0a275 Mon Sep 17 00:00:00 2001 From: chiragvp-aws Date: Mon, 6 Oct 2025 19:30:07 +0000 Subject: [PATCH 2/2] change: renamed is_nova_recipe to is_nova_or_eval_recipe --- src/sagemaker/pytorch/estimator.py | 14 +++++++------- tests/unit/test_pytorch_nova.py | 22 +++++++++++----------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 5e96584309..ce8daae9d1 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -543,7 +543,7 @@ def __init__( :class:`~sagemaker.estimator.Framework` and :class:`~sagemaker.estimator.EstimatorBase`. """ - self.is_nova_recipe = False + self.is_nova_or_eval_recipe = False if training_recipe is not None: if entry_point is not None: logger.warning("Argument entry_point will be ignored with training_recipe.") @@ -555,7 +555,7 @@ def __init__( training_recipe, recipe_overrides, source_dir, kwargs ) - if self.is_nova_recipe and image_uri is None: + if self.is_nova_or_eval_recipe and image_uri is None: raise ValueError("Must supply image_uri for nova jobs.") entry_point = args["entry_point"] @@ -586,7 +586,7 @@ def __init__( source_dir, hyperparameters, image_uri=image_uri, - is_nova_job=self.is_nova_recipe, + is_nova_job=self.is_nova_or_eval_recipe, **kwargs, ) @@ -719,8 +719,8 @@ def fit( """ # Handle recipe upload and input channel creation if we have a recipe if ( - self.is_nova_recipe is not None - and self.is_nova_recipe + self.is_nova_or_eval_recipe is not None + and self.is_nova_or_eval_recipe and hasattr(self, "training_recipe_file") and self.training_recipe_file ): @@ -1154,8 +1154,8 @@ def _setup_for_training_recipe(self, training_recipe, recipe_overrides, source_d # Merge with overrides recipe = OmegaConf.merge(recipe, recipe_overrides) - self.is_nova_recipe = _is_nova_recipe(recipe) or _is_eval_recipe(recipe) - if self.is_nova_recipe: + self.is_nova_or_eval_recipe = _is_nova_recipe(recipe) or _is_eval_recipe(recipe) + if self.is_nova_or_eval_recipe: return self._setup_for_nova_recipe( recipe, recipe_name, diff --git a/tests/unit/test_pytorch_nova.py b/tests/unit/test_pytorch_nova.py index b8604c2ef2..662d27e85f 100644 --- a/tests/unit/test_pytorch_nova.py +++ b/tests/unit/test_pytorch_nova.py @@ -138,7 +138,7 @@ def test_setup_for_nova_recipe_with_model_name(mock_resolve_save, sagemaker_sess ) # Check that the Nova recipe was correctly identified - assert pytorch.is_nova_recipe is True + assert pytorch.is_nova_or_eval_recipe is True # Verify _setup_for_nova_recipe was called mock_nova_setup.assert_called_once() @@ -194,7 +194,7 @@ def test_setup_for_nova_recipe_with_s3_path(mock_resolve_save, sagemaker_session ) # Check that the Nova recipe was correctly identified - assert pytorch.is_nova_recipe is True + assert pytorch.is_nova_or_eval_recipe is True # Verify _setup_for_nova_recipe was called mock_nova_setup.assert_called_once() @@ -326,7 +326,7 @@ def test_upload_recipe_to_s3(mock_time, mock_recipe_load, sagemaker_session): ) # Set Nova recipe attributes - pytorch.is_nova_recipe = True + pytorch.is_nova_or_eval_recipe = True # Create a temporary file to use as the recipe file with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file: @@ -369,7 +369,7 @@ def test_recipe_resolve_and_save( ) # Set Nova recipe attributes - pytorch.is_nova_recipe = True + pytorch.is_nova_or_eval_recipe = True # Mock the temporary file mock_temp_file_instance = Mock() @@ -421,7 +421,7 @@ def test_fit_with_nova_recipe_s3_upload(mock_framework_fit, mock_recipe_load, sa ) # Set Nova recipe attributes - pytorch.is_nova_recipe = True + pytorch.is_nova_or_eval_recipe = True pytorch.training_recipe_file = temp_file # Mock the _upload_recipe_to_s3 method @@ -473,7 +473,7 @@ def test_fit_with_nova_recipe_and_inputs( ) # Set Nova recipe attributes - pytorch.is_nova_recipe = True + pytorch.is_nova_or_eval_recipe = True pytorch.training_recipe_file = temp_file # Create training inputs @@ -559,7 +559,7 @@ def test_fit_with_nova_recipe( ) # Set Nova recipe attributes - pytorch.is_nova_recipe = True + pytorch.is_nova_or_eval_recipe = True pytorch.training_recipe_file = temp_file # Mock the upload_recipe_to_s3 method @@ -642,7 +642,7 @@ def test_framework_set_hyperparameters_non_nova(): py_version="py3", image_uri=IMAGE_URI, ) - framework.is_nova_recipe = False + framework.is_nova_or_eval_recipe = False # Add hyperparameters framework.set_hyperparameters(string_param="string_value", int_param=42, bool_param=True) @@ -719,7 +719,7 @@ def test_setup_for_nova_recipe_with_evaluation_lambda(mock_resolve_save, sagemak ) # Check that the Nova recipe was correctly identified - assert pytorch.is_nova_recipe is True + assert pytorch.is_nova_or_eval_recipe is True # Verify that eval_lambda_arn hyperparameter was set correctly assert ( @@ -780,7 +780,7 @@ def test_setup_for_nova_recipe_with_distillation(mock_resolve_save, sagemaker_se ) # Check that the Nova recipe was correctly identified - assert pytorch.is_nova_recipe is True + assert pytorch.is_nova_or_eval_recipe is True # Verify _setup_for_nova_recipe was called mock_nova_setup.assert_called_once() @@ -828,7 +828,7 @@ def test_setup_for_nova_recipe_sets_model_type(mock_resolve_save, sagemaker_sess ) # Check that the Nova recipe was correctly identified - assert pytorch.is_nova_recipe is True + assert pytorch.is_nova_or_eval_recipe is True # Verify that model_type hyperparameter was set correctly assert pytorch._hyperparameters.get("model_type") == "amazon.nova.llama-2-7b"