Skip to content

Commit 46bba68

Browse files
author
Dewen Qi
committed
update TM as per latest estimator changes
1 parent bcb4e4c commit 46bba68

File tree

4 files changed

+64
-13
lines changed

4 files changed

+64
-13
lines changed

src/sagemaker/instance_group.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@
1313
"""Defines the InstanceGroup class that configures a heterogeneous cluster."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
18+
from sagemaker.workflow.entities import PipelineVariable
19+
1620

1721
class InstanceGroup(object):
1822
"""The class to create instance groups for a heterogeneous cluster."""
1923

2024
def __init__(
2125
self,
22-
instance_group_name=None,
23-
instance_type=None,
24-
instance_count=None,
26+
instance_group_name: Optional[Union[str, PipelineVariable]] = None,
27+
instance_type: Optional[Union[str, PipelineVariable]] = None,
28+
instance_count: Optional[Union[int, PipelineVariable]] = None,
2529
):
2630
"""It initializes an ``InstanceGroup`` instance.
2731

tests/unit/sagemaker/workflow/test_mechanism/test_code/__init__.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from sagemaker import ModelMetrics, MetricsSource, FileSource, Predictor
2121
from sagemaker.drift_check_baselines import DriftCheckBaselines
22+
from sagemaker.instance_group import InstanceGroup
2223
from sagemaker.metadata_properties import MetadataProperties
2324
from sagemaker.model import FrameworkModel
2425
from sagemaker.parameter import IntegerParameter
@@ -233,14 +234,17 @@ def _generate_all_pipeline_vars() -> dict:
233234
)
234235

235236

237+
# TODO: we should remove the _IS_TRUE_TMP and replace its usages with IS_TRUE
238+
# As currently the `instance_groups` does not work well with some estimator subclasses,
239+
# we temporarily hard code it to False which disables the instance_groups
240+
_IS_TRUE_TMP = False
236241
IS_TRUE = bool(getrandbits(1))
237242
PIPELINE_SESSION = _generate_mock_pipeline_session()
238243
PIPELINE_VARIABLES = _generate_all_pipeline_vars()
239244

240245
# TODO: need to recursively assign with Pipeline Variable in later changes
241246
FIXED_ARGUMENTS = dict(
242247
common=dict(
243-
instance_type=INSTANCE_TYPE,
244248
role=ROLE,
245249
sagemaker_session=PIPELINE_SESSION,
246250
source_dir=f"s3://{BUCKET}/source",
@@ -281,6 +285,7 @@ def _generate_all_pipeline_vars() -> dict:
281285
response_types=["application/json"],
282286
),
283287
processor=dict(
288+
instance_type=INSTANCE_TYPE,
284289
estimator_cls=PyTorch,
285290
code=f"s3://{BUCKET}/code",
286291
spark_event_logs_s3_uri=f"s3://{BUCKET}/my-spark-output-path",
@@ -438,13 +443,33 @@ def _generate_all_pipeline_vars() -> dict:
438443
input_mode=ParameterString(name="train_inputs_input_mode"),
439444
attribute_names=[ParameterString(name="train_inputs_attribute_name")],
440445
target_attribute_name=ParameterString(name="train_inputs_target_attr_name"),
446+
instance_groups=[ParameterString(name="train_inputs_instance_groups")],
441447
),
442448
},
449+
instance_groups=[
450+
InstanceGroup(
451+
instance_group_name=ParameterString(name="instance_group_name"),
452+
# hard code the instance_type here because InstanceGroup.instance_type
453+
# would be used to retrieve image_uri if image_uri is not presented
454+
# and currently the test mechanism does not support skip the test case
455+
# relating to bonded parameters in composite variables (i.e. the InstanceGroup)
456+
# TODO: we should support skip testing on bonded parameters in composite vars
457+
instance_type="ml.m5.xlarge",
458+
instance_count=ParameterString(name="instance_group_instance_count"),
459+
),
460+
]
461+
if _IS_TRUE_TMP
462+
else None,
463+
instance_type="ml.m5.xlarge" if not _IS_TRUE_TMP else None,
464+
instance_count=1 if not _IS_TRUE_TMP else None,
465+
distribution={} if not _IS_TRUE_TMP else None,
443466
),
444467
transformer=dict(
468+
instance_type=INSTANCE_TYPE,
445469
data=f"s3://{BUCKET}/data",
446470
),
447471
tuner=dict(
472+
instance_type=INSTANCE_TYPE,
448473
estimator=TensorFlow(
449474
entry_point=TENSORFLOW_ENTRY_POINT,
450475
role=ROLE,
@@ -475,12 +500,14 @@ def _generate_all_pipeline_vars() -> dict:
475500
include_cls_metadata={"estimator-1": IS_TRUE},
476501
),
477502
model=dict(
503+
instance_type=INSTANCE_TYPE,
478504
serverless_inference_config=ServerlessInferenceConfig(),
479505
framework_version="1.11.0",
480506
py_version="py3",
481507
accelerator_type="ml.eia2.xlarge",
482508
),
483509
pipelinemodel=dict(
510+
instance_type=INSTANCE_TYPE,
484511
models=[
485512
SparkMLModel(
486513
name="MySparkMLModel",
@@ -577,12 +604,17 @@ def _generate_all_pipeline_vars() -> dict:
577604
},
578605
),
579606
)
580-
# A dict to keep the optional arguments which should not be None according to the logic
581-
# specific to the subclass.
607+
# A dict to keep the optional arguments which should not be set to None
608+
# in the test iteration according to the logic specific to the subclass.
582609
PARAMS_SHOULD_NOT_BE_NONE = dict(
583610
estimator=dict(
584611
init=dict(
585-
common={"instance_count", "instance_type"},
612+
# TODO: we should remove the three instance_ parameters here
613+
# For mutually exclusive parameters: instance group
614+
# vs instance count/instance type, if any side is set to None during iteration,
615+
# the other side should get a not None value, instead of listing them here
616+
# and force them to be not None
617+
common={"instance_count", "instance_type", "instance_groups"},
586618
LDA={"mini_batch_size"},
587619
)
588620
),
@@ -692,7 +724,10 @@ def _generate_all_pipeline_vars() -> dict:
692724
),
693725
estimator=dict(
694726
init=dict(
695-
common=dict(),
727+
common=dict(
728+
entry_point={"enable_network_isolation"},
729+
source_dir={"enable_network_isolation"},
730+
),
696731
TensorFlow=dict(
697732
image_uri={"compiler_config"},
698733
compiler_config={"image_uri"},
@@ -701,7 +736,13 @@ def _generate_all_pipeline_vars() -> dict:
701736
image_uri={"compiler_config"},
702737
compiler_config={"image_uri"},
703738
),
704-
)
739+
),
740+
fit=dict(
741+
common=dict(
742+
instance_count={"instance_groups"},
743+
instance_type={"instance_groups"},
744+
),
745+
),
705746
),
706747
)
707748

