Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,23 @@ def _is_nova_recipe(recipe):
return bool(has_nova_model) or bool(has_distillation)


def _is_eval_recipe(recipe):
"""Check if the recipe is an eval recipe.

An eval recipe is identified by:
1. Having a evaluation section

Args:
recipe (OmegaConf): The loaded recipe configuration

Returns:
bool: True if the recipe is an eval recipe, False otherwise
"""
# Check for eval model
eval_config = recipe.get("evaluation", {})
return bool(eval_config)


def _recipe_initialize_args(source_dir):
"""Initialize the arguments dictionary for recipe setup.

Expand Down Expand Up @@ -949,7 +966,7 @@ def _device_validate_and_get_type(kwargs, recipe):
if "instance_type" not in kwargs:
raise ValueError("Must pass instance type to estimator when using training recipes.")

if not _is_nova_recipe(recipe) and "trainer" not in recipe:
if not _is_nova_recipe(recipe) and "trainer" not in recipe and not _is_eval_recipe(recipe):
raise ValueError("Supplied recipe does not contain required field trainer.")

instance_type = kwargs["instance_type"].split(".")[1]
Expand All @@ -973,15 +990,15 @@ def _device_handle_instance_count(kwargs, recipe):
"""
# Check if instance_count is already provided in kwargs

is_nova = _is_nova_recipe(recipe)
is_nova_or_eval = _is_nova_recipe(recipe) or _is_eval_recipe(recipe)
if "instance_count" in kwargs:
# Warn if there are conflicting configurations in the recipe
if "num_nodes" in recipe.get("trainer", {}):
logger.warning(
"Using instance_count argument to estimator to set number "
"of nodes. Ignoring trainer -> num_nodes in recipe."
)
if is_nova and "replicas" in recipe.get("run", {}):
if is_nova_or_eval and "replicas" in recipe.get("run", {}):
logger.warning(
"Using instance_count argument to estimator to set number "
"of nodes. Ignoring run -> replicas in recipe."
Expand All @@ -993,7 +1010,7 @@ def _device_handle_instance_count(kwargs, recipe):
kwargs["instance_count"] = recipe["trainer"]["num_nodes"]
return

if is_nova and "run" in recipe and "replicas" in recipe["run"]:
if is_nova_or_eval and "run" in recipe and "replicas" in recipe["run"]:
kwargs["instance_count"] = recipe["run"]["replicas"]
return

Expand Down