Skip to content

Commit 225aa98

Browse files
committed
WIP: refactors termination criteria
* Refactors the bayesian_optimization class to have optional termination_criteria argument and uses this instead of iteration number * Keeps termination on number of iterations, adds objective function value, runtime, and objective function improvement termination criteria * Tests each implementation
1 parent c410d51 commit 225aa98

File tree

2 files changed

+161
-5
lines changed

2 files changed

+161
-5
lines changed

bayes_opt/bayesian_optimization.py

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from pathlib import Path
1313
from typing import TYPE_CHECKING, Any
1414
from warnings import warn
15+
from datetime import timedelta, datetime, timezone
16+
from itertools import accumulate
1517

1618
import numpy as np
1719
from scipy.optimize import NonlinearConstraint
@@ -92,6 +94,7 @@ def __init__(
9294
verbose: int = 2,
9395
bounds_transformer: DomainTransformer | None = None,
9496
allow_duplicate_points: bool = False,
97+
termination_criteria: Mapping[str, float | Mapping[str, float]] | None = None,
9598
):
9699
self._random_state = ensure_rng(random_state)
97100
self._allow_duplicate_points = allow_duplicate_points
@@ -139,6 +142,18 @@ def __init__(
139142

140143
self._sorting_warning_already_shown = False # TODO: remove in future version
141144

145+
self._termination_criteria = termination_criteria if termination_criteria is not None else {}
146+
147+
self._initial_iterations = 0
148+
self._optimizing_iterations = 0
149+
150+
self._start_time: datetime | None = None
151+
self._timedelta: timedelta | None = None
152+
153+
# Directly instantiate timedelta if provided
154+
if termination_criteria and "time" in termination_criteria:
155+
self._timedelta = timedelta(**termination_criteria["time"])
156+
142157
# Initialize logger
143158
self.logger = ScreenLogger(verbose=self._verbose, is_constrained=self.is_constrained)
144159

@@ -295,7 +310,7 @@ def maximize(self, init_points: int = 5, n_iter: int = 25) -> None:
295310
296311
n_iter: int, optional(default=25)
297312
Number of iterations where the method attempts to find the maximum
298-
value.
313+
value. Used when other termination criteria are not provided.
299314
300315
Warning
301316
-------
@@ -309,19 +324,27 @@ def maximize(self, init_points: int = 5, n_iter: int = 25) -> None:
309324
# Log optimization start
310325
self.logger.log_optimization_start(self._space.keys)
311326

327+
if self._start_time is None and "time" in self._termination_criteria:
328+
self._start_time = datetime.now(timezone.utc)
329+
330+
# Set iterations as termination criteria if others not supplied, increment existing if it already exists.
331+
self._termination_criteria["iterations"] = max(
332+
self._termination_criteria.get("iterations", 0) + n_iter + init_points, 1
333+
)
334+
312335
# Prime the queue with random points
313336
self._prime_queue(init_points)
314337

315-
iteration = 0
316-
while self._queue or iteration < n_iter:
338+
while self._queue or not self.termination_criteria_met():
317339
try:
318340
x_probe = self._queue.popleft()
341+
self._initial_iterations += 1
319342
except IndexError:
320343
x_probe = self.suggest()
321-
iteration += 1
344+
self._optimizing_iterations += 1
322345
self.probe(x_probe, lazy=False)
323346

324-
if self._bounds_transformer and iteration > 0:
347+
if self._bounds_transformer and not self._queue:
325348
# The bounds transformer should only modify the bounds after
326349
# the init_points points (only for the true iterations)
327350
self.set_bounds(self._bounds_transformer.transform(self._space))
@@ -345,6 +368,51 @@ def set_gp_params(self, **params: Any) -> None:
345368
params["kernel"] = wrap_kernel(kernel=params["kernel"], transform=self._space.kernel_transform)
346369
self._gp.set_params(**params)
347370

371+
def termination_criteria_met(self) -> bool:
372+
"""Determine if the termination criteria have been met."""
373+
if "iterations" in self._termination_criteria:
374+
if (
375+
self._optimizing_iterations + self._initial_iterations
376+
>= self._termination_criteria["iterations"]
377+
):
378+
return True
379+
380+
if "value" in self._termination_criteria:
381+
if self.max is not None and self.max["target"] >= self._termination_criteria["value"]:
382+
return True
383+
384+
if "time" in self._termination_criteria:
385+
time_taken = datetime.now(timezone.utc) - self._start_time
386+
if time_taken >= self._timedelta:
387+
return True
388+
389+
if "convergence_tol" in self._termination_criteria and len(self._space.target) > 2:
390+
# Find the maximum value of the target function at each iteration
391+
running_max = list(accumulate(self._space.target, max))
392+
# Determine improvements that have occurred each iteration
393+
improvements = np.diff(running_max)
394+
if (
395+
self._initial_iterations + self._optimizing_iterations
396+
>= self._termination_criteria["convergence_tol"]["n_iters"]
397+
):
398+
# Check if there are improvements in the specified number of iterations
399+
relevant_improvements = (
400+
improvements
401+
if len(self._space.target) == self._termination_criteria["convergence_tol"]["n_iters"]
402+
else improvements[-self._termination_criteria["convergence_tol"]["n_iters"] :]
403+
)
404+
# There has been no improvement within the iterations specified
405+
if len(set(relevant_improvements)) == 1:
406+
return True
407+
# The improvement(s) are lower than specified
408+
if (
409+
max(relevant_improvements) - min(relevant_improvements)
410+
< self._termination_criteria["convergence_tol"]["abs_tol"]
411+
):
412+
return True
413+
414+
return False
415+
348416
def save_state(self, path: str | PathLike[str]) -> None:
349417
"""Save complete state for reconstruction of the optimizer.
350418
@@ -385,6 +453,13 @@ def save_state(self, path: str | PathLike[str]) -> None:
385453
"verbose": self._verbose,
386454
"random_state": random_state,
387455
"acquisition_params": acquisition_params,
456+
"termination_criteria": self._termination_criteria,
457+
"initial_iterations": self._initial_iterations,
458+
"optimizing_iterations": self._optimizing_iterations,
459+
"start_time": datetime.strftime(self._start_time, "%Y-%m-%dT%H:%M:%SZ")
460+
if self._start_time
461+
else "",
462+
"timedelta": self._timedelta.total_seconds() if self._timedelta else "",
388463
}
389464

390465
with Path(path).open("w") as f:
@@ -443,3 +518,14 @@ def load_state(self, path: str | PathLike[str]) -> None:
443518
state["random_state"]["cached_gaussian"],
444519
)
445520
self._random_state.set_state(random_state_tuple)
521+
522+
self._termination_criteria = state["termination_criteria"]
523+
self._initial_iterations = state["initial_iterations"]
524+
self._optimizing_iterations = state["optimizing_iterations"]
525+
# Previously saved as UTC, so explicitly parse as UTC time.
526+
self._start_time = (
527+
datetime.strptime(state["start_time"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc)
528+
if state["start_time"] != ""
529+
else None
530+
)
531+
self._timedelta = timedelta(seconds=state["timedelta"]) if state["timedelta"] else None

tests/test_bayesian_optimization.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

3+
from datetime import datetime
34
import pickle
45
from pathlib import Path
56

7+
from _pytest.tmpdir import tmp_path
68
import numpy as np
79
import pytest
810
from scipy.optimize import NonlinearConstraint
@@ -585,3 +587,71 @@ def area_of_triangle(sides):
585587
suggestion1 = optimizer.suggest()
586588
suggestion2 = new_optimizer.suggest()
587589
np.testing.assert_array_almost_equal(suggestion1["sides"], suggestion2["sides"], decimal=7)
590+
591+
592+
def test_termination_criteria(tmp_path):
593+
"""Test each termination criteria individually."""
594+
595+
def target_func_trivial():
596+
# Max at 0, 1
597+
return lambda x, y: -(x**2) - ((y - 1) ** 2)
598+
599+
termination_criteria = {"iterations": 10}
600+
pbounds = {"x": [-10.0, 10.0], "y": [-10.0, 10.0]}
601+
opt = BayesianOptimization(
602+
f=target_func_trivial(), pbounds=pbounds, termination_criteria=termination_criteria
603+
)
604+
605+
# Ensure no initial points are specified.
606+
opt.maximize(init_points=0, n_iter=10)
607+
608+
assert len(opt.res) == termination_criteria["iterations"]
609+
610+
# Provide reasonable target value for objective fn
611+
termination_criteria = {"value": -0.05}
612+
opt = BayesianOptimization(
613+
f=target_func_trivial(), pbounds=pbounds, termination_criteria=termination_criteria
614+
)
615+
616+
# Call with large number of iterations, so that this is not the termination criteria
617+
opt.maximize(init_points=5, n_iter=1_000)
618+
619+
assert opt.max["target"] > termination_criteria["value"]
620+
621+
# 3 seconds of maximizing before termination
622+
termination_criteria = {"time": {"seconds": 3}}
623+
opt = BayesianOptimization(
624+
f=target_func_trivial(), pbounds=pbounds, termination_criteria=termination_criteria
625+
)
626+
627+
start = datetime.now()
628+
# Call with large number of iterations, so that this is not the termination criteria
629+
opt.maximize(n_iter=1_000, init_points=1)
630+
631+
# Allow ~200ms tolerance on timing
632+
assert abs((datetime.now() - start).total_seconds() - termination_criteria["time"]["seconds"]) < 0.2
633+
634+
# Terminate if no improvement in last 3 iterations
635+
termination_criteria = {"convergence_tol": {"n_iters": 3, "abs_tol": 0}}
636+
637+
opt = BayesianOptimization(
638+
f=target_func_trivial(), pbounds=pbounds, termination_criteria=termination_criteria
639+
)
640+
# Call with number of iterations which will not lead to termination criteria on iterations
641+
opt.maximize(n_iter=1_000, init_points=5)
642+
643+
# Check that none of the last 3 values are the maximum
644+
no_improvement_in_3 = all([value < opt._space.max()["target"] for value in opt._space.target[-3:]])
645+
assert no_improvement_in_3
646+
647+
# Converged if minimum improvement below 1 in last 10 iterations
648+
termination_criteria = {"convergence_tol": {"n_iters": 10, "abs_tol": 1}}
649+
650+
opt = BayesianOptimization(
651+
f=target_func_trivial(), pbounds=pbounds, termination_criteria=termination_criteria
652+
)
653+
opt.maximize(n_iter=1_000, init_points=5)
654+
655+
improvement_below_tol = np.max(opt._space.target[-10:] - opt._space.max()["target"]) < 1
656+
657+
assert improvement_below_tol

0 commit comments

Comments
 (0)