diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 90729da8e1..1785f15892 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -185,6 +185,7 @@ def remote( methods that are not available via PyPI or conda. Default value is ``False``. instance_count (int): The number of instances to use. Defaults to 1. + NOTE: Remote function does not support instance_count > 1 instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. @@ -255,6 +256,12 @@ def _remote(func): @functools.wraps(func) def wrapper(*args, **kwargs): + if instance_count > 1: + raise ValueError( + "Remote function do not support training on multi instances. " + + "Please provide instance_count = 1" + ) + RemoteExecutor._validate_submit_args(func, *args, **kwargs) job_settings = _JobSettings( @@ -574,6 +581,7 @@ def __init__( and methods that are not available via PyPI or conda. Default value is ``False``. instance_count (int): The number of instances to use. Defaults to 1. + NOTE: Remote function does not support instance_count > 1 instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. @@ -647,6 +655,12 @@ def __init__( if self.max_parallel_jobs <= 0: raise ValueError("max_parallel_jobs must be greater than 0.") + if instance_count > 1: + raise ValueError( + "Remote function do not support training on multi instances. " + + "Please provide instance_count = 1" + ) + self.job_settings = _JobSettings( dependencies=dependencies, pre_execution_commands=pre_execution_commands, diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py index f8cd528505..fede42dab1 100644 --- a/tests/unit/sagemaker/remote_function/test_client.py +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -373,6 +373,17 @@ def square(x): square(5) +def test_decorator_instance_count_greater_than_one(): + @remote(image_uri=IMAGE, s3_root_uri=S3_URI, instance_count=2) + def square(x): + return x * x + + with pytest.raises( + ValueError, match=r"Remote function do not support training on multi instances." + ): + square(5) + + @patch("sagemaker.remote_function.client._JobSettings") @patch("sagemaker.remote_function.client._Job.start") def test_decorator_underlying_job_timed_out(mock_start, mock_job_settings): @@ -626,6 +637,14 @@ def test_executor_fails_to_start_job(mock_start, *args): assert future_2.done() +def test_executor_instance_count_greater_than_one(): + with pytest.raises( + ValueError, match=r"Remote function do not support training on multi instances." + ): + with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/", instance_count=2) as e: + e.submit(job_function, 1, 2, c=3, d=4) + + @patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT) @patch("sagemaker.remote_function.client._JobSettings") @patch("sagemaker.remote_function.client._Job.start")