Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ application_import_names = sagemaker, tests
import-order-style = google
per-file-ignores =
tests/unit/test_tuner.py: F405
src/sagemaker/config/config_schema.py: E501
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ To run the integration tests, the following prerequisites must be met
1. AWS account credentials are available in the environment for the boto3 client to use.
2. The AWS account has an IAM role named :code:`SageMakerRole`.
It should have the AmazonSageMakerFullAccess policy attached as well as a policy with `the necessary permissions to use Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei-setup.html>`__.
3. To run remote_function tests, dummy ecr repo should be created. It can be created by running -
:code:`aws ecr create-repository --repository-name remote-function-dummy-container`

We recommend selectively running just those integration tests you'd like to run. You can filter by individual test function names with:

Expand Down
1 change: 1 addition & 0 deletions requirements/extras/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ sagemaker-experiments==0.1.35
Jinja2==3.0.3
pandas>=1.3.5,<1.5
scikit-learn==1.0.2
cloudpickle==2.2.1
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def read_requirements(filename):
required_packages = [
"attrs>=20.3.0,<23",
"boto3>=1.26.28,<2.0",
"cloudpickle==2.2.1",
"google-pasta",
"numpy>=1.9.0,<2.0",
"protobuf>=3.1,<4.0",
Expand All @@ -62,6 +63,7 @@ def read_requirements(filename):
"PyYAML==5.4.1",
"jsonschema",
"platformdirs",
"tblib==1.7.0",
]

# Specific use case dependencies
Expand Down
102 changes: 101 additions & 1 deletion src/sagemaker/config/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@
SAGEMAKER = "SageMaker"
PYTHON_SDK = "PythonSDK"
MODULES = "Modules"
REMOTE_FUNCTION = "RemoteFunction"
DEPENDENCIES = "Dependencies"
PRE_EXECUTION_SCRIPT = "PreExecutionScript"
PRE_EXECUTION_COMMANDS = "PreExecutionCommands"
ENVIRONMENT_VARIABLES = "EnvironmentVariables"
IMAGE_URI = "ImageUri"
INCLUDE_LOCAL_WORKDIR = "IncludeLocalWorkDir"
INSTANCE_TYPE = "InstanceType"
S3_KMS_KEY_ID = "S3KmsKeyId"
S3_ROOT_URI = "S3RootUri"
JOB_CONDA_ENV = "JobCondaEnvironment"
OFFLINE_STORE_CONFIG = "OfflineStoreConfig"
ONLINE_STORE_CONFIG = "OnlineStoreConfig"
S3_STORAGE_CONFIG = "S3StorageConfig"
Expand Down Expand Up @@ -221,6 +232,49 @@ def _simple_path(*args: str):
SAGEMAKER, MODEL_PACKAGE, VALIDATION_SPECIFICATION, VALIDATION_PROFILES
)

REMOTE_FUNCTION_DEPENDENCIES = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, DEPENDENCIES
)
REMOTE_FUNCTION_PRE_EXECUTION_COMMANDS = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, PRE_EXECUTION_COMMANDS
)
REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, PRE_EXECUTION_SCRIPT
)
REMOTE_FUNCTION_ENVIRONMENT_VARIABLES = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENVIRONMENT_VARIABLES
)
REMOTE_FUNCTION_IMAGE_URI = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, IMAGE_URI)
REMOTE_FUNCTION_INCLUDE_LOCAL_WORKDIR = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, INCLUDE_LOCAL_WORKDIR
)
REMOTE_FUNCTION_INSTANCE_TYPE = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, INSTANCE_TYPE
)
REMOTE_FUNCTION_JOB_CONDA_ENV = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, JOB_CONDA_ENV
)
REMOTE_FUNCTION_ROLE_ARN = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ROLE_ARN)
REMOTE_FUNCTION_S3_KMS_KEY_ID = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, S3_KMS_KEY_ID
)
REMOTE_FUNCTION_S3_ROOT_URI = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, S3_ROOT_URI
)
REMOTE_FUNCTION_TAGS = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, TAGS)
REMOTE_FUNCTION_VOLUME_KMS_KEY_ID = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VOLUME_KMS_KEY_ID
)
REMOTE_FUNCTION_VPC_CONFIG_SUBNETS = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VPC_CONFIG, SUBNETS
)
REMOTE_FUNCTION_VPC_CONFIG_SECURITY_GROUP_IDS = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VPC_CONFIG, SECURITY_GROUP_IDS
)
REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
)

