diff --git a/smartsim/_core/launcher/dragon/dragonLauncher.py b/smartsim/_core/launcher/dragon/dragonLauncher.py index 727bedbf2d..398596049e 100644 --- a/smartsim/_core/launcher/dragon/dragonLauncher.py +++ b/smartsim/_core/launcher/dragon/dragonLauncher.py @@ -263,6 +263,18 @@ def stop(self, step_name: str) -> StepInfo: step_info.launcher_status = str(JobStatus.CANCELLED) return step_info + def stop_jobs( + self, *launched_ids: LaunchedJobID + ) -> t.Mapping[LaunchedJobID, JobStatus]: + """Take a collection of job ids and issue stop requests to the dragon + backend for each. + + :param launched_ids: The ids of the launched jobs to stop. + :returns: A mapping of ids for jobs to stop to their reported status + after attempting to stop them. + """ + return {id_: self.stop(id_).status for id_ in launched_ids} + @staticmethod def _unprefix_step_id(step_id: str) -> str: return step_id.split("-", maxsplit=1)[1] diff --git a/smartsim/_core/shell/shellLauncher.py b/smartsim/_core/shell/shellLauncher.py index 77dc0a10e2..c22ba6ba83 100644 --- a/smartsim/_core/shell/shellLauncher.py +++ b/smartsim/_core/shell/shellLauncher.py @@ -119,16 +119,28 @@ def impl( class ShellLauncher: - """Mock launcher for launching/tracking simple shell commands""" + """A launcher for launching/tracking local shell commands""" def __init__(self) -> None: + """Initialize a new shell launcher.""" self._launched: dict[LaunchedJobID, sp.Popen[bytes]] = {} def check_popen_inputs(self, shell_command: ShellLauncherCommand) -> None: + """Validate that the contents of a shell command are valid. + + :param shell_command: The command to validate + :raises ValueError: If the command is not valid + """ if not shell_command.path.exists(): raise ValueError("Please provide a valid path to ShellLauncherCommand.") def start(self, shell_command: ShellLauncherCommand) -> LaunchedJobID: + """Have the shell launcher start and track the progress of a new + subprocess. + + :param shell_command: The template of a subprocess to start. + :returns: An id to reference the process for status. + """ self.check_popen_inputs(shell_command) id_ = create_job_id() exe, *rest = shell_command.command_tuple @@ -143,15 +155,40 @@ def start(self, shell_command: ShellLauncherCommand) -> LaunchedJobID: ) return id_ + def _get_proc_from_job_id(self, id_: LaunchedJobID, /) -> sp.Popen[bytes]: + """Given an issued job id, return the process represented by that id. + + :param id_: The launched job id of the process + :raises: errors.LauncherJobNotFound: The id could not be mapped to a + process. This usually means that the provided id was not issued by + this launcher instance. + :returns: The process that the shell launcher started and represented + by the issued id. + """ + if (proc := self._launched.get(id_)) is None: + msg = f"Launcher `{self}` has not launched a job with id `{id_}`" + raise errors.LauncherJobNotFound(msg) + return proc + def get_status( self, *launched_ids: LaunchedJobID ) -> t.Mapping[LaunchedJobID, JobStatus]: + """Take a collection of job ids and return the status of the + corresponding processes started by the shell launcher. + + :param launched_ids: A collection of ids of the launched jobs to get + the statuses of. + :returns: A mapping of ids for jobs to stop to their reported status. + """ return {id_: self._get_status(id_) for id_ in launched_ids} def _get_status(self, id_: LaunchedJobID, /) -> JobStatus: - if (proc := self._launched.get(id_)) is None: - msg = f"Launcher `{self}` has not launched a job with id `{id_}`" - raise errors.LauncherJobNotFound(msg) + """Given an issued job id, return the process represented by that id + + :param id_: The launched job id of the process to get the status of. + :returns: The status of that process represented by the given id. + """ + proc = self._get_proc_from_job_id(id_) ret_code = proc.poll() if ret_code is None: status = psutil.Process(proc.pid).status() @@ -173,6 +210,59 @@ def _get_status(self, id_: LaunchedJobID, /) -> JobStatus: return JobStatus.COMPLETED return JobStatus.FAILED + def stop_jobs( + self, *launched_ids: LaunchedJobID + ) -> t.Mapping[LaunchedJobID, JobStatus]: + """Take a collection of job ids and kill the corresponding processes + started by the shell launcher. + + :param launched_ids: The ids of the launched jobs to stop. + :returns: A mapping of ids for jobs to stop to their reported status + after attempting to stop them. + """ + return {id_: self._stop(id_) for id_ in launched_ids} + + def _stop(self, id_: LaunchedJobID, /, wait_time: float = 5.0) -> JobStatus: + """Stop a job represented by an id + + The launcher will first start by attempting to kill the process using + by sending a SIGTERM signal and then waiting for an amount of time. If + the process is not killed by the timeout time, a SIGKILL signal will be + sent and another waiting period will be started. If the period also + ends, the message will be logged and the process will be left to + continue running. The method will then get and return the status of the + job. + + :param id_: The id of a launched job to stop. + :param wait: The maximum amount of time, in seconds, to wait for a + signal to stop a process. + :returns: The status of the job after sending signals to terminate the + started process. + """ + proc = self._get_proc_from_job_id(id_) + if proc.poll() is None: + msg = f"Attempting to terminate local process {proc.pid}" + logger.debug(msg) + proc.terminate() + + try: + proc.wait(wait_time) + except sp.TimeoutExpired: + msg = f"Failed to terminate process {proc.pid}. Attempting to kill." + logger.warning(msg) + proc.kill() + + try: + proc.wait(wait_time) + except sp.TimeoutExpired: + logger.error(f"Failed to kill process {proc.pid}") + return self._get_status(id_) + @classmethod def create(cls, _: Experiment) -> Self: + """Create a new launcher instance from an experiment instance. + + :param _: An experiment instance. + :returns: A new launcher instance. + """ return cls() diff --git a/smartsim/_core/utils/launcher.py b/smartsim/_core/utils/launcher.py index 5191a21f80..7cb0a440b9 100644 --- a/smartsim/_core/utils/launcher.py +++ b/smartsim/_core/utils/launcher.py @@ -85,3 +85,15 @@ def get_status( the ids of the `launched_ids` collection is not recognized. :returns: A mapping of launched id to current status """ + + @abc.abstractmethod + def stop_jobs( + self, *launched_ids: LaunchedJobID + ) -> t.Mapping[LaunchedJobID, JobStatus]: + """Given a collection of launched job ids, cancel the launched jobs + + :param launched_ids: The ids of the jobs to stop + :raises smartsim.error.errors.LauncherJobNotFound: If at least one of + the ids of the `launched_ids` collection is not recognized. + :returns: A mapping of launched id to status upon cancellation + """ diff --git a/smartsim/experiment.py b/smartsim/experiment.py index 3865ba7088..24709ccfd0 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -420,6 +420,25 @@ def summary(self, style: str = "github") -> str: disable_numparse=True, ) + def stop(self, *ids: LaunchedJobID) -> tuple[JobStatus | InvalidJobStatus, ...]: + """Cancel the execution of a previously launched job. + + :param ids: The ids of the launched jobs to stop. + :raises ValueError: No job ids were provided. + :returns: A tuple of job statuses upon cancellation with order + respective of the order of the calling arguments. + """ + if not ids: + raise ValueError("No job ids provided") + by_launcher = self._launch_history.group_by_launcher(set(ids), unknown_ok=True) + id_to_stop_stat = ( + launcher.stop_jobs(*launched).items() + for launcher, launched in by_launcher.items() + ) + stats_map = dict(itertools.chain.from_iterable(id_to_stop_stat)) + stats = (stats_map.get(id_, InvalidJobStatus.NEVER_STARTED) for id_ in ids) + return tuple(stats) + @property def telemetry(self) -> TelemetryConfiguration: """Return the telemetry configuration for this entity. diff --git a/tests/temp_tests/test_settings/conftest.py b/tests/temp_tests/test_settings/conftest.py index 6ec60dd14e..90ffdd1416 100644 --- a/tests/temp_tests/test_settings/conftest.py +++ b/tests/temp_tests/test_settings/conftest.py @@ -55,4 +55,7 @@ def create(cls, exp): def get_status(self, *ids): raise NotImplementedError + def stop_jobs(self, *ids): + raise NotImplementedError + yield _MockLauncher() diff --git a/tests/temp_tests/test_settings/test_dispatch.py b/tests/temp_tests/test_settings/test_dispatch.py index f1545f58ee..89303b5a37 100644 --- a/tests/temp_tests/test_settings/test_dispatch.py +++ b/tests/temp_tests/test_settings/test_dispatch.py @@ -263,6 +263,9 @@ def start(self, strs): def get_status(self, *ids): raise NotImplementedError + def stop_jobs(self, *ids): + raise NotImplementedError + class BufferWriterLauncherSubclass(BufferWriterLauncher): ... diff --git a/tests/test_experiment.py b/tests/test_experiment.py index b0e0136144..aff32604c0 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -42,6 +42,7 @@ from smartsim._core.control.launch_history import LaunchHistory from smartsim._core.utils.launcher import LauncherProtocol, create_job_id from smartsim.entity import entity +from smartsim.error import errors from smartsim.experiment import Experiment from smartsim.launchable import job from smartsim.settings import launchSettings @@ -145,6 +146,9 @@ def start(self, record: LaunchRecord): def get_status(self, *ids): raise NotImplementedError + def stop_jobs(self, *ids): + raise NotImplementedError + @dataclasses.dataclass(frozen=True) class LaunchRecord: @@ -315,9 +319,20 @@ def create(cls, _): def start(self, _): raise NotImplementedError("{type(self).__name__} should not start anything") + def _assert_ids(self, ids: LaunchedJobID): + if any(id_ not in self.id_to_status for id_ in ids): + raise errors.LauncherJobNotFound + def get_status(self, *ids: LaunchedJobID): + self._assert_ids(ids) return {id_: self.id_to_status[id_] for id_ in ids} + def stop_jobs(self, *ids: LaunchedJobID): + self._assert_ids(ids) + stopped = {id_: JobStatus.CANCELLED for id_ in ids} + self.id_to_status |= stopped + return stopped + @pytest.fixture def make_populated_experiment(monkeypatch, experiment): @@ -531,3 +546,68 @@ def test_poll_for_status_raises_if_ids_not_found_within_timeout( timeout=1, interval=0, ) + + +@pytest.mark.parametrize( + "num_launchers", + [pytest.param(i, id=f"{i} launcher(s)") for i in (2, 3, 5, 10, 20, 100)], +) +@pytest.mark.parametrize( + "select_ids", + [ + pytest.param( + lambda history: history._id_to_issuer.keys(), id="All launched jobs" + ), + pytest.param( + lambda history: next(iter(history.group_by_launcher().values())), + id="All from one launcher", + ), + pytest.param( + lambda history: itertools.chain.from_iterable( + random.sample(tuple(ids), len(JobStatus) // 2) + for ids in history.group_by_launcher().values() + ), + id="Subset per launcher", + ), + pytest.param( + lambda history: random.sample( + tuple(history._id_to_issuer), len(history._id_to_issuer) // 3 + ), + id=f"Random subset across all launchers", + ), + ], +) +def test_experiment_can_stop_jobs(make_populated_experiment, num_launchers, select_ids): + exp = make_populated_experiment(num_launchers) + ids = (launcher.known_ids for launcher in exp._launch_history.iter_past_launchers()) + ids = tuple(itertools.chain.from_iterable(ids)) + before_stop_stats = exp.get_status(*ids) + to_cancel = tuple(select_ids(exp._launch_history)) + stats = exp.stop(*to_cancel) + after_stop_stats = exp.get_status(*ids) + assert stats == (JobStatus.CANCELLED,) * len(to_cancel) + assert dict(zip(ids, before_stop_stats)) | dict(zip(to_cancel, stats)) == dict( + zip(ids, after_stop_stats) + ) + + +def test_experiment_raises_if_asked_to_stop_no_jobs(experiment): + with pytest.raises(ValueError, match="No job ids provided"): + experiment.stop() + + +@pytest.mark.parametrize( + "num_launchers", + [pytest.param(i, id=f"{i} launcher(s)") for i in (2, 3, 5, 10, 20, 100)], +) +def test_experiment_stop_does_not_raise_on_unknown_job_id( + make_populated_experiment, num_launchers +): + exp = make_populated_experiment(num_launchers) + new_id = create_job_id() + all_known_ids = tuple(exp._launch_history._id_to_issuer) + before_cancel = exp.get_status(*all_known_ids) + (stat,) = exp.stop(new_id) + assert stat == InvalidJobStatus.NEVER_STARTED + after_cancel = exp.get_status(*all_known_ids) + assert before_cancel == after_cancel diff --git a/tests/test_generator.py b/tests/test_generator.py index 2e6b8a4ad7..ff24018ca7 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -20,15 +20,11 @@ pytestmark = pytest.mark.group_a -ids = set() +_ID_GENERATOR = (str(i) for i in itertools.count()) def random_id(): - while True: - num = str(random.randint(1, 100)) - if num not in ids: - ids.add(num) - return num + return next(_ID_GENERATOR) @pytest.fixture @@ -278,7 +274,6 @@ def test_generate_ensemble_directory_start(test_dir, wlmutils, monkeypatch): log_path = os.path.join(jobs_dir, ensemble_dir, "log") assert osp.isdir(run_path) assert osp.isdir(log_path) - ids.clear() def test_generate_ensemble_copy(test_dir, wlmutils, monkeypatch, get_gen_copy_dir): @@ -299,7 +294,6 @@ def test_generate_ensemble_copy(test_dir, wlmutils, monkeypatch, get_gen_copy_di for ensemble_dir in job_dir: copy_folder_path = os.path.join(jobs_dir, ensemble_dir, "run", "to_copy_dir") assert osp.isdir(copy_folder_path) - ids.clear() def test_generate_ensemble_symlink( @@ -327,7 +321,6 @@ def test_generate_ensemble_symlink( assert osp.isdir(sym_file_path) assert sym_file_path.is_symlink() assert os.fspath(sym_file_path.resolve()) == osp.realpath(get_gen_symlink_dir) - ids.clear() def test_generate_ensemble_configure( @@ -351,7 +344,6 @@ def test_generate_ensemble_configure( job_list = ensemble.as_jobs(launch_settings) exp = Experiment(name="exp_name", exp_path=test_dir) id = exp.start(*job_list) - print(id) run_dir = listdir(test_dir) jobs_dir = os.path.join(test_dir, run_dir[0], "jobs") @@ -372,4 +364,3 @@ def _check_generated(param_0, param_1, dir): _check_generated(1, 2, os.path.join(jobs_dir, "ensemble-name-2-2", "run")) _check_generated(1, 3, os.path.join(jobs_dir, "ensemble-name-3-3", "run")) _check_generated(0, 2, os.path.join(jobs_dir, "ensemble-name-0-0", "run")) - ids.clear() diff --git a/tests/test_launch_history.py b/tests/test_launch_history.py index 9d3bb31ac4..3b4cd5bcc5 100644 --- a/tests/test_launch_history.py +++ b/tests/test_launch_history.py @@ -48,6 +48,9 @@ def start(self, _): def get_status(self, *_): raise NotImplementedError + def stop_jobs(self, *_): + raise NotImplementedError + LAUNCHER_INSTANCE_A = MockLancher() LAUNCHER_INSTANCE_B = MockLancher() diff --git a/tests/test_shell_launcher.py b/tests/test_shell_launcher.py index 432fa7e675..95e8847108 100644 --- a/tests/test_shell_launcher.py +++ b/tests/test_shell_launcher.py @@ -24,9 +24,14 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import contextlib import os import pathlib import subprocess +import sys +import textwrap import unittest.mock import psutil @@ -83,20 +88,37 @@ def shell_launcher(): launcher = ShellLauncher() yield launcher if any(proc.poll() is None for proc in launcher._launched.values()): - raise ("Test leaked processes") + raise RuntimeError("Test leaked processes") @pytest.fixture -def shell_cmd(test_dir: str) -> ShellLauncherCommand: - """Fixture to create an instance of Generator.""" - run_dir, out_file, err_file = generate_directory(test_dir) - with ( - open(out_file, "w", encoding="utf-8") as out, - open(err_file, "w", encoding="utf-8") as err, +def make_shell_command(test_dir): + run_dir, out_file_, err_file_ = generate_directory(test_dir) + + @contextlib.contextmanager + def impl( + args: t.Sequence[str], + working_dir: str | os.PathLike[str] = run_dir, + env: dict[str, str] | None = None, + out_file: str | os.PathLike[str] = out_file_, + err_file: str | os.PathLike[str] = err_file_, ): - yield ShellLauncherCommand( - {}, run_dir, out, err, EchoHelloWorldEntity().as_executable_sequence() - ) + with ( + open(out_file, "w", encoding="utf-8") as out, + open(err_file, "w", encoding="utf-8") as err, + ): + yield ShellLauncherCommand( + env or {}, pathlib.Path(working_dir), out, err, tuple(args) + ) + + yield impl + + +@pytest.fixture +def shell_cmd(make_shell_command) -> ShellLauncherCommand: + """Fixture to create an instance of Generator.""" + with make_shell_command(EchoHelloWorldEntity().as_executable_sequence()) as hello: + yield hello # UNIT TESTS @@ -310,3 +332,61 @@ def test_get_status_maps_correctly( value = shell_launcher.get_status(id) assert value.get(id) == job_status assert proc.wait() == 0 + + +@pytest.mark.parametrize( + "args", + ( + pytest.param(("sleep", "60"), id="Sleep for a minute"), + *( + pytest.param( + ( + sys.executable, + "-c", + textwrap.dedent(f"""\ + import signal, time + signal.signal(signal.{signal_name}, + lambda n, f: print("Ignoring")) + time.sleep(60) + """), + ), + id=f"Process Swallows {signal_name}", + ) + for signal_name in ("SIGINT", "SIGTERM") + ), + ), +) +def test_launcher_can_stop_processes(shell_launcher, make_shell_command, args): + with make_shell_command(args) as cmd: + start = time.perf_counter() + id_ = shell_launcher.start(cmd) + time.sleep(0.1) + assert {id_: JobStatus.RUNNING} == shell_launcher.get_status(id_) + assert JobStatus.FAILED == shell_launcher._stop(id_, wait_time=0.25) + end = time.perf_counter() + assert {id_: JobStatus.FAILED} == shell_launcher.get_status(id_) + proc = shell_launcher._launched[id_] + assert proc.poll() is not None + assert proc.poll() != 0 + assert 0.1 < end - start < 1 + + +def test_launcher_can_stop_many_processes( + make_shell_command, shell_launcher, shell_cmd +): + with ( + make_shell_command(("sleep", "60")) as sleep_60, + make_shell_command(("sleep", "45")) as sleep_45, + make_shell_command(("sleep", "30")) as sleep_30, + ): + id_60 = shell_launcher.start(sleep_60) + id_45 = shell_launcher.start(sleep_45) + id_30 = shell_launcher.start(sleep_30) + id_short = shell_launcher.start(shell_cmd) + time.sleep(0.1) + assert { + id_60: JobStatus.FAILED, + id_45: JobStatus.FAILED, + id_30: JobStatus.FAILED, + id_short: JobStatus.COMPLETED, + } == shell_launcher.stop_jobs(id_30, id_45, id_60, id_short)