Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/sagemaker/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def remote(
methods that are not available via PyPI or conda. Default value is ``False``.

instance_count (int): The number of instances to use. Defaults to 1.
NOTE: Remote function does not support instance_count > 1

instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
Expand Down Expand Up @@ -255,6 +256,12 @@ def _remote(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):

if instance_count > 1:
raise ValueError(
"Remote function do not support training on multi instances. "
+ "Please provide instance_count = 1"
)

RemoteExecutor._validate_submit_args(func, *args, **kwargs)

job_settings = _JobSettings(
Expand Down Expand Up @@ -574,6 +581,7 @@ def __init__(
and methods that are not available via PyPI or conda. Default value is ``False``.

instance_count (int): The number of instances to use. Defaults to 1.
NOTE: Remote function does not support instance_count > 1

instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
Expand Down Expand Up @@ -647,6 +655,12 @@ def __init__(
if self.max_parallel_jobs <= 0:
raise ValueError("max_parallel_jobs must be greater than 0.")

if instance_count > 1:
raise ValueError(
"Remote function do not support training on multi instances. "
+ "Please provide instance_count = 1"
)

self.job_settings = _JobSettings(
dependencies=dependencies,
pre_execution_commands=pre_execution_commands,
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/sagemaker/remote_function/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,17 @@ def square(x):
square(5)


def test_decorator_instance_count_greater_than_one():
@remote(image_uri=IMAGE, s3_root_uri=S3_URI, instance_count=2)
def square(x):
return x * x

with pytest.raises(
ValueError, match=r"Remote function do not support training on multi instances."
):
square(5)


@patch("sagemaker.remote_function.client._JobSettings")
@patch("sagemaker.remote_function.client._Job.start")
def test_decorator_underlying_job_timed_out(mock_start, mock_job_settings):
Expand Down Expand Up @@ -626,6 +637,14 @@ def test_executor_fails_to_start_job(mock_start, *args):
assert future_2.done()


def test_executor_instance_count_greater_than_one():
with pytest.raises(
ValueError, match=r"Remote function do not support training on multi instances."
):
with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/", instance_count=2) as e:
e.submit(job_function, 1, 2, c=3, d=4)


@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
@patch("sagemaker.remote_function.client._JobSettings")
@patch("sagemaker.remote_function.client._Job.start")
Expand Down