Skip to content
112 changes: 112 additions & 0 deletions smartsim/_core/control/interval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# BSD 2-Clause License
#
# Copyright (c) 2021-2024, Hewlett Packard Enterprise
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from __future__ import annotations

import time
import typing as t

Seconds = t.NewType("Seconds", float)


class SynchronousTimeInterval:
"""A utility class to represent and synchronously block the execution of a
thread for an interval of time.
"""

def __init__(self, delta: float | None) -> None:
"""Initialize a new `SynchronousTimeInterval` interval

:param delta: The difference in time the interval represents in
seconds. If `None`, the interval will represent an infinite amount
of time.
:raises ValueError: The `delta` is negative
"""
if delta is not None and delta < 0:
raise ValueError("Timeout value cannot be less than 0")
if delta is None:
delta = float("inf")
self._delta = Seconds(delta)
"""The amount of time, in seconds, the interval spans."""
self._start = time.perf_counter()
"""The time of the creation of the interval"""

@property
def delta(self) -> Seconds:
"""The difference in time the interval represents

:returns: The difference in time the interval represents
"""
return self._delta

@property
def elapsed(self) -> Seconds:
"""The amount of time that has passed since the interval was created

:returns: The amount of time that has passed since the interval was
created
"""
return Seconds(time.perf_counter() - self._start)

@property
def remaining(self) -> Seconds:
"""The amount of time remaining in the interval

:returns: The amount of time remaining in the interval
"""
return Seconds(max(self.delta - self.elapsed, 0))

@property
def expired(self) -> bool:
"""The amount of time remaining in interval

:returns: The amount of time left in the interval
"""
return self.remaining <= 0

@property
def infinite(self) -> bool:
"""Return true if the timeout interval is infinitely long

:returns: `True` if the delta is infinite, `False` otherwise
"""
return self.remaining == float("inf")

def new_interval(self) -> SynchronousTimeInterval:
"""Make a new timeout with the same interval

:returns: The new time interval
"""
return type(self)(self.delta)

def block(self) -> None:
"""Block the thread until the timeout completes

:raises RuntimeError: The thread would be blocked forever
"""
if self.remaining == float("inf"):
raise RuntimeError("Cannot block thread forever")
time.sleep(self.remaining)
45 changes: 43 additions & 2 deletions smartsim/_core/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import base64
import collections.abc
import functools
import itertools
import os
import signal
Expand All @@ -40,12 +41,15 @@
import uuid
import warnings
from datetime import datetime
from functools import lru_cache
from shutil import which

if t.TYPE_CHECKING:
from types import FrameType

from typing_extensions import TypeVarTuple, Unpack

_Ts = TypeVarTuple("_Ts")


_T = t.TypeVar("_T")
_HashableT = t.TypeVar("_HashableT", bound=t.Hashable)
Expand Down Expand Up @@ -97,7 +101,7 @@ def create_lockfile_name() -> str:
return f"smartsim-{lock_suffix}.lock"


