|
50 | 50 | validate_source_code_input_against_pipeline_variables, |
51 | 51 | ) |
52 | 52 | from sagemaker.inputs import TrainingInput, FileSystemInput |
| 53 | +from sagemaker.instance_group import InstanceGroup |
53 | 54 | from sagemaker.job import _Job |
54 | 55 | from sagemaker.jumpstart.utils import ( |
55 | 56 | add_jumpstart_tags, |
@@ -149,7 +150,7 @@ def __init__( |
149 | 150 | code_location: Optional[str] = None, |
150 | 151 | entry_point: Optional[Union[str, PipelineVariable]] = None, |
151 | 152 | dependencies: Optional[List[Union[str]]] = None, |
152 | | - instance_groups: Optional[Dict[str, Union[str, int]]] = None, |
| 153 | + instance_groups: Optional[List[InstanceGroup]] = None, |
153 | 154 | **kwargs, |
154 | 155 | ): |
155 | 156 | """Initialize an ``EstimatorBase`` instance. |
@@ -1580,6 +1581,8 @@ def _get_instance_type(self): |
1580 | 1581 |
|
1581 | 1582 | for instance_group in self.instance_groups: |
1582 | 1583 | instance_type = instance_group.instance_type |
| 1584 | + if is_pipeline_variable(instance_type): |
| 1585 | + continue |
1583 | 1586 | match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) |
1584 | 1587 |
|
1585 | 1588 | if match: |
@@ -2179,7 +2182,7 @@ def __init__( |
2179 | 2182 | code_location: Optional[str] = None, |
2180 | 2183 | entry_point: Optional[Union[str, PipelineVariable]] = None, |
2181 | 2184 | dependencies: Optional[List[str]] = None, |
2182 | | - instance_groups: Optional[Dict[str, Union[str, int]]] = None, |
| 2185 | + instance_groups: Optional[List[InstanceGroup]] = None, |
2183 | 2186 | **kwargs, |
2184 | 2187 | ): |
2185 | 2188 | """Initialize an ``Estimator`` instance. |
|
0 commit comments