Skip to content

Commit 7eca440

Browse files
author
Namrata Madan
committed
Revert "fix: make RemoteExecutor context manager non-blocking on pending futures (#3822)"
This reverts commit 5f40087.
1 parent 4844aa1 commit 7eca440

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ def map(self, func, *iterables):
731731
futures = map(self.submit, itertools.repeat(func), *iterables)
732732
return [future.result() for future in futures]
733733

734-
def shutdown(self, wait=True):
734+
def shutdown(self):
735735
"""Prevent more function executions to be submitted to this executor."""
736736
with self._state_condition:
737737
self._shutdown = True
@@ -742,15 +742,15 @@ def shutdown(self, wait=True):
742742
self._state_condition.notify_all()
743743

744744
if self._workers is not None:
745-
self._workers.shutdown(wait)
745+
self._workers.shutdown(wait=True)
746746

747747
def __enter__(self):
748748
"""Create an executor instance and return it"""
749749
return self
750750

751751
def __exit__(self, exc_type, exc_val, exc_tb):
752752
"""Make sure the executor instance is shutdown."""
753-
self.shutdown(wait=False)
753+
self.shutdown()
754754
return False
755755

756756
@staticmethod

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -507,11 +507,6 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism):
507507
future_3 = e.submit(job_function, 9, 10, c=11, d=12)
508508
future_4 = e.submit(job_function, 13, 14, c=15, d=16)
509509

510-
future_1.wait()
511-
future_2.wait()
512-
future_3.wait()
513-
future_4.wait()
514-
515510
mock_start.assert_has_calls(
516511
[
517512
call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, None),
@@ -520,6 +515,10 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism):
520515
call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, None),
521516
]
522517
)
518+
mock_job_1.describe.assert_called()
519+
mock_job_2.describe.assert_called()
520+
mock_job_3.describe.assert_called()
521+
mock_job_4.describe.assert_called()
523522

524523
assert future_1.done()
525524
assert future_2.done()
@@ -544,15 +543,14 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj):
544543
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
545544
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
546545

547-
future_1.wait()
548-
future_2.wait()
549-
550546
mock_start.assert_has_calls(
551547
[
552548
call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, run_info),
553549
call(ANY, job_function, (5, 6), {"c": 7, "d": 8}, run_info),
554550
]
555551
)
552+
mock_job_1.describe.assert_called()
553+
mock_job_2.describe.assert_called()
556554

557555
assert future_1.done()
558556
assert future_2.done()
@@ -562,15 +560,14 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj):
562560
future_3 = e.submit(job_function, 9, 10, c=11, d=12)
563561
future_4 = e.submit(job_function, 13, 14, c=15, d=16)
564562

565-
future_3.wait()
566-
future_4.wait()
567-
568563
mock_start.assert_has_calls(
569564
[
570565
call(ANY, job_function, (9, 10), {"c": 11, "d": 12}, run_info),
571566
call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, run_info),
572567
]
573568
)
569+
mock_job_3.describe.assert_called()
570+
mock_job_4.describe.assert_called()
574571

575572
assert future_3.done()
576573
assert future_4.done()
@@ -622,7 +619,7 @@ def test_executor_fails_to_start_job(mock_start, *args):
622619

623620
with pytest.raises(TypeError):
624621
future_1.result()
625-
future_2.wait()
622+
print(future_2._state)
626623
assert future_2.done()
627624

628625

@@ -679,8 +676,6 @@ def test_executor_describe_job_throttled_temporarily(mock_start, *args):
679676
# submit second job
680677
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
681678

682-
future_1.wait()
683-
future_2.wait()
684679
assert future_1.done()
685680
assert future_2.done()
686681

@@ -700,9 +695,9 @@ def test_executor_describe_job_failed_permanently(mock_start, *args):
700695
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
701696

702697
with pytest.raises(RuntimeError):
703-
future_1.result()
698+
future_1.done()
704699
with pytest.raises(RuntimeError):
705-
future_2.result()
700+
future_2.done()
706701

707702

708703
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)