Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def retrieve_default(
ValueError: If the combination of arguments specified is not supported.
"""
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
raise ValueError(
"Must specify `model_id` and `model_version` when retrieving environment variables."
)

return artifacts._retrieve_default_environment_variables(model_id, model_version, region)
56 changes: 54 additions & 2 deletions src/sagemaker/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from __future__ import absolute_import

import logging
from typing import Dict
from typing import Dict, Optional

from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts
from sagemaker.jumpstart.enums import HyperparameterValidationMode
from sagemaker.jumpstart.validators import validate_hyperparameters

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,8 +53,58 @@ def retrieve_default(
ValueError: If the combination of arguments specified is not supported.
"""
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
raise ValueError(
"Must specify `model_id` and `model_version` when retrieving hyperparameters."
)

return artifacts._retrieve_default_hyperparameters(
model_id, model_version, region, include_container_hyperparameters
)


def validate(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hyperparameters: Optional[dict] = None,
validation_mode: Optional[HyperparameterValidationMode] = None,
) -> None:
"""Validate hyperparameters for models.

Args:
region (str): Region for which to validate hyperparameters. (Default: None).
model_id (str): Model ID of the model for which to validate hyperparameters.
(Default: None)
model_version (str): Version of the model for which to validate hyperparameters.
(Default: None)
hyperparameters (dict): Hyperparameters to validate.
(Default: None)
validation_mode (HyperparameterValidationMode): Method of validation to use with
hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided
to this function will be validated, the missing hyperparameters will be ignored.
If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated.
If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated.
(Default: None)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a Raises: section in the docstring please.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: also add ValueError in Raises docsting.

Raises:
JumpStartHyperparametersError: If the hyperparameter is not formatted correctly,
according to its specs in the model metadata.
ValueError: If the combination of arguments specified is not supported.

"""

if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
raise ValueError(
"Must specify `model_id` and `model_version` when validating hyperparameters."
)

if hyperparameters is None:
raise ValueError("Must specify hyperparameters.")

return validate_hyperparameters(
model_id=model_id,
model_version=model_version,
hyperparameters=hyperparameters,
validation_mode=validation_mode,
region=region,
)
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from sagemaker import image_uris
from sagemaker.jumpstart.constants import (
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
ModelFramework,
VariableScope,
Expand Down
51 changes: 3 additions & 48 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
"""This module stores constants related to SageMaker JumpStart."""
from __future__ import absolute_import
from typing import Set
from enum import Enum
import boto3
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo


Expand Down Expand Up @@ -118,52 +118,7 @@

JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"


INFERENCE_ENTRYPOINT_SCRIPT_NAME = "inference.py"
TRAINING_ENTRYPOINT_SCRIPT_NAME = "transfer_learning.py"


class JumpStartScriptScope(str, Enum):
"""Enum class for JumpStart script scopes."""

INFERENCE = "inference"
TRAINING = "training"

INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"

SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)


class ModelFramework(str, Enum):
"""Enum class for JumpStart model framework.

The ML framework as referenced in the prefix of the model ID.
This value does not necessarily correspond to the container name.
"""

PYTORCH = "pytorch"
TENSORFLOW = "tensorflow"
MXNET = "mxnet"
HUGGINGFACE = "huggingface"
LIGHTGBM = "lightgbm"
CATBOOST = "catboost"
XGBOOST = "xgboost"
SKLEARN = "sklearn"


class VariableScope(str, Enum):
"""Possible value of the ``scope`` attribute for a hyperparameter or environment variable.

Used for hosting environment variables and training hyperparameters.
"""

CONTAINER = "container"
ALGORITHM = "algorithm"


class JumpStartTag(str, Enum):
"""Enum class for tag keys to apply to JumpStart models."""

INFERENCE_MODEL_URI = "aws-jumpstart-inference-model-uri"
INFERENCE_SCRIPT_URI = "aws-jumpstart-inference-script-uri"
TRAINING_MODEL_URI = "aws-jumpstart-training-model-uri"
TRAINING_SCRIPT_URI = "aws-jumpstart-training-script-uri"
77 changes: 77 additions & 0 deletions src/sagemaker/jumpstart/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module stores enums related to SageMaker JumpStart."""

from __future__ import absolute_import

from enum import Enum


class ModelFramework(str, Enum):
"""Enum class for JumpStart model framework.

The ML framework as referenced in the prefix of the model ID.
This value does not necessarily correspond to the container name.
"""

PYTORCH = "pytorch"
TENSORFLOW = "tensorflow"
MXNET = "mxnet"
HUGGINGFACE = "huggingface"
LIGHTGBM = "lightgbm"
CATBOOST = "catboost"
XGBOOST = "xgboost"
SKLEARN = "sklearn"


class VariableScope(str, Enum):
"""Possible value of the ``scope`` attribute for a hyperparameter or environment variable.

Used for hosting environment variables and training hyperparameters.
"""

CONTAINER = "container"
ALGORITHM = "algorithm"


class JumpStartScriptScope(str, Enum):
"""Enum class for JumpStart script scopes."""

INFERENCE = "inference"
TRAINING = "training"


class HyperparameterValidationMode(str, Enum):
"""Possible modes for validating hyperparameters."""

VALIDATE_PROVIDED = "validate_provided"
VALIDATE_ALGORITHM = "validate_algorithm"
VALIDATE_ALL = "validate_all"


class VariableTypes(str, Enum):
"""Possible types for hyperparameters and environment variables."""

TEXT = "text"
INT = "int"
FLOAT = "float"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add a BOOL member while we are at it?

BOOL = "bool"


class JumpStartTag(str, Enum):
"""Enum class for tag keys to apply to JumpStart models."""

INFERENCE_MODEL_URI = "aws-jumpstart-inference-model-uri"
INFERENCE_SCRIPT_URI = "aws-jumpstart-inference-script-uri"
TRAINING_MODEL_URI = "aws-jumpstart-training-model-uri"
TRAINING_SCRIPT_URI = "aws-jumpstart-training-script-uri"
13 changes: 12 additions & 1 deletion src/sagemaker/jumpstart/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,24 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module stores exceptions related to SageMaker JumpStart."""

from __future__ import absolute_import
from typing import List, Optional

from sagemaker.jumpstart.constants import JumpStartScriptScope


class JumpStartHyperparametersError(Exception):
"""Exception raised for bad hyperparameters of a JumpStart model."""

def __init__(
self,
message: Optional[str] = None,
):
self.message = message

super().__init__(self.message)


class VulnerableJumpStartModelError(Exception):
"""Exception raised when trying to access a JumpStart model specs flagged as vulnerable.

Expand Down
10 changes: 10 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class JumpStartHyperparameter(JumpStartDataHolderType):
"scope",
"min",
"max",
"exclusive_min",
"exclusive_max",
}

def __init__(self, spec: Dict[str, Any]):
Expand Down Expand Up @@ -215,6 +217,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
if max_val is not None:
self.max = max_val

exclusive_min_val = json_obj.get("exclusive_min")
if exclusive_min_val is not None:
self.exclusive_min = exclusive_min_val

exclusive_max_val = json_obj.get("exclusive_max")
if exclusive_max_val is not None:
self.exclusive_max = exclusive_max_val

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartHyperparameter object."""
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
Expand Down
16 changes: 8 additions & 8 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from urllib.parse import urlparse
from packaging.version import Version
import sagemaker
from sagemaker.jumpstart import constants
from sagemaker.jumpstart import constants, enums
from sagemaker.jumpstart import accessors
from sagemaker.s3 import parse_s3_url
from sagemaker.jumpstart.exceptions import (
Expand Down Expand Up @@ -200,13 +200,13 @@ def get_tag_value(tag_key: str, tag_array: List[Dict[str, str]]) -> str:


def add_single_jumpstart_tag(
uri: str, tag_key: constants.JumpStartTag, curr_tags: Optional[List[Dict[str, str]]]
uri: str, tag_key: enums.JumpStartTag, curr_tags: Optional[List[Dict[str, str]]]
) -> Optional[List]:
"""Adds ``tag_key`` to ``curr_tags`` if ``uri`` corresponds to a JumpStart model.

Args:
uri (str): URI which may correspond to a JumpStart model.
tag_key (constants.JumpStartTag): Custom tag to apply to current tags if the URI
tag_key (enums.JumpStartTag): Custom tag to apply to current tags if the URI
corresponds to a JumpStart model.
curr_tags (Optional[List]): Current tags associated with ``Estimator`` or ``Model``.
"""
Expand Down Expand Up @@ -249,22 +249,22 @@ def add_jumpstart_tags(

if inference_model_uri:
tags = add_single_jumpstart_tag(
inference_model_uri, constants.JumpStartTag.INFERENCE_MODEL_URI, tags
inference_model_uri, enums.JumpStartTag.INFERENCE_MODEL_URI, tags
)

if inference_script_uri:
tags = add_single_jumpstart_tag(
inference_script_uri, constants.JumpStartTag.INFERENCE_SCRIPT_URI, tags
inference_script_uri, enums.JumpStartTag.INFERENCE_SCRIPT_URI, tags
)

if training_model_uri:
tags = add_single_jumpstart_tag(
training_model_uri, constants.JumpStartTag.TRAINING_MODEL_URI, tags
training_model_uri, enums.JumpStartTag.TRAINING_MODEL_URI, tags
)

if training_script_uri:
tags = add_single_jumpstart_tag(
training_script_uri, constants.JumpStartTag.TRAINING_SCRIPT_URI, tags
training_script_uri, enums.JumpStartTag.TRAINING_SCRIPT_URI, tags
)

return tags
Expand All @@ -280,7 +280,7 @@ def update_inference_tags_with_jumpstart_training_tags(
training_tags (Optional[List[Dict[str, str]]]): Tags from training job.
"""
if training_tags:
for tag_key in constants.JumpStartTag:
for tag_key in enums.JumpStartTag:
if tag_key_in_array(tag_key, training_tags):
tag_value = get_tag_value(tag_key, training_tags)
if inference_tags is None:
Expand Down
Loading