# Paths for reference elsewhere in the SDK.
# Names include the schema version since the paths could change with other schema versions
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
Expand All @@ -245,7 +299,6 @@ def _simple_path(*args: str):
SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
)


SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = {
"$schema": "https://json-schema.org/draft/2020-12/schema",
TYPE: OBJECT,
Expand Down Expand Up @@ -377,6 +430,23 @@ def _simple_path(*args: str):
"minItems": 0,
"maxItems": 50,
},
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html#sagemaker-CreateTrainingJob-request-Environment
"environmentVariables": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: False,
PATTERN_PROPERTIES: {
r"([a-zA-Z_][a-zA-Z0-9_]*){1,512}": {
TYPE: "string",
"pattern": r"[\S\s]*",
"maxLength": 512,
}
},
"maxProperties": 48,
},
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html#sagemaker-Type-S3DataSource-S3Uri
"s3Uri": {TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024},
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint
"preExecutionCommand": {TYPE: "string", "pattern": r".*"},
},
PROPERTIES: {
SCHEMA_VERSION: {
Expand Down Expand Up @@ -406,6 +476,36 @@ def _simple_path(*args: str):
# Any SageMaker Python SDK specific configuration will be added here.
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: False,
PROPERTIES: {
REMOTE_FUNCTION: {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: False,
PROPERTIES: {
DEPENDENCIES: {TYPE: "string"},
PRE_EXECUTION_COMMANDS: {
TYPE: "array",
"items": {"$ref": "#/definitions/preExecutionCommand"},
},
PRE_EXECUTION_SCRIPT: {TYPE: "string"},
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: {
TYPE: "boolean"
},
ENVIRONMENT_VARIABLES: {
"$ref": "#/definitions/environmentVariables"
},
IMAGE_URI: {TYPE: "string"},
INCLUDE_LOCAL_WORKDIR: {TYPE: "boolean"},
INSTANCE_TYPE: {TYPE: "string"},
JOB_CONDA_ENV: {TYPE: "string"},
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
S3_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"},
S3_ROOT_URI: {"$ref": "#/definitions/s3Uri"},
TAGS: {"$ref": "#/definitions/tags"},
VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"},
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
},
}
},
}
},
},
Expand Down
50 changes: 30 additions & 20 deletions src/sagemaker/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,14 @@ def __exit__(self, exc_type, exc_value, exc_traceback):

self.close()

def __getstate__(self):
"""Overriding this method to prevent instance of Run from being pickled.

Raise:
NotImplementedError: If attempting to pickle this instance.
"""
raise NotImplementedError("Instance of Run type is not allowed to be pickled.")


