Skip to content

Commit dcdf66f

Browse files
author
Dewen Qi
committed
go with model base and tf
1 parent 03cd5ad commit dcdf66f

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
validate_source_code_input_against_pipeline_variables,
5151
)
5252
from sagemaker.inputs import TrainingInput, FileSystemInput
53+
from sagemaker.instance_group import InstanceGroup
5354
from sagemaker.job import _Job
5455
from sagemaker.jumpstart.utils import (
5556
add_jumpstart_tags,
@@ -149,7 +150,7 @@ def __init__(
149150
code_location: Optional[str] = None,
150151
entry_point: Optional[Union[str, PipelineVariable]] = None,
151152
dependencies: Optional[List[Union[str]]] = None,
152-
instance_groups: Optional[Dict[str, Union[str, int]]] = None,
153+
instance_groups: Optional[List[InstanceGroup]] = None,
153154
**kwargs,
154155
):
155156
"""Initialize an ``EstimatorBase`` instance.
@@ -1580,6 +1581,8 @@ def _get_instance_type(self):
15801581

15811582
for instance_group in self.instance_groups:
15821583
instance_type = instance_group.instance_type
1584+
if is_pipeline_variable(instance_type):
1585+
continue
15831586
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
15841587

15851588
if match:
@@ -2179,7 +2182,7 @@ def __init__(
21792182
code_location: Optional[str] = None,
21802183
entry_point: Optional[Union[str, PipelineVariable]] = None,
21812184
dependencies: Optional[List[str]] = None,
2182-
instance_groups: Optional[Dict[str, Union[str, int]]] = None,
2185+
instance_groups: Optional[List[InstanceGroup]] = None,
21832186
**kwargs,
21842187
):
21852188
"""Initialize an ``Estimator`` instance.

src/sagemaker/fw_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,14 @@ def validate_distribution_instance(sagemaker_session, distribution, instance_typ
871871
# Strategy modelparallel is not enabled
872872
return
873873

874+
if is_pipeline_variable(instance_type):
875+
logger.warning(
876+
"instance_type is a pipeline variable, which is only interpreted in "
877+
"pipeline execution time. As modelparallel only runs on GPU-enabled "
878+
"instances, in execution time, the specified instance type has to support GPU."
879+
)
880+
return
881+
874882
instance_desc = sagemaker_session.boto_session.client("ec2").describe_instance_types(
875883
InstanceTypes=[f"{instance_type}"]
876884
)

0 commit comments

Comments
 (0)