diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index 6afbeb3f89..c7457f6fad 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -310,7 +310,7 @@ def _get_args_from_nova_recipe( processor = recipe.get("processor", {}) lambda_arn = processor.get("lambda_arn", "") if lambda_arn: - args["hyperparameters"]["lambda_arn"] = lambda_arn + args["hyperparameters"]["eval_lambda_arn"] = lambda_arn _register_custom_resolvers() diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 208239e368..9f41b5b2b9 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -1224,6 +1224,13 @@ def _setup_for_nova_recipe( ) args["hyperparameters"]["kms_key"] = kms_key + # Handle eval custom lambda configuration + if recipe.get("evaluation", {}): + processor = recipe.get("processor", {}) + lambda_arn = processor.get("lambda_arn", "") + if lambda_arn: + args["hyperparameters"]["eval_lambda_arn"] = lambda_arn + # Resolve and save the final recipe self._recipe_resolve_and_save(recipe, recipe_name, args["source_dir"]) diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py index 3c3f3dc2bf..6087050171 100644 --- a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -463,7 +463,7 @@ def test_get_args_from_nova_recipe_with_distillation_errors(test_case): "expected_args": { "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), "hyperparameters": { - "lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction", + "eval_lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction", }, "training_image": None, "source_code": None, diff --git a/tests/unit/test_pytorch_nova.py b/tests/unit/test_pytorch_nova.py index f78bdcae7d..46d526f22e 100644 --- a/tests/unit/test_pytorch_nova.py +++ b/tests/unit/test_pytorch_nova.py @@ -684,6 +684,50 @@ def test_framework_hyperparameters_nova(): assert hyperparams["bool_param"] == "true" +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") +def test_setup_for_nova_recipe_with_evaluation_lambda(mock_resolve_save, sagemaker_session): + """Test that _setup_for_nova_recipe correctly handles evaluation lambda configuration.""" + # Create a mock recipe with evaluation and processor config + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "foobar/foobar-3-8b", + "replicas": 1, + }, + "evaluation": {"task:": "gen_qa", "strategy": "gen_qa", "metric": "all"}, + "processor": { + "lambda_arn": "arn:aws:lambda:us-west-2:123456789012:function:eval-function" + }, + } + ) + + with patch( + "sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe) + ): + mock_resolve_save.return_value = recipe + + pytorch = PyTorch( + training_recipe="nova_recipe", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + ) + + # Check that the Nova recipe was correctly identified + assert pytorch.is_nova_recipe is True + + # Verify that eval_lambda_arn hyperparameter was set correctly + assert ( + pytorch._hyperparameters.get("eval_lambda_arn") + == "arn:aws:lambda:us-west-2:123456789012:function:eval-function" + ) + + @patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") def test_setup_for_nova_recipe_with_distillation(mock_resolve_save, sagemaker_session): """Test that _setup_for_nova_recipe correctly handles distillation configurations."""