@lru_cache(maxsize=20, typed=False)
@functools.lru_cache(maxsize=20, typed=False)
def check_dev_log_level() -> bool:
lvl = os.environ.get("SMARTSIM_LOG_LEVEL", "")
return lvl == "developer"
Expand Down Expand Up @@ -454,6 +458,43 @@ def group_by(
return dict(groups)


def pack_params(
fn: t.Callable[[Unpack[_Ts]], _T]
) -> t.Callable[[tuple[Unpack[_Ts]]], _T]:
r"""Take a function that takes an unspecified number of positional arguments
and turn it into a function that takes one argument of type `tuple` of
unspecified length. The main use case is largely just for iterating over an
iterable where arguments are "pre-zipped" into tuples. E.g.

.. highlight:: python
.. code-block:: python

def pretty_print_dict(d):
fmt_pair = lambda key, value: f"{repr(key)}: {repr(value)},"
body = "\n".join(map(pack_params(fmt_pair), d.items()))
# ^^^^^^^^^^^^^^^^^^^^^
print(f"{{\n{textwrap.indent(body, ' ')}\n}}")

pretty_print_dict({"spam": "eggs", "foo": "bar", "hello": "world"})
# prints:
# {
# 'spam': 'eggs',
# 'foo': 'bar',
# 'hello': 'world',
# }

:param fn: A callable that takes many positional parameters.
:returns: A callable that takes a single positional parameter of type tuple
of with the same shape as the original callable parameter list.
"""

@functools.wraps(fn)
def packed(args: tuple[Unpack[_Ts]]) -> _T:
return fn(*args)

return packed


@t.final
class SignalInterceptionStack(collections.abc.Collection[_TSignalHandlerFn]):
"""Registers a stack of callables to be called when a signal is
Expand Down
84 changes: 81 additions & 3 deletions smartsim/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# pylint: disable=too-many-lines

from __future__ import annotations

import datetime
Expand All @@ -39,9 +37,11 @@

from smartsim._core import dispatch
from smartsim._core.config import CONFIG
from smartsim._core.control import interval as _interval
from smartsim._core.control.launch_history import LaunchHistory as _LaunchHistory
from smartsim._core.utils import helpers as _helpers
from smartsim.error import errors
from smartsim.status import InvalidJobStatus, JobStatus
from smartsim.status import TERMINAL_STATUSES, InvalidJobStatus, JobStatus

from ._core import Generator, Manifest, previewrenderer
from .entity import TelemetryConfiguration
Expand Down Expand Up @@ -254,6 +254,84 @@ def get_status(
stats = (stats_map.get(i, InvalidJobStatus.NEVER_STARTED) for i in ids)
return tuple(stats)

def wait(
self, *ids: LaunchedJobID, timeout: float | None = None, verbose: bool = True
) -> None:
"""Block execution until all of the provided launched jobs, represented
by an ID, have entered a terminal status.

:param ids: The ids of the launched jobs to wait for.
:param timeout: The max time to wait for all of the launched jobs to end.
:param verbose: Whether found statuses should be displayed in the console.
:raises ValueError: No IDs were provided.
"""
if not ids:
raise ValueError("No job ids to wait on provided")
self._poll_for_statuses(
ids, TERMINAL_STATUSES, timeout=timeout, verbose=verbose
)

def _poll_for_statuses(
self,
ids: t.Sequence[LaunchedJobID],
statuses: t.Collection[JobStatus],
timeout: float | None = None,
interval: float = 5.0,
verbose: bool = True,
) -> dict[LaunchedJobID, JobStatus | InvalidJobStatus]:
"""Poll the experiment's launchers for the statuses of the launched
jobs with the provided ids, until the status of the changes to one of
the provided statuses.

:param ids: The ids of the launched jobs to wait for.
:param statuses: A collection of statuses to poll for.
:param timeout: The minimum amount of time to spend polling all jobs to
reach one of the supplied statuses. If not supplied or `None`, the
experiment will poll indefinitely.
:param interval: The minimum time between polling launchers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is param interval just for us to use for testing? Seems like user cannot define

Copy link
Member Author

@MattToast MattToast Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly! It was something that, right now, could be hard coded, but if in the future we wanted to make it variable we can change the parameter. Totally willing to remove if we think the excess complexity is unnecessary in a YAGNI way!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't mind leaving it in and keeping it in the docs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I say keep if but not in docstring

:param verbose: Whether or not to log polled states to the console.
:raises ValueError: The interval between polling launchers is infinite
:raises TimeoutError: The polling interval was exceeded.
:returns: A mapping of ids to the status they entered that ended
polling.
"""
terminal = frozenset(itertools.chain(statuses, InvalidJobStatus))
log = logger.info if verbose else lambda *_, **__: None
method_timeout = _interval.SynchronousTimeInterval(timeout)
iter_timeout = _interval.SynchronousTimeInterval(interval)
final: dict[LaunchedJobID, JobStatus | InvalidJobStatus] = {}

def is_finished(
id_: LaunchedJobID, status: JobStatus | InvalidJobStatus
) -> bool:
job_title = f"Job({id_}): "
if done := status in terminal:
log(f"{job_title}Finished with status '{status.value}'")
else:
log(f"{job_title}Running with status '{status.value}'")
return done

if iter_timeout.infinite:
raise ValueError("Polling interval cannot be infinite")
while ids and not method_timeout.expired:
iter_timeout = iter_timeout.new_interval()
stats = zip(ids, self.get_status(*ids))
is_done = _helpers.group_by(_helpers.pack_params(is_finished), stats)
final |= dict(is_done.get(True, ()))
ids = tuple(id_ for id_, _ in is_done.get(False, ()))
if ids:
(
iter_timeout
if iter_timeout.remaining < method_timeout.remaining
else method_timeout
).block()
if ids:
raise TimeoutError(
f"Job ID(s) {', '.join(map(str, ids))} failed to reach "
"terminal status before timeout"
)
return final

@_contextualize
def _generate(
self, generator: Generator, job: Job, job_index: int
Expand Down
Loading