Skip to content
21 changes: 21 additions & 0 deletions smartsim/_core/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,39 @@
from datetime import datetime
from shutil import which

from typing_extensions import TypeAlias

if t.TYPE_CHECKING:
from types import FrameType

from typing_extensions import TypeVarTuple, Unpack

from smartsim.launchable.job import Job

Check warning on line 53 in smartsim/_core/utils/helpers.py

View check run for this annotation

Codecov / codecov/patch

smartsim/_core/utils/helpers.py#L53

Added line #L53 was not covered by tests

_Ts = TypeVarTuple("_Ts")


_T = t.TypeVar("_T")
_HashableT = t.TypeVar("_HashableT", bound=t.Hashable)
_TSignalHandlerFn = t.Callable[[int, t.Optional["FrameType"]], object]

_NestedJobSequenceType: TypeAlias = "t.Sequence[Job | _NestedJobSequenceType]"


def unpack(value: _NestedJobSequenceType) -> t.Generator[Job, None, None]:
"""Unpack any iterable input in order to obtain a
single sequence of values

:param value: Sequence containing elements of type Job or other
sequences that are also of type _NestedJobSequenceType
:return: flattened list of Jobs"""

for item in value:
if isinstance(item, t.Iterable):
yield from unpack(item)
else:
yield item


def check_name(name: str) -> None:
"""
Expand Down
7 changes: 3 additions & 4 deletions smartsim/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,18 @@ def __init__(self, name: str, exp_path: str | None = None):
experiment
"""

def start(self, *jobs: Job) -> tuple[LaunchedJobID, ...]:
def start(self, *jobs: Job | t.Sequence[Job]) -> tuple[LaunchedJobID, ...]:
"""Execute a collection of `Job` instances.

:param jobs: A collection of other job instances to start
:returns: A sequence of ids with order corresponding to the sequence of
jobs that can be used to query or alter the status of that
particular execution of the job.
"""
# Create the run id
jobs_ = list(_helpers.unpack(jobs))
run_id = datetime.datetime.now().replace(microsecond=0).isoformat()
# Generate the root path
root = pathlib.Path(self.exp_path, run_id)
return self._dispatch(Generator(root), dispatch.DEFAULT_DISPATCHER, *jobs)
return self._dispatch(Generator(root), dispatch.DEFAULT_DISPATCHER, *jobs_)

def _dispatch(
self,
Expand Down
22 changes: 21 additions & 1 deletion tests/_legacy/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,32 @@
import pytest

from smartsim._core.utils import helpers
from smartsim._core.utils.helpers import cat_arg_and_value
from smartsim._core.utils.helpers import cat_arg_and_value, unpack
from smartsim.entity.application import Application
from smartsim.launchable.job import Job
from smartsim.settings.launch_settings import LaunchSettings

# The tests in this file belong to the group_a group
pytestmark = pytest.mark.group_a


def test_unpack_iterates_over_nested_jobs_in_expected_order(wlmutils):
launch_settings = LaunchSettings(wlmutils.get_test_launcher())
app = Application("app_name", exe="python")
job_1 = Job(app, launch_settings)
job_2 = Job(app, launch_settings)
job_3 = Job(app, launch_settings)
job_4 = Job(app, launch_settings)
job_5 = Job(app, launch_settings)

assert (
[job_1, job_2, job_3, job_4, job_5]
== list(unpack([job_1, [job_2, job_3], job_4, [job_5]]))
== list(unpack([job_1, job_2, [job_3, job_4], job_5]))
== list(unpack([job_1, [job_2, [job_3, job_4], job_5]]))
)


def test_double_dash_concat():
result = cat_arg_and_value("--foo", "FOO")
assert result == "--foo=FOO"
Expand Down
119 changes: 118 additions & 1 deletion tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,32 @@
import time
import typing as t
import uuid
from os import path as osp

import pytest

from smartsim._core import dispatch
from smartsim._core.control.interval import SynchronousTimeInterval
from smartsim._core.control.launch_history import LaunchHistory
from smartsim._core.utils.launcher import LauncherProtocol, create_job_id
from smartsim.builders.ensemble import Ensemble
from smartsim.entity import entity
from smartsim.entity.application import Application
from smartsim.error import errors
from smartsim.experiment import Experiment
from smartsim.launchable import job
from smartsim.settings import launch_settings
from smartsim.settings.arguments import launch_arguments
from smartsim.status import InvalidJobStatus, JobStatus
from smartsim.types import LaunchedJobID

pytestmark = pytest.mark.group_a

_ID_GENERATOR = (str(i) for i in itertools.count())


def random_id():
return next(_ID_GENERATOR)


@pytest.fixture
def experiment(monkeypatch, test_dir, dispatcher):
Expand Down Expand Up @@ -611,3 +620,111 @@ def test_experiment_stop_does_not_raise_on_unknown_job_id(
assert stat == InvalidJobStatus.NEVER_STARTED
after_cancel = exp.get_status(*all_known_ids)
assert before_cancel == after_cancel


@pytest.mark.parametrize(
"job_list",
(
pytest.param(
[
(
job.Job(
Application(
"test_name",
exe="echo",
exe_args=["spam", "eggs"],
),
launch_settings.LaunchSettings("local"),
),
Ensemble("ensemble-name", "echo", replicas=2).build_jobs(
launch_settings.LaunchSettings("local")
),
)
],
id="(job1, (job2, job_3))",
),
pytest.param(
[
(
Ensemble("ensemble-name", "echo", replicas=2).build_jobs(
launch_settings.LaunchSettings("local")
),
(
job.Job(
Application(
"test_name",
exe="echo",
exe_args=["spam", "eggs"],
),
launch_settings.LaunchSettings("local"),
),
job.Job(
Application(
"test_name_2",
exe="echo",
exe_args=["spam", "eggs"],
),
launch_settings.LaunchSettings("local"),
),
),
)
],
id="((job1, job2), (job3, job4))",
),
pytest.param(
[
(
job.Job(
Application(
"test_name",
exe="echo",
exe_args=["spam", "eggs"],
),
launch_settings.LaunchSettings("local"),
),
)
],
id="(job,)",
),
pytest.param(
[
[
job.Job(
Application(
"test_name",
exe="echo",
exe_args=["spam", "eggs"],
),
launch_settings.LaunchSettings("local"),
),
(
Ensemble("ensemble-name", "echo", replicas=2).build_jobs(
launch_settings.LaunchSettings("local")
),
job.Job(
Application(
"test_name_2",
exe="echo",
exe_args=["spam", "eggs"],
),
launch_settings.LaunchSettings("local"),
),
),
]
],
id="[job_1, ((job_2, job_3), job_4)]",
),
),
)
def test_start_unpack(
test_dir: str, wlmutils, monkeypatch: pytest.MonkeyPatch, job_list: job.Job
):
"""Test unpacking a sequences of jobs"""

monkeypatch.setattr(
"smartsim._core.dispatch._LauncherAdapter.start",
lambda launch, exe, job_execution_path, env, out, err: random_id(),
)

exp = Experiment(name="exp_name", exp_path=test_dir)
exp.start(*job_list)
Loading