1919
2020from sagemaker import ModelMetrics , MetricsSource , FileSource , Predictor
2121from sagemaker .drift_check_baselines import DriftCheckBaselines
22+ from sagemaker .instance_group import InstanceGroup
2223from sagemaker .metadata_properties import MetadataProperties
2324from sagemaker .model import FrameworkModel
2425from sagemaker .parameter import IntegerParameter
@@ -233,14 +234,17 @@ def _generate_all_pipeline_vars() -> dict:
233234 )
234235
235236
237+ # TODO: we should remove the _IS_TRUE_TMP and replace its usages with IS_TRUE
238+ # As currently the `instance_groups` does not work well with some estimator subclasses,
239+ # we temporarily hard code it to False which disables the instance_groups
240+ _IS_TRUE_TMP = False
236241IS_TRUE = bool (getrandbits (1 ))
237242PIPELINE_SESSION = _generate_mock_pipeline_session ()
238243PIPELINE_VARIABLES = _generate_all_pipeline_vars ()
239244
240245# TODO: need to recursively assign with Pipeline Variable in later changes
241246FIXED_ARGUMENTS = dict (
242247 common = dict (
243- instance_type = INSTANCE_TYPE ,
244248 role = ROLE ,
245249 sagemaker_session = PIPELINE_SESSION ,
246250 source_dir = f"s3://{ BUCKET } /source" ,
@@ -281,6 +285,7 @@ def _generate_all_pipeline_vars() -> dict:
281285 response_types = ["application/json" ],
282286 ),
283287 processor = dict (
288+ instance_type = INSTANCE_TYPE ,
284289 estimator_cls = PyTorch ,
285290 code = f"s3://{ BUCKET } /code" ,
286291 spark_event_logs_s3_uri = f"s3://{ BUCKET } /my-spark-output-path" ,
@@ -438,13 +443,31 @@ def _generate_all_pipeline_vars() -> dict:
438443 input_mode = ParameterString (name = "train_inputs_input_mode" ),
439444 attribute_names = [ParameterString (name = "train_inputs_attribute_name" )],
440445 target_attribute_name = ParameterString (name = "train_inputs_target_attr_name" ),
446+ instance_groups = [ParameterString (name = "train_inputs_instance_groups" )],
441447 ),
442448 },
449+ instance_groups = [
450+ InstanceGroup (
451+ instance_group_name = ParameterString (name = "instance_group_name" ),
452+ # hard code the instance_type here because InstanceGroup.instance_type
453+ # would be used to retrieve image_uri if image_uri is not presented
454+ # and currently the test mechanism does not support skip the test case
455+ # relating to bonded parameters in composite variables (i.e. the InstanceGroup)
456+ # TODO: we should support skip testing on bonded parameters in composite vars
457+ instance_type = "ml.m5.xlarge" ,
458+ instance_count = ParameterString (name = "instance_group_instance_count" ),
459+ ),
460+ ] if _IS_TRUE_TMP else None ,
461+ instance_type = "ml.m5.xlarge" if not _IS_TRUE_TMP else None ,
462+ instance_count = 1 if not _IS_TRUE_TMP else None ,
463+ distribution = {} if not _IS_TRUE_TMP else None ,
443464 ),
444465 transformer = dict (
466+ instance_type = INSTANCE_TYPE ,
445467 data = f"s3://{ BUCKET } /data" ,
446468 ),
447469 tuner = dict (
470+ instance_type = INSTANCE_TYPE ,
448471 estimator = TensorFlow (
449472 entry_point = TENSORFLOW_ENTRY_POINT ,
450473 role = ROLE ,
@@ -475,12 +498,14 @@ def _generate_all_pipeline_vars() -> dict:
475498 include_cls_metadata = {"estimator-1" : IS_TRUE },
476499 ),
477500 model = dict (
501+ instance_type = INSTANCE_TYPE ,
478502 serverless_inference_config = ServerlessInferenceConfig (),
479503 framework_version = "1.11.0" ,
480504 py_version = "py3" ,
481505 accelerator_type = "ml.eia2.xlarge" ,
482506 ),
483507 pipelinemodel = dict (
508+ instance_type = INSTANCE_TYPE ,
484509 models = [
485510 SparkMLModel (
486511 name = "MySparkMLModel" ,
@@ -577,12 +602,17 @@ def _generate_all_pipeline_vars() -> dict:
577602 },
578603 ),
579604)
580- # A dict to keep the optional arguments which should not be None according to the logic
581- # specific to the subclass.
605+ # A dict to keep the optional arguments which should not be set to None
606+ # in the test iteration according to the logic specific to the subclass.
582607PARAMS_SHOULD_NOT_BE_NONE = dict (
583608 estimator = dict (
584609 init = dict (
585- common = {"instance_count" , "instance_type" },
610+ # TODO: we should remove the three instance_ parameters here
611+ # For mutually exclusive parameters: instance group
612+ # vs instance count/instance type, if any side is set to None during iteration,
613+ # the other side should get a not None value, instead of listing them here
614+ # and force them to be not None
615+ common = {"instance_count" , "instance_type" , "instance_groups" },
586616 LDA = {"mini_batch_size" },
587617 )
588618 ),
@@ -692,7 +722,10 @@ def _generate_all_pipeline_vars() -> dict:
692722 ),
693723 estimator = dict (
694724 init = dict (
695- common = dict (),
725+ common = dict (
726+ entry_point = {"enable_network_isolation" },
727+ source_dir = {"enable_network_isolation" },
728+ ),
696729 TensorFlow = dict (
697730 image_uri = {"compiler_config" },
698731 compiler_config = {"image_uri" },
@@ -701,7 +734,13 @@ def _generate_all_pipeline_vars() -> dict:
701734 image_uri = {"compiler_config" },
702735 compiler_config = {"image_uri" },
703736 ),
704- )
737+ ),
738+ fit = dict (
739+ common = dict (
740+ instance_count = {"instance_groups" },
741+ instance_type = {"instance_groups" },
742+ ),
743+ ),
705744 ),
706745)
707746
0 commit comments