1919
2020from sagemaker import ModelMetrics , MetricsSource , FileSource , Predictor
2121from sagemaker .drift_check_baselines import DriftCheckBaselines
22+ from sagemaker .instance_group import InstanceGroup
2223from sagemaker .metadata_properties import MetadataProperties
2324from sagemaker .model import FrameworkModel
2425from sagemaker .parameter import IntegerParameter
@@ -233,14 +234,17 @@ def _generate_all_pipeline_vars() -> dict:
233234 )
234235
235236
237+ # TODO: we should remove the _IS_TRUE_TMP and replace its usages with IS_TRUE
238+ # As currently the `instance_groups` does not work well with some estimator subclasses,
239+ # we temporarily hard code it to False which disables the instance_groups
240+ _IS_TRUE_TMP = False
236241IS_TRUE = bool (getrandbits (1 ))
237242PIPELINE_SESSION = _generate_mock_pipeline_session ()
238243PIPELINE_VARIABLES = _generate_all_pipeline_vars ()
239244
240245# TODO: need to recursively assign with Pipeline Variable in later changes
241246FIXED_ARGUMENTS = dict (
242247 common = dict (
243- instance_type = INSTANCE_TYPE ,
244248 role = ROLE ,
245249 sagemaker_session = PIPELINE_SESSION ,
246250 source_dir = f"s3://{ BUCKET } /source" ,
@@ -281,6 +285,7 @@ def _generate_all_pipeline_vars() -> dict:
281285 response_types = ["application/json" ],
282286 ),
283287 processor = dict (
288+ instance_type = INSTANCE_TYPE ,
284289 estimator_cls = PyTorch ,
285290 code = f"s3://{ BUCKET } /code" ,
286291 spark_event_logs_s3_uri = f"s3://{ BUCKET } /my-spark-output-path" ,
@@ -438,13 +443,33 @@ def _generate_all_pipeline_vars() -> dict:
438443 input_mode = ParameterString (name = "train_inputs_input_mode" ),
439444 attribute_names = [ParameterString (name = "train_inputs_attribute_name" )],
440445 target_attribute_name = ParameterString (name = "train_inputs_target_attr_name" ),
446+ instance_groups = [ParameterString (name = "train_inputs_instance_groups" )],
441447 ),
442448 },
449+ instance_groups = [
450+ InstanceGroup (
451+ instance_group_name = ParameterString (name = "instance_group_name" ),
452+ # hard code the instance_type here because InstanceGroup.instance_type
453+ # would be used to retrieve image_uri if image_uri is not presented
454+ # and currently the test mechanism does not support skip the test case
455+ # relating to bonded parameters in composite variables (i.e. the InstanceGroup)
456+ # TODO: we should support skip testing on bonded parameters in composite vars
457+ instance_type = "ml.m5.xlarge" ,
458+ instance_count = ParameterString (name = "instance_group_instance_count" ),
459+ ),
460+ ]
461+ if _IS_TRUE_TMP
462+ else None ,
463+ instance_type = "ml.m5.xlarge" if not _IS_TRUE_TMP else None ,
464+ instance_count = 1 if not _IS_TRUE_TMP else None ,
465+ distribution = {} if not _IS_TRUE_TMP else None ,
443466 ),
444467 transformer = dict (
468+ instance_type = INSTANCE_TYPE ,
445469 data = f"s3://{ BUCKET } /data" ,
446470 ),
447471 tuner = dict (
472+ instance_type = INSTANCE_TYPE ,
448473 estimator = TensorFlow (
449474 entry_point = TENSORFLOW_ENTRY_POINT ,
450475 role = ROLE ,
@@ -475,12 +500,14 @@ def _generate_all_pipeline_vars() -> dict:
475500 include_cls_metadata = {"estimator-1" : IS_TRUE },
476501 ),
477502 model = dict (
503+ instance_type = INSTANCE_TYPE ,
478504 serverless_inference_config = ServerlessInferenceConfig (),
479505 framework_version = "1.11.0" ,
480506 py_version = "py3" ,
481507 accelerator_type = "ml.eia2.xlarge" ,
482508 ),
483509 pipelinemodel = dict (
510+ instance_type = INSTANCE_TYPE ,
484511 models = [
485512 SparkMLModel (
486513 name = "MySparkMLModel" ,
@@ -577,12 +604,17 @@ def _generate_all_pipeline_vars() -> dict:
577604 },
578605 ),
579606)
580- # A dict to keep the optional arguments which should not be None according to the logic
581- # specific to the subclass.
607+ # A dict to keep the optional arguments which should not be set to None
608+ # in the test iteration according to the logic specific to the subclass.
582609PARAMS_SHOULD_NOT_BE_NONE = dict (
583610 estimator = dict (
584611 init = dict (
585- common = {"instance_count" , "instance_type" },
612+ # TODO: we should remove the three instance_ parameters here
613+ # For mutually exclusive parameters: instance group
614+ # vs instance count/instance type, if any side is set to None during iteration,
615+ # the other side should get a not None value, instead of listing them here
616+ # and force them to be not None
617+ common = {"instance_count" , "instance_type" , "instance_groups" },
586618 LDA = {"mini_batch_size" },
587619 )
588620 ),
@@ -692,7 +724,10 @@ def _generate_all_pipeline_vars() -> dict:
692724 ),
693725 estimator = dict (
694726 init = dict (
695- common = dict (),
727+ common = dict (
728+ entry_point = {"enable_network_isolation" },
729+ source_dir = {"enable_network_isolation" },
730+ ),
696731 TensorFlow = dict (
697732 image_uri = {"compiler_config" },
698733 compiler_config = {"image_uri" },
@@ -701,7 +736,13 @@ def _generate_all_pipeline_vars() -> dict:
701736 image_uri = {"compiler_config" },
702737 compiler_config = {"image_uri" },
703738 ),
704- )
739+ ),
740+ fit = dict (
741+ common = dict (
742+ instance_count = {"instance_groups" },
743+ instance_type = {"instance_groups" },
744+ ),
745+ ),
705746 ),
706747)
707748
0 commit comments