diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 1785f15892..ecfa67533b 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -745,7 +745,7 @@ def map(self, func, *iterables): futures = map(self.submit, itertools.repeat(func), *iterables) return [future.result() for future in futures] - def shutdown(self, wait=True): + def shutdown(self): """Prevent more function executions to be submitted to this executor.""" with self._state_condition: self._shutdown = True @@ -756,7 +756,7 @@ def shutdown(self, wait=True): self._state_condition.notify_all() if self._workers is not None: - self._workers.shutdown(wait) + self._workers.shutdown(wait=True) def __enter__(self): """Create an executor instance and return it""" @@ -764,7 +764,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): """Make sure the executor instance is shutdown.""" - self.shutdown(wait=False) + self.shutdown() return False @staticmethod diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py index fede42dab1..fb6e9caf94 100644 --- a/tests/unit/sagemaker/remote_function/test_client.py +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -518,11 +518,6 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism): future_3 = e.submit(job_function, 9, 10, c=11, d=12) future_4 = e.submit(job_function, 13, 14, c=15, d=16) - future_1.wait() - future_2.wait() - future_3.wait() - future_4.wait() - mock_start.assert_has_calls( [ call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, None), @@ -531,6 +526,10 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism): call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, None), ] ) + mock_job_1.describe.assert_called() + mock_job_2.describe.assert_called() + mock_job_3.describe.assert_called() + mock_job_4.describe.assert_called() assert future_1.done() assert future_2.done() @@ -555,15 +554,14 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj): future_1 = e.submit(job_function, 1, 2, c=3, d=4) future_2 = e.submit(job_function, 5, 6, c=7, d=8) - future_1.wait() - future_2.wait() - mock_start.assert_has_calls( [ call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, run_info), call(ANY, job_function, (5, 6), {"c": 7, "d": 8}, run_info), ] ) + mock_job_1.describe.assert_called() + mock_job_2.describe.assert_called() assert future_1.done() assert future_2.done() @@ -573,15 +571,14 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj): future_3 = e.submit(job_function, 9, 10, c=11, d=12) future_4 = e.submit(job_function, 13, 14, c=15, d=16) - future_3.wait() - future_4.wait() - mock_start.assert_has_calls( [ call(ANY, job_function, (9, 10), {"c": 11, "d": 12}, run_info), call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, run_info), ] ) + mock_job_3.describe.assert_called() + mock_job_4.describe.assert_called() assert future_3.done() assert future_4.done() @@ -633,7 +630,7 @@ def test_executor_fails_to_start_job(mock_start, *args): with pytest.raises(TypeError): future_1.result() - future_2.wait() + print(future_2._state) assert future_2.done() @@ -698,8 +695,6 @@ def test_executor_describe_job_throttled_temporarily(mock_start, *args): # submit second job future_2 = e.submit(job_function, 5, 6, c=7, d=8) - future_1.wait() - future_2.wait() assert future_1.done() assert future_2.done() @@ -719,9 +714,9 @@ def test_executor_describe_job_failed_permanently(mock_start, *args): future_2 = e.submit(job_function, 5, 6, c=7, d=8) with pytest.raises(RuntimeError): - future_1.result() + future_1.done() with pytest.raises(RuntimeError): - future_2.result() + future_2.done() @pytest.mark.parametrize(