tests/unit/sagemaker/workflow/test_mechanism/test_code/test_pipeline_var_compatibility_template.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import json
1616

1717
from random import getrandbits
18-
from typing import Optional
18+
from typing import Optional, List
1919
from typing_extensions import get_origin
2020

2121
from sagemaker import Model, PipelineModel, AlgorithmEstimator
@@ -368,14 +368,14 @@ def _verify_composite_object_against_pipeline_var(
368368
self,
369369
param_with_none: str,
370370
step_dsl: str,
371-
step_dsl_obj: object,
371+
step_dsl_obj: List[dict],
372372
):
373373
"""verify pipeline definition regarding composite objects against pipeline variables
374374
375375
Args:
376376
param_with_none (str): The name of the parameter with None value.
377377
step_dsl (str): The step definition retrieved from the pipeline definition DSL.
378-
step_dsl_obj (objet): The json load object of the step definition.
378+
step_dsl_obj (List[dict]): The json load object of the step definition.
379379
"""
380380
# TODO: remove the following hard code assertion once recursive assignment is added
381381
if issubclass(self.clazz, Processor):
@@ -398,6 +398,12 @@ def _verify_composite_object_against_pipeline_var(
398398
assert '{"Get": "Parameters.proc_input_s3_data_type"}' in step_dsl
399399
assert '{"Get": "Parameters.proc_input_app_managed"}' in step_dsl
400400
elif issubclass(self.clazz, EstimatorBase):
401+
if (
402+
param_with_none != "instance_groups"
403+
and self.default_args[CLAZZ_ARGS]["instance_groups"]
404+
):
405+
assert '{"Get": "Parameters.instance_group_name"}' in step_dsl
406+
assert '{"Get": "Parameters.instance_group_instance_count"}' in step_dsl
401407
if issubclass(self.clazz, AmazonAlgorithmEstimatorBase):
402408
# AmazonAlgorithmEstimatorBase's input is records
403409
if param_with_none != "records":
@@ -415,6 +421,7 @@ def _verify_composite_object_against_pipeline_var(
415421
assert '{"Get": "Parameters.train_inputs_input_mode"}' in step_dsl
416422
assert '{"Get": "Parameters.train_inputs_attribute_name"}' in step_dsl
417423
assert '{"Get": "Parameters.train_inputs_target_attr_name"}' in step_dsl
424+
assert '{"Get": "Parameters.train_inputs_instance_groups"}' in step_dsl
418425
if not issubclass(self.clazz, (TensorFlow, MXNet, PyTorch, AlgorithmEstimator)):
419426
# debugger_hook_config may be disabled for these first 3 frameworks
420427
# AlgorithmEstimator ignores the kwargs

tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_estimators.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,6 @@ def test_sklearn_estimator_compatibility():
291291
clazz_args=dict(
292292
py_version="py3",
293293
instance_count=1,
294-
instance_type="ml.m5.xlarge",
295294
framework_version="0.20.0",
296295
),
297296
func_args=dict(),

0 commit comments

Comments
 (0)