def load_run(
run_name: Optional[str] = None,
Expand Down Expand Up @@ -787,36 +795,38 @@ def load_run(
Returns:
Run: The loaded Run object.
"""
sagemaker_session = sagemaker_session or _utils.default_session()
environment = _RunEnvironment.load()

verify_load_input_names(run_name=run_name, experiment_name=experiment_name)

if run_name or environment:
if run_name:
logger.warning(
"run_name is explicitly supplied in load_run, "
"which will be prioritized to load the Run object. "
"In other words, the run name in the experiment config, fetched from the "
"job environment or the current run context, will be ignored."
)
else:
exp_config = get_tc_and_exp_config_from_job_env(
environment=environment, sagemaker_session=sagemaker_session
)
run_name = Run._extract_run_name_from_tc_name(
trial_component_name=exp_config[RUN_NAME],
experiment_name=exp_config[EXPERIMENT_NAME],
)
experiment_name = exp_config[EXPERIMENT_NAME]

if run_name:
logger.warning(
"run_name is explicitly supplied in load_run, "
"which will be prioritized to load the Run object. "
"In other words, the run name in the experiment config, fetched from the "
"job environment or the current run context, will be ignored."
)
run_instance = Run(
experiment_name=experiment_name,
run_name=run_name,
sagemaker_session=sagemaker_session,
sagemaker_session=sagemaker_session or _utils.default_session(),
)
elif _RunContext.get_current_run():
run_instance = _RunContext.get_current_run()
elif environment:
exp_config = get_tc_and_exp_config_from_job_env(
environment=environment, sagemaker_session=sagemaker_session or _utils.default_session()
)
run_name = Run._extract_run_name_from_tc_name(
trial_component_name=exp_config[RUN_NAME],
experiment_name=exp_config[EXPERIMENT_NAME],
)
experiment_name = exp_config[EXPERIMENT_NAME]
run_instance = Run(
experiment_name=experiment_name,
run_name=run_name,
sagemaker_session=sagemaker_session or _utils.default_session(),
)
else:
raise RuntimeError(
"Failed to load a Run object. "
Expand Down
36 changes: 36 additions & 0 deletions src/sagemaker/image_uri_config/sagemaker-base-python.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"versions": {
"1.0": {
"registries": {
"us-east-2": "429704687514",
"me-south-1": "117516905037",
"us-west-2": "236514542706",
"ca-central-1": "310906938811",
"ap-east-1": "493642496378",
"us-east-1": "081325390199",
"ap-northeast-2": "806072073708",
"eu-west-2": "712779665605",
"ap-southeast-2": "52832661640",
"cn-northwest-1": "390780980154",
"eu-north-1": "243637512696",
"cn-north-1": "390048526115",
"ap-south-1": "394103062818",
"eu-west-3": "615547856133",
"ap-southeast-3": "276181064229",
"af-south-1": "559312083959",
"eu-west-1": "470317259841",
"eu-central-1": "936697816551",
"sa-east-1": "782484402741",
"ap-northeast-3": "792733760839",
"eu-south-1": "592751261982",
"ap-northeast-1": "102112518831",
"us-west-1": "742091327244",
"ap-southeast-1": "492261229750",
"me-central-1": "103105715889",
"us-gov-east-1": "107072934176",
"us-gov-west-1": "107173498710"
},
"repository": "sagemaker-base-python"
}
}
}
26 changes: 26 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,3 +663,29 @@ def get_training_image_uri(
container_version=container_version,
training_compiler_config=compiler_config,
)


def get_base_python_image_uri(region, py_version="310") -> str:
"""Retrieves the image URI for base python image.

Args:
region (str): The AWS region to use for image URI.
py_version (str): The python version to use for the image. Can be 310 or 38
Default to 310

Returns:
str: The image URI string.
"""

framework = "sagemaker-base-python"
version = "1.0"
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]
config = config_for_framework(framework)
version_config = config["versions"][_version_for_config(version, config)]

registry = _registry_from_region(region, version_config["registries"])

repo = version_config["repository"] + "-" + py_version
repo_and_tag = repo + ":" + version

return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo_and_tag)
12 changes: 12 additions & 0 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,10 +674,22 @@ def _initialize(
self.sagemaker_client = LocalSagemakerClient(self)
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
self.local_mode = True
sagemaker_config = kwargs.get("sagemaker_config", None)
if sagemaker_config:
validate_sagemaker_config(sagemaker_config)

if self.s3_endpoint_url is not None:
self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url)
self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url)
self.sagemaker_config = (
sagemaker_config
if sagemaker_config
else load_sagemaker_config(s3_resource=self.s3_resource)
)
else:
self.sagemaker_config = (
sagemaker_config if sagemaker_config else load_sagemaker_config()
)

sagemaker_config = kwargs.get("sagemaker_config", None)
if sagemaker_config:
Expand Down
16 changes: 16 additions & 0 deletions src/sagemaker/remote_function/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Defines classes and helper methods used in remote function executions."""
from __future__ import absolute_import

from sagemaker.remote_function.client import remote, RemoteExecutor # noqa: F401
Loading