1212from  pathlib  import  Path 
1313from  typing  import  TYPE_CHECKING , Any 
1414from  warnings  import  warn 
15+ from  datetime  import  timedelta , datetime , timezone 
16+ from  itertools  import  accumulate 
1517
1618import  numpy  as  np 
1719from  scipy .optimize  import  NonlinearConstraint 
@@ -92,6 +94,7 @@ def __init__(
9294        verbose : int  =  2 ,
9395        bounds_transformer : DomainTransformer  |  None  =  None ,
9496        allow_duplicate_points : bool  =  False ,
97+         termination_criteria : Mapping [str , float  |  Mapping [str , float ]] |  None  =  None ,
9598    ):
9699        self ._random_state  =  ensure_rng (random_state )
97100        self ._allow_duplicate_points  =  allow_duplicate_points 
@@ -139,6 +142,18 @@ def __init__(
139142
140143        self ._sorting_warning_already_shown  =  False   # TODO: remove in future version 
141144
145+         self ._termination_criteria  =  termination_criteria  if  termination_criteria  is  not   None  else  {}
146+ 
147+         self ._initial_iterations  =  0 
148+         self ._optimizing_iterations  =  0 
149+ 
150+         self ._start_time : datetime  |  None  =  None 
151+         self ._timedelta : timedelta  |  None  =  None 
152+ 
153+         # Directly instantiate timedelta if provided 
154+         if  termination_criteria  and  "time"  in  termination_criteria :
155+             self ._timedelta  =  timedelta (** termination_criteria ["time" ])
156+ 
142157        # Initialize logger 
143158        self .logger  =  ScreenLogger (verbose = self ._verbose , is_constrained = self .is_constrained )
144159
@@ -295,7 +310,7 @@ def maximize(self, init_points: int = 5, n_iter: int = 25) -> None:
295310
296311        n_iter: int, optional(default=25) 
297312            Number of iterations where the method attempts to find the maximum 
298-             value. 
313+             value. Used when other termination criteria are not provided.  
299314
300315        Warning 
301316        ------- 
@@ -309,19 +324,27 @@ def maximize(self, init_points: int = 5, n_iter: int = 25) -> None:
309324        # Log optimization start 
310325        self .logger .log_optimization_start (self ._space .keys )
311326
327+         if  self ._start_time  is  None  and  "time"  in  self ._termination_criteria :
328+             self ._start_time  =  datetime .now (timezone .utc )
329+ 
330+         # Set iterations as termination criteria if others not supplied, increment existing if it already exists. 
331+         self ._termination_criteria ["iterations" ] =  max (
332+             self ._termination_criteria .get ("iterations" , 0 ) +  n_iter  +  init_points , 1 
333+         )
334+ 
312335        # Prime the queue with random points 
313336        self ._prime_queue (init_points )
314337
315-         iteration  =  0 
316-         while  self ._queue  or  iteration  <  n_iter :
338+         while  self ._queue  or  not  self .termination_criteria_met ():
317339            try :
318340                x_probe  =  self ._queue .popleft ()
341+                 self ._initial_iterations  +=  1 
319342            except  IndexError :
320343                x_probe  =  self .suggest ()
321-                 iteration  +=  1 
344+                 self . _optimizing_iterations  +=  1 
322345            self .probe (x_probe , lazy = False )
323346
324-             if  self ._bounds_transformer  and  iteration   >   0 :
347+             if  self ._bounds_transformer  and  not   self . _queue :
325348                # The bounds transformer should only modify the bounds after 
326349                # the init_points points (only for the true iterations) 
327350                self .set_bounds (self ._bounds_transformer .transform (self ._space ))
@@ -345,6 +368,51 @@ def set_gp_params(self, **params: Any) -> None:
345368            params ["kernel" ] =  wrap_kernel (kernel = params ["kernel" ], transform = self ._space .kernel_transform )
346369        self ._gp .set_params (** params )
347370
371+     def  termination_criteria_met (self ) ->  bool :
372+         """Determine if the termination criteria have been met.""" 
373+         if  "iterations"  in  self ._termination_criteria :
374+             if  (
375+                 self ._optimizing_iterations  +  self ._initial_iterations 
376+                 >=  self ._termination_criteria ["iterations" ]
377+             ):
378+                 return  True 
379+ 
380+         if  "value"  in  self ._termination_criteria :
381+             if  self .max  is  not   None  and  self .max ["target" ] >=  self ._termination_criteria ["value" ]:
382+                 return  True 
383+ 
384+         if  "time"  in  self ._termination_criteria :
385+             time_taken  =  datetime .now (timezone .utc ) -  self ._start_time 
386+             if  time_taken  >=  self ._timedelta :
387+                 return  True 
388+ 
389+         if  "convergence_tol"  in  self ._termination_criteria  and  len (self ._space .target ) >  2 :
390+             # Find the maximum value of the target function at each iteration 
391+             running_max  =  list (accumulate (self ._space .target , max ))
392+             # Determine improvements that have occurred each iteration 
393+             improvements  =  np .diff (running_max )
394+             if  (
395+                 self ._initial_iterations  +  self ._optimizing_iterations 
396+                 >=  self ._termination_criteria ["convergence_tol" ]["n_iters" ]
397+             ):
398+                 # Check if there are improvements in the specified number of iterations 
399+                 relevant_improvements  =  (
400+                     improvements 
401+                     if  len (self ._space .target ) ==  self ._termination_criteria ["convergence_tol" ]["n_iters" ]
402+                     else  improvements [- self ._termination_criteria ["convergence_tol" ]["n_iters" ] :]
403+                 )
404+                 # There has been no improvement within the iterations specified 
405+                 if  len (set (relevant_improvements )) ==  1 :
406+                     return  True 
407+                 # The improvement(s) are lower than specified 
408+                 if  (
409+                     max (relevant_improvements ) -  min (relevant_improvements )
410+                     <  self ._termination_criteria ["convergence_tol" ]["abs_tol" ]
411+                 ):
412+                     return  True 
413+ 
414+         return  False 
415+ 
348416    def  save_state (self , path : str  |  PathLike [str ]) ->  None :
349417        """Save complete state for reconstruction of the optimizer. 
350418
@@ -385,6 +453,13 @@ def save_state(self, path: str | PathLike[str]) -> None:
385453            "verbose" : self ._verbose ,
386454            "random_state" : random_state ,
387455            "acquisition_params" : acquisition_params ,
456+             "termination_criteria" : self ._termination_criteria ,
457+             "initial_iterations" : self ._initial_iterations ,
458+             "optimizing_iterations" : self ._optimizing_iterations ,
459+             "start_time" : datetime .strftime (self ._start_time , "%Y-%m-%dT%H:%M:%SZ" )
460+             if  self ._start_time 
461+             else  "" ,
462+             "timedelta" : self ._timedelta .total_seconds () if  self ._timedelta  else  "" ,
388463        }
389464
390465        with  Path (path ).open ("w" ) as  f :
@@ -443,3 +518,14 @@ def load_state(self, path: str | PathLike[str]) -> None:
443518                state ["random_state" ]["cached_gaussian" ],
444519            )
445520            self ._random_state .set_state (random_state_tuple )
521+ 
522+         self ._termination_criteria  =  state ["termination_criteria" ]
523+         self ._initial_iterations  =  state ["initial_iterations" ]
524+         self ._optimizing_iterations  =  state ["optimizing_iterations" ]
525+         # Previously saved as UTC, so explicitly parse as UTC time. 
526+         self ._start_time  =  (
527+             datetime .strptime (state ["start_time" ], "%Y-%m-%dT%H:%M:%SZ" ).replace (tzinfo = timezone .utc )
528+             if  state ["start_time" ] !=  "" 
529+             else  None 
530+         )
531+         self ._timedelta  =  timedelta (seconds = state ["timedelta" ]) if  state ["timedelta" ] else  None 
0 commit comments