Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,23 @@ def _is_nova_recipe(recipe):
return bool(has_nova_model) or bool(has_distillation)


def _is_eval_recipe(recipe):
"""Check if the recipe is an eval recipe.

An eval recipe is identified by:
1. Having a evaluation section

Args:
recipe (OmegaConf): The loaded recipe configuration

Returns:
bool: True if the recipe is an eval recipe, False otherwise
"""
# Check for eval model
eval_config = recipe.get("evaluation", {})
return bool(eval_config)


def _recipe_initialize_args(source_dir):
"""Initialize the arguments dictionary for recipe setup.

Expand Down Expand Up @@ -526,7 +543,7 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
self.is_nova_recipe = False
self.is_nova_or_eval_recipe = False
if training_recipe is not None:
if entry_point is not None:
logger.warning("Argument entry_point will be ignored with training_recipe.")
Expand All @@ -538,7 +555,7 @@ def __init__(
training_recipe, recipe_overrides, source_dir, kwargs
)

if self.is_nova_recipe and image_uri is None:
if self.is_nova_or_eval_recipe and image_uri is None:
raise ValueError("Must supply image_uri for nova jobs.")

entry_point = args["entry_point"]
Expand Down Expand Up @@ -569,7 +586,7 @@ def __init__(
source_dir,
hyperparameters,
image_uri=image_uri,
is_nova_job=self.is_nova_recipe,
is_nova_job=self.is_nova_or_eval_recipe,
**kwargs,
)

Expand Down Expand Up @@ -702,8 +719,8 @@ def fit(
"""
# Handle recipe upload and input channel creation if we have a recipe
if (
self.is_nova_recipe is not None
and self.is_nova_recipe
self.is_nova_or_eval_recipe is not None
and self.is_nova_or_eval_recipe
and hasattr(self, "training_recipe_file")
and self.training_recipe_file
):
Expand Down Expand Up @@ -949,7 +966,7 @@ def _device_validate_and_get_type(kwargs, recipe):
if "instance_type" not in kwargs:
raise ValueError("Must pass instance type to estimator when using training recipes.")

if not _is_nova_recipe(recipe) and "trainer" not in recipe:
if not _is_nova_recipe(recipe) and "trainer" not in recipe and not _is_eval_recipe(recipe):
raise ValueError("Supplied recipe does not contain required field trainer.")

