Skip to content
12 changes: 12 additions & 0 deletions smartsim/_core/launcher/dragon/dragonLauncher.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,18 @@ def stop(self, step_name: str) -> StepInfo:
step_info.launcher_status = str(JobStatus.CANCELLED)
return step_info

def stop_jobs(
self, *launched_ids: LaunchedJobID
) -> t.Mapping[LaunchedJobID, JobStatus]:
"""Take a collection of job ids and issue stop requests to the dragon
backend for each.

:param launched_ids: The ids of the launched jobs to stop.
:returns: A mapping of ids for jobs to stop to their reported status
after attempting to stop them.
"""
return {id_: self.stop(id_).status for id_ in launched_ids}

@staticmethod
def _unprefix_step_id(step_id: str) -> str:
return step_id.split("-", maxsplit=1)[1]
Expand Down
98 changes: 94 additions & 4 deletions smartsim/_core/shell/shellLauncher.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,28 @@ def impl(


class ShellLauncher:
"""Mock launcher for launching/tracking simple shell commands"""
"""A launcher for launching/tracking local shell commands"""

def __init__(self) -> None:
"""Initialize a new shell launcher."""
self._launched: dict[LaunchedJobID, sp.Popen[bytes]] = {}

def check_popen_inputs(self, shell_command: ShellLauncherCommand) -> None:
"""Validate that the contents of a shell command are valid.

:param shell_command: The command to validate
:raises ValueError: If the command is not valid
"""
if not shell_command.path.exists():
raise ValueError("Please provide a valid path to ShellLauncherCommand.")

def start(self, shell_command: ShellLauncherCommand) -> LaunchedJobID:
"""Have the shell launcher start and track the progress of a new
subprocess.

:param shell_command: The template of a subprocess to start.
:returns: An id to reference the process for status.
"""
self.check_popen_inputs(shell_command)
id_ = create_job_id()
exe, *rest = shell_command.command_tuple
Expand All @@ -143,15 +155,40 @@ def start(self, shell_command: ShellLauncherCommand) -> LaunchedJobID:
)
return id_

def _get_proc_from_job_id(self, id_: LaunchedJobID, /) -> sp.Popen[bytes]:
"""Given an issued job id, return the process represented by that id.

:param id_: The launched job id of the process
:raises: errors.LauncherJobNotFound: The id could not be mapped to a
process. This usually means that the provided id was not issued by
this launcher instance.
:returns: The process that the shell launcher started and represented
by the issued id.
"""
if (proc := self._launched.get(id_)) is None:
msg = f"Launcher `{self}` has not launched a job with id `{id_}`"
raise errors.LauncherJobNotFound(msg)
return proc

def get_status(
self, *launched_ids: LaunchedJobID
) -> t.Mapping[LaunchedJobID, JobStatus]:
"""Take a collection of job ids and return the status of the
corresponding processes started by the shell launcher.

:param launched_ids: A collection of ids of the launched jobs to get
the statuses of.
:returns: A mapping of ids for jobs to stop to their reported status.
"""
return {id_: self._get_status(id_) for id_ in launched_ids}

def _get_status(self, id_: LaunchedJobID, /) -> JobStatus:
if (proc := self._launched.get(id_)) is None:
msg = f"Launcher `{self}` has not launched a job with id `{id_}`"
raise errors.LauncherJobNotFound(msg)
"""Given an issued job id, return the process represented by that id

:param id_: The launched job id of the process to get the status of.
:returns: The status of that process represented by the given id.
"""
proc = self._get_proc_from_job_id(id_)
ret_code = proc.poll()
if ret_code is None:
status = psutil.Process(proc.pid).status()
Expand All @@ -173,6 +210,59 @@ def _get_status(self, id_: LaunchedJobID, /) -> JobStatus:
return JobStatus.COMPLETED
return JobStatus.FAILED

def stop_jobs(
self, *launched_ids: LaunchedJobID
) -> t.Mapping[LaunchedJobID, JobStatus]:
"""Take a collection of job ids and kill the corresponding processes
started by the shell launcher.

:param launched_ids: The ids of the launched jobs to stop.
:returns: A mapping of ids for jobs to stop to their reported status
after attempting to stop them.
"""
return {id_: self._stop(id_) for id_ in launched_ids}

def _stop(self, id_: LaunchedJobID, /, wait_time: float = 5.0) -> JobStatus:
"""Stop a job represented by an id

The launcher will first start by attempting to kill the process using
by sending a SIGTERM signal and then waiting for an amount of time. If
the process is not killed by the timeout time, a SIGKILL signal will be
sent and another waiting period will be started. If the period also
ends, the message will be logged and the process will be left to
continue running. The method will then get and return the status of the
job.

:param id_: The id of a launched job to stop.
:param wait: The maximum amount of time, in seconds, to wait for a
signal to stop a process.
:returns: The status of the job after sending signals to terminate the
started process.
"""
proc = self._get_proc_from_job_id(id_)
if proc.poll() is None:
msg = f"Attempting to terminate local process {proc.pid}"
logger.debug(msg)
proc.terminate()

