diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index b6523e14dd..6afbeb3f89 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -305,6 +305,13 @@ def _get_args_from_nova_recipe( ) args["hyperparameters"]["kms_key"] = kms_key + # Handle eval custom lambda configuration + if recipe.get("evaluation", {}): + processor = recipe.get("processor", {}) + lambda_arn = processor.get("lambda_arn", "") + if lambda_arn: + args["hyperparameters"]["lambda_arn"] = lambda_arn + _register_custom_resolvers() # Resolve Final Recipe diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py index 17cfda55b0..3c3f3dc2bf 100644 --- a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -446,3 +446,35 @@ def test_get_args_from_nova_recipe_with_distillation_errors(test_case): _get_args_from_nova_recipe( recipe=recipe, compute=test_case["compute"], role=test_case.get("role") ) + + +@pytest.mark.parametrize( + "test_case", + [ + { + "recipe": { + "evaluation": {"task:": "gen_qa", "strategy": "gen_qa", "metric": "all"}, + "processor": { + "lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction" + }, + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "role": "arn:aws:iam::123456789012:role/SageMakerRole", + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "hyperparameters": { + "lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction", + }, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + ], +) +def test_get_args_from_nova_recipe_with_evaluation(test_case): + recipe = OmegaConf.create(test_case["recipe"]) + args, _ = _get_args_from_nova_recipe( + recipe=recipe, compute=test_case["compute"], role=test_case["role"] + ) + assert args == test_case["expected_args"]