instance_type = kwargs["instance_type"].split(".")[1]
Expand All @@ -973,15 +990,15 @@ def _device_handle_instance_count(kwargs, recipe):
"""
# Check if instance_count is already provided in kwargs

is_nova = _is_nova_recipe(recipe)
is_nova_or_eval = _is_nova_recipe(recipe) or _is_eval_recipe(recipe)
if "instance_count" in kwargs:
# Warn if there are conflicting configurations in the recipe
if "num_nodes" in recipe.get("trainer", {}):
logger.warning(
"Using instance_count argument to estimator to set number "
"of nodes. Ignoring trainer -> num_nodes in recipe."
)
if is_nova and "replicas" in recipe.get("run", {}):
if is_nova_or_eval and "replicas" in recipe.get("run", {}):
logger.warning(
"Using instance_count argument to estimator to set number "
"of nodes. Ignoring run -> replicas in recipe."
Expand All @@ -993,7 +1010,7 @@ def _device_handle_instance_count(kwargs, recipe):
kwargs["instance_count"] = recipe["trainer"]["num_nodes"]
return

if is_nova and "run" in recipe and "replicas" in recipe["run"]:
if is_nova_or_eval and "run" in recipe and "replicas" in recipe["run"]:
kwargs["instance_count"] = recipe["run"]["replicas"]
return

Expand Down Expand Up @@ -1137,8 +1154,8 @@ def _setup_for_training_recipe(self, training_recipe, recipe_overrides, source_d
# Merge with overrides
recipe = OmegaConf.merge(recipe, recipe_overrides)

self.is_nova_recipe = _is_nova_recipe(recipe)
if self.is_nova_recipe:
self.is_nova_or_eval_recipe = _is_nova_recipe(recipe) or _is_eval_recipe(recipe)
if self.is_nova_or_eval_recipe:
return self._setup_for_nova_recipe(
recipe,
recipe_name,
Expand Down
22 changes: 11 additions & 11 deletions tests/unit/test_pytorch_nova.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_setup_for_nova_recipe_with_model_name(mock_resolve_save, sagemaker_sess
)

# Check that the Nova recipe was correctly identified
assert pytorch.is_nova_recipe is True
assert pytorch.is_nova_or_eval_recipe is True

# Verify _setup_for_nova_recipe was called
mock_nova_setup.assert_called_once()
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_setup_for_nova_recipe_with_s3_path(mock_resolve_save, sagemaker_session
)

# Check that the Nova recipe was correctly identified
assert pytorch.is_nova_recipe is True
assert pytorch.is_nova_or_eval_recipe is True

# Verify _setup_for_nova_recipe was called
mock_nova_setup.assert_called_once()
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_upload_recipe_to_s3(mock_time, mock_recipe_load, sagemaker_session):
)

# Set Nova recipe attributes
pytorch.is_nova_recipe = True
pytorch.is_nova_or_eval_recipe = True

# Create a temporary file to use as the recipe file
with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file:
Expand Down Expand Up @@ -369,7 +369,7 @@ def test_recipe_resolve_and_save(
)

# Set Nova recipe attributes
pytorch.is_nova_recipe = True
pytorch.is_nova_or_eval_recipe = True

# Mock the temporary file
mock_temp_file_instance = Mock()
Expand Down Expand Up @@ -421,7 +421,7 @@ def test_fit_with_nova_recipe_s3_upload(mock_framework_fit, mock_recipe_load, sa
)

# Set Nova recipe attributes
pytorch.is_nova_recipe = True
pytorch.is_nova_or_eval_recipe = True
pytorch.training_recipe_file = temp_file

# Mock the _upload_recipe_to_s3 method
Expand Down Expand Up @@ -473,7 +473,7 @@ def test_fit_with_nova_recipe_and_inputs(
)

# Set Nova recipe attributes
pytorch.is_nova_recipe = True
pytorch.is_nova_or_eval_recipe = True
pytorch.training_recipe_file = temp_file

# Create training inputs
Expand Down Expand Up @@ -559,7 +559,7 @@ def test_fit_with_nova_recipe(
)

# Set Nova recipe attributes
pytorch.is_nova_recipe = True
pytorch.is_nova_or_eval_recipe = True
pytorch.training_recipe_file = temp_file

# Mock the upload_recipe_to_s3 method
Expand Down Expand Up @@ -642,7 +642,7 @@ def test_framework_set_hyperparameters_non_nova():
py_version="py3",
image_uri=IMAGE_URI,
)
framework.is_nova_recipe = False
framework.is_nova_or_eval_recipe = False

# Add hyperparameters
framework.set_hyperparameters(string_param="string_value", int_param=42, bool_param=True)
Expand Down Expand Up @@ -719,7 +719,7 @@ def test_setup_for_nova_recipe_with_evaluation_lambda(mock_resolve_save, sagemak
)

# Check that the Nova recipe was correctly identified
assert pytorch.is_nova_recipe is True
assert pytorch.is_nova_or_eval_recipe is True

# Verify that eval_lambda_arn hyperparameter was set correctly
assert (
Expand Down Expand Up @@ -780,7 +780,7 @@ def test_setup_for_nova_recipe_with_distillation(mock_resolve_save, sagemaker_se
)

# Check that the Nova recipe was correctly identified
assert pytorch.is_nova_recipe is True
assert pytorch.is_nova_or_eval_recipe is True

# Verify _setup_for_nova_recipe was called
mock_nova_setup.assert_called_once()
Expand Down Expand Up @@ -828,7 +828,7 @@ def test_setup_for_nova_recipe_sets_model_type(mock_resolve_save, sagemaker_sess
)

# Check that the Nova recipe was correctly identified
assert pytorch.is_nova_recipe is True
assert pytorch.is_nova_or_eval_recipe is True

# Verify that model_type hyperparameter was set correctly
assert pytorch._hyperparameters.get("model_type") == "amazon.nova.llama-2-7b"