Skip to content

Commit 095329e

Browse files
committed
feat: added endpoint_name to clarify.ModelConfig
1 parent 76eb325 commit 095329e

File tree

2 files changed

+121
-18
lines changed

2 files changed

+121
-18
lines changed

src/sagemaker/clarify.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -277,26 +277,30 @@ class ModelConfig:
277277

278278
def __init__(
279279
self,
280-
model_name,
281-
instance_count,
282-
instance_type,
283-
accept_type=None,
284-
content_type=None,
285-
content_template=None,
286-
custom_attributes=None,
287-
accelerator_type=None,
288-
endpoint_name_prefix=None,
289-
target_model=None,
280+
model_name: str = None,
281+
instance_count: int = None,
282+
instance_type: str = None,
283+
accept_type: str = None,
284+
content_type: str = None,
285+
content_template: str = None,
286+
custom_attributes: str = None,
287+
accelerator_type: str = None,
288+
endpoint_name_prefix: str = None,
289+
target_model: str = None,
290+
endpoint_name: str = None,
290291
):
291292
r"""Initializes a configuration of a model and the endpoint to be created for it.
292293
293294
Args:
294295
model_name (str): Model name (as created by
295296
`CreateModel <https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html>`_.
297+
Cannot be set when ``endpoint_name`` is set. Must be set with ``instance_count``, ``instance_type``
296298
instance_count (int): The number of instances of a new endpoint for model inference.
299+
Cannot be set when ``endpoint_name`` is set. Must be set with ``model_name``, ``instance_type``
297300
instance_type (str): The type of
298301
`EC2 instance <https://aws.amazon.com/ec2/instance-types/>`_
299302
to use for model inference; for example, ``"ml.c5.xlarge"``.
303+
Cannot be set when ``endpoint_name`` is set. Must be set with ``instance_count``, ``model_name``
300304
accept_type (str): The model output format to be used for getting inferences with the
301305
shadow endpoint. Valid values are ``"text/csv"`` for CSV and
302306
``"application/jsonlines"``. Default is the same as ``content_type``.
@@ -326,17 +330,39 @@ def __init__(
326330
target_model (str): Sets the target model name when using a multi-model endpoint. For
327331
more information about multi-model endpoints, see
328332
https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html
333+
endpoint_name (str): Sets the endpoint_name when re-uses an existing endpoint. Cannot be set
334+
when ``model_name``, ``instance_count``, and ``instance_type`` set
329335
330336
Raises:
331-
ValueError: when the ``endpoint_name_prefix`` is invalid, ``accept_type`` is invalid,
332-
``content_type`` is invalid, or ``content_template`` has no placeholder "features"
337+
ValueError: when the
338+
- ``endpoint_name_prefix`` is invalid,
339+
- ``accept_type`` is invalid,
340+
- ``content_type`` is invalid,
341+
- ``content_template`` has no placeholder "features"
342+
- both [``endpoint_name``] AND [``model_name``, ``instance_count``, ``instance_type``] are set
343+
- both [``endpoint_name``] AND [``endpoint_name_prefix``] are set
333344
"""
334-
self.predictor_config = {
335-
"model_name": model_name,
336-
"instance_type": instance_type,
337-
"initial_instance_count": instance_count,
338-
}
339-
if endpoint_name_prefix is not None:
345+
346+
# validation
347+
_model_endpoint_config_rule = (
348+
all([model_name, instance_count, instance_type]),
349+
all([endpoint_name]),
350+
)
351+
assert any(_model_endpoint_config_rule) and not all(_model_endpoint_config_rule)
352+
if endpoint_name:
353+
assert not endpoint_name_prefix
354+
355+
# main init logic
356+
self.predictor_config = (
357+
{
358+
"model_name": model_name,
359+
"instance_type": instance_type,
360+
"initial_instance_count": instance_count,
361+
}
362+
if not endpoint_name
363+
else {"endpoint_name": endpoint_name}
364+
)
365+
if endpoint_name_prefix:
340366
if re.search("^[a-zA-Z0-9](-*[a-zA-Z0-9])", endpoint_name_prefix) is None:
341367
raise ValueError(
342368
"Invalid endpoint_name_prefix."

tests/unit/test_clarify.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,42 @@ def pdp_config():
757757
return PDPConfig(features=["F1", "F2"], grid_resolution=20)
758758

759759

760+
def test_model_config_validations():
761+
new_model_endpoint_definition = {
762+
"model_name": "xgboost-model",
763+
"instance_type": "ml.c5.xlarge",
764+
"instance_count": 1,
765+
}
766+
existing_endpoint_definition = {"endpoint_name": "existing_endpoint"}
767+
768+
with pytest.raises(AssertionError):
769+
# should be one of them
770+
ModelConfig(
771+
**new_model_endpoint_definition,
772+
**existing_endpoint_definition,
773+
)
774+
775+
with pytest.raises(AssertionError):
776+
# should be one of them
777+
ModelConfig(
778+
endpoint_name_prefix="prefix",
779+
**existing_endpoint_definition,
780+
)
781+
782+
# success path for new model
783+
assert ModelConfig(**new_model_endpoint_definition).predictor_config == {
784+
"initial_instance_count": 1,
785+
"instance_type": "ml.c5.xlarge",
786+
"model_name": "xgboost-model",
787+
}
788+
789+
# success path for existing endpoint
790+
assert (
791+
ModelConfig(**existing_endpoint_definition).predictor_config
792+
== existing_endpoint_definition
793+
)
794+
795+
760796
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
761797
def test_pre_training_bias(
762798
name_from_base,
@@ -1442,6 +1478,47 @@ def test_analysis_config_generator_for_bias_explainability(
14421478
assert actual == expected
14431479

14441480

1481+
def test_analysis_config_generator_for_bias_explainability_with_existing_endpoint(
1482+
data_config, data_bias_config
1483+
):
1484+
model_config = ModelConfig(endpoint_name="existing_endpoint_name")
1485+
model_predicted_label_config = ModelPredictedLabelConfig(
1486+
probability="pr",
1487+
label_headers=["success"],
1488+
)
1489+
actual = _AnalysisConfigGenerator.bias_and_explainability(
1490+
data_config,
1491+
model_config,
1492+
model_predicted_label_config,
1493+
[SHAPConfig(), PDPConfig()],
1494+
data_bias_config,
1495+
pre_training_methods="all",
1496+
post_training_methods="all",
1497+
)
1498+
expected = {
1499+
"dataset_type": "text/csv",
1500+
"facet": [{"name_or_index": "F1"}],
1501+
"group_variable": "F2",
1502+
"headers": ["Label", "F1", "F2", "F3", "F4"],
1503+
"joinsource_name_or_index": "F4",
1504+
"label": "Label",
1505+
"label_values_or_threshold": [1],
1506+
"methods": {
1507+
"pdp": {"grid_resolution": 15, "top_k_features": 10},
1508+
"post_training_bias": {"methods": "all"},
1509+
"pre_training_bias": {"methods": "all"},
1510+
"report": {"name": "report", "title": "Analysis Report"},
1511+
"shap": {"save_local_shap_values": True, "use_logit": False},
1512+
},
1513+
"predictor": {
1514+
"label_headers": ["success"],
1515+
"endpoint_name": "existing_endpoint_name",
1516+
"probability": "pr",
1517+
},
1518+
}
1519+
assert actual == expected
1520+
1521+
14451522
def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config):
14461523
actual = _AnalysisConfigGenerator.bias_pre_training(
14471524
data_config, data_bias_config, methods="all"

0 commit comments

Comments
 (0)