try:
proc.wait(wait_time)
except sp.TimeoutExpired:
msg = f"Failed to terminate process {proc.pid}. Attempting to kill."
logger.warning(msg)
proc.kill()

try:
proc.wait(wait_time)
except sp.TimeoutExpired:
logger.error(f"Failed to kill process {proc.pid}")
return self._get_status(id_)

@classmethod
def create(cls, _: Experiment) -> Self:
"""Create a new launcher instance from an experiment instance.

:param _: <Unused> An experiment instance.
:returns: A new launcher instance.
"""
return cls()
12 changes: 12 additions & 0 deletions smartsim/_core/utils/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,15 @@ def get_status(
the ids of the `launched_ids` collection is not recognized.
:returns: A mapping of launched id to current status
"""

@abc.abstractmethod
def stop_jobs(
self, *launched_ids: LaunchedJobID
) -> t.Mapping[LaunchedJobID, JobStatus]:
"""Given a collection of launched job ids, cancel the launched jobs

:param launched_ids: The ids of the jobs to stop
:raises smartsim.error.errors.LauncherJobNotFound: If at least one of
the ids of the `launched_ids` collection is not recognized.
:returns: A mapping of launched id to status upon cancellation
"""
19 changes: 19 additions & 0 deletions smartsim/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,25 @@ def summary(self, style: str = "github") -> str:
disable_numparse=True,
)

def stop(self, *ids: LaunchedJobID) -> tuple[JobStatus | InvalidJobStatus, ...]:
"""Cancel the execution of a previously launched job.

:param ids: The ids of the launched jobs to stop.
:raises ValueError: No job ids were provided.
:returns: A tuple of job statuses upon cancellation with order
respective of the order of the calling arguments.
"""
if not ids:
raise ValueError("No job ids provided")
by_launcher = self._launch_history.group_by_launcher(set(ids), unknown_ok=True)
id_to_stop_stat = (
launcher.stop_jobs(*launched).items()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I am thinking about here is stop_jobs may raise a smartsim.error.errors.LauncherJobNotFound, and if it does, it seems like the process only partially completes. Should we complete all of the ones that can complete and then (re)raise the error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory this cannot happen if the _launch_history is never corrupted (e.g. an id is never mapped to a launcher that never issued it), but I do agree that this could become a problem, especially if users start to make their own launchers.

I moved this to a new issue where it can be discussed further because I think this may be a problem with other methods (e.g. get_status, wait) as well as potentially other errors that can arise when stopping over a collection of ids.

for launcher, launched in by_launcher.items()
)
stats_map = dict(itertools.chain.from_iterable(id_to_stop_stat))
stats = (stats_map.get(id_, InvalidJobStatus.NEVER_STARTED) for id_ in ids)
return tuple(stats)

@property
def telemetry(self) -> TelemetryConfiguration:
"""Return the telemetry configuration for this entity.
Expand Down
3 changes: 3 additions & 0 deletions tests/temp_tests/test_settings/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,7 @@ def create(cls, exp):
def get_status(self, *ids):
raise NotImplementedError

def stop_jobs(self, *ids):
raise NotImplementedError

yield _MockLauncher()
3 changes: 3 additions & 0 deletions tests/temp_tests/test_settings/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ def start(self, strs):
def get_status(self, *ids):
raise NotImplementedError

def stop_jobs(self, *ids):
raise NotImplementedError


class BufferWriterLauncherSubclass(BufferWriterLauncher): ...

Expand Down
80 changes: 80 additions & 0 deletions tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from smartsim._core.control.launch_history import LaunchHistory
from smartsim._core.utils.launcher import LauncherProtocol, create_job_id
from smartsim.entity import entity
from smartsim.error import errors
from smartsim.experiment import Experiment
from smartsim.launchable import job
from smartsim.settings import launchSettings
Expand Down Expand Up @@ -145,6 +146,9 @@ def start(self, record: LaunchRecord):
def get_status(self, *ids):
raise NotImplementedError

def stop_jobs(self, *ids):
raise NotImplementedError


@dataclasses.dataclass(frozen=True)
class LaunchRecord:
Expand Down Expand Up @@ -315,9 +319,20 @@ def create(cls, _):
def start(self, _):
raise NotImplementedError("{type(self).__name__} should not start anything")

def _assert_ids(self, ids: LaunchedJobID):
if any(id_ not in self.id_to_status for id_ in ids):
raise errors.LauncherJobNotFound

def get_status(self, *ids: LaunchedJobID):
self._assert_ids(ids)
return {id_: self.id_to_status[id_] for id_ in ids}

def stop_jobs(self, *ids: LaunchedJobID):
self._assert_ids(ids)
stopped = {id_: JobStatus.CANCELLED for id_ in ids}
self.id_to_status |= stopped
return stopped


@pytest.fixture
def make_populated_experiment(monkeypatch, experiment):
Expand Down Expand Up @@ -531,3 +546,68 @@ def test_poll_for_status_raises_if_ids_not_found_within_timeout(
timeout=1,
interval=0,
)


@pytest.mark.parametrize(
"num_launchers",
[pytest.param(i, id=f"{i} launcher(s)") for i in (2, 3, 5, 10, 20, 100)],
)
@pytest.mark.parametrize(
"select_ids",
[
pytest.param(
lambda history: history._id_to_issuer.keys(), id="All launched jobs"
),
pytest.param(
lambda history: next(iter(history.group_by_launcher().values())),
id="All from one launcher",
),
pytest.param(
lambda history: itertools.chain.from_iterable(
random.sample(tuple(ids), len(JobStatus) // 2)
for ids in history.group_by_launcher().values()
),
id="Subset per launcher",
),
pytest.param(
lambda history: random.sample(
tuple(history._id_to_issuer), len(history._id_to_issuer) // 3
),
id=f"Random subset across all launchers",
),
],
)
def test_experiment_can_stop_jobs(make_populated_experiment, num_launchers, select_ids):
exp = make_populated_experiment(num_launchers)
ids = (launcher.known_ids for launcher in exp._launch_history.iter_past_launchers())
ids = tuple(itertools.chain.from_iterable(ids))
before_stop_stats = exp.get_status(*ids)
to_cancel = tuple(select_ids(exp._launch_history))
stats = exp.stop(*to_cancel)
after_stop_stats = exp.get_status(*ids)
assert stats == (JobStatus.CANCELLED,) * len(to_cancel)
assert dict(zip(ids, before_stop_stats)) | dict(zip(to_cancel, stats)) == dict(
zip(ids, after_stop_stats)
)


def test_experiment_raises_if_asked_to_stop_no_jobs(experiment):
with pytest.raises(ValueError, match="No job ids provided"):
experiment.stop()


@pytest.mark.parametrize(
"num_launchers",
[pytest.param(i, id=f"{i} launcher(s)") for i in (2, 3, 5, 10, 20, 100)],
)
def test_experiment_stop_does_not_raise_on_unknown_job_id(
make_populated_experiment, num_launchers
):
exp = make_populated_experiment(num_launchers)
new_id = create_job_id()
all_known_ids = tuple(exp._launch_history._id_to_issuer)
before_cancel = exp.get_status(*all_known_ids)
(stat,) = exp.stop(new_id)
assert stat == InvalidJobStatus.NEVER_STARTED
after_cancel = exp.get_status(*all_known_ids)
assert before_cancel == after_cancel
13 changes: 2 additions & 11 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,11 @@

pytestmark = pytest.mark.group_a

ids = set()
_ID_GENERATOR = (str(i) for i in itertools.count())


def random_id():
while True:
num = str(random.randint(1, 100))
if num not in ids:
ids.add(num)
return num
return next(_ID_GENERATOR)


@pytest.fixture
Expand Down Expand Up @@ -278,7 +274,6 @@ def test_generate_ensemble_directory_start(test_dir, wlmutils, monkeypatch):
log_path = os.path.join(jobs_dir, ensemble_dir, "log")
assert osp.isdir(run_path)
assert osp.isdir(log_path)
ids.clear()


def test_generate_ensemble_copy(test_dir, wlmutils, monkeypatch, get_gen_copy_dir):
Expand All @@ -299,7 +294,6 @@ def test_generate_ensemble_copy(test_dir, wlmutils, monkeypatch, get_gen_copy_di
for ensemble_dir in job_dir:
copy_folder_path = os.path.join(jobs_dir, ensemble_dir, "run", "to_copy_dir")
assert osp.isdir(copy_folder_path)
ids.clear()


def test_generate_ensemble_symlink(
Expand Down Expand Up @@ -327,7 +321,6 @@ def test_generate_ensemble_symlink(
assert osp.isdir(sym_file_path)
assert sym_file_path.is_symlink()
assert os.fspath(sym_file_path.resolve()) == osp.realpath(get_gen_symlink_dir)
ids.clear()


def test_generate_ensemble_configure(
Expand All @@ -351,7 +344,6 @@ def test_generate_ensemble_configure(
job_list = ensemble.as_jobs(launch_settings)
exp = Experiment(name="exp_name", exp_path=test_dir)
id = exp.start(*job_list)
print(id)
run_dir = listdir(test_dir)
jobs_dir = os.path.join(test_dir, run_dir[0], "jobs")

Expand All @@ -372,4 +364,3 @@ def _check_generated(param_0, param_1, dir):
_check_generated(1, 2, os.path.join(jobs_dir, "ensemble-name-2-2", "run"))
_check_generated(1, 3, os.path.join(jobs_dir, "ensemble-name-3-3", "run"))
_check_generated(0, 2, os.path.join(jobs_dir, "ensemble-name-0-0", "run"))
ids.clear()
3 changes: 3 additions & 0 deletions tests/test_launch_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def start(self, _):
def get_status(self, *_):
raise NotImplementedError

def stop_jobs(self, *_):
raise NotImplementedError


LAUNCHER_INSTANCE_A = MockLancher()
LAUNCHER_INSTANCE_B = MockLancher()
Expand Down
Loading