2525 parse_nested_dicts ,
2626)
2727from floatcsep .infrastructure .engine import Task , TaskGraph
28+ from floatcsep .infrastructure .logger import log_models_tree , log_results_tree
2829
2930log = logging .getLogger ("floatLogger" )
3031
@@ -52,8 +53,8 @@ class Experiment:
5253 - growth (:class:`str`): `incremental` or `cumulative`
5354 - offset (:class:`float`): recurrence of forecast creation.
5455
55- For further details, see :func:`~floatcsep.utils.timewindows_ti `
56- and :func:`~floatcsep.utils.timewindows_td `
56+ For further details, see :func:`~floatcsep.utils.time_windows_ti `
57+ and :func:`~floatcsep.utils.time_windows_td `
5758
5859 region_config (dict): Contains all the spatial and magnitude
5960 specifications. It must contain the following keys:
@@ -75,6 +76,7 @@ class Experiment:
7576
7677 model_config (str): Path to the models' configuration file
7778 test_config (str): Path to the evaluations' configuration file
79+ run_mode (str): 'sequential' or 'parallel'
7880 default_test_kwargs (dict): Default values for the testing
7981 (seed, number of simulations, etc.)
8082 postprocess (dict): Contains the instruction for postprocessing
@@ -99,6 +101,7 @@ def __init__(
99101 postprocess : str = None ,
100102 default_test_kwargs : dict = None ,
101103 rundir : str = "results" ,
104+ run_mode : str = "sequential" ,
102105 report_hook : dict = None ,
103106 ** kwargs ,
104107 ) -> None :
@@ -118,14 +121,15 @@ def __init__(
118121 os .makedirs (os .path .join (workdir , rundir ), exist_ok = True )
119122
120123 self .name = name if name else "floatingExp"
121- self .registry = ExperimentRegistry (workdir , rundir )
124+ self .registry = ExperimentRegistry . factory (workdir = workdir , run_dir = rundir )
122125 self .results_repo = ResultsRepository (self .registry )
123126 self .catalog_repo = CatalogRepository (self .registry )
124127
125128 self .config_file = kwargs .get ("config_file" , None )
126129 self .original_config = kwargs .get ("original_config" , None )
127130 self .original_run_dir = kwargs .get ("original_rundir" , None )
128131 self .run_dir = rundir
132+ self .run_mode = run_mode
129133 self .seed = kwargs .get ("seed" , None )
130134 self .time_config = read_time_cfg (time_config , ** kwargs )
131135 self .region_config = read_region_cfg (region_config , ** kwargs )
@@ -143,7 +147,7 @@ def __init__(
143147 log .info (f"Setting up experiment { self .name } :" )
144148 log .info (f"\t Start: { self .start_date } " )
145149 log .info (f"\t End: { self .end_date } " )
146- log .info (f"\t Time windows: { len (self .timewindows )} " )
150+ log .info (f"\t Time windows: { len (self .time_windows )} " )
147151 log .info (f"\t Region: { self .region .name if self .region else None } " )
148152 log .info (
149153 f"\t Magnitude range: [{ numpy .min (self .magnitudes )} ,"
@@ -175,7 +179,7 @@ def __getattr__(self, item: str) -> object:
175179 Override built-in method to return the experiment attributes by also using the command
176180 ``experiment.{attr}``. Adds also to the experiment scope the keys of
177181 :attr:`region_config` or :attr:`time_config`. These are: ``start_date``, ``end_date``,
178- ``timewindows ``, ``horizon``, ``offset``, ``region``, ``magnitudes``, ``mag_min``,
182+ ``time_windows ``, ``horizon``, ``offset``, ``region``, ``magnitudes``, ``mag_min``,
179183 `mag_max``, ``mag_bin``, ``depth_min`` depth_max .
180184 """
181185
@@ -295,8 +299,8 @@ def stage_models(self) -> None:
295299 """
296300 log .info ("Staging models" )
297301 for i in self .models :
298- i .stage (self .timewindows )
299- self .registry .add_forecast_registry (i )
302+ i .stage (self .time_windows , run_mode = self . run_mode , run_dir = self . run_dir )
303+ self .registry .add_model_registry (i )
300304
301305 def set_tests (self , test_config : Union [str , Dict , List ]) -> list :
302306 """
@@ -376,17 +380,17 @@ def set_tasks(self) -> None:
376380 """
377381
378382 # Set the file path structure
379- self .registry .build_tree (self .timewindows , self .models , self .tests )
383+ self .registry .build_tree (self .time_windows , self .models , self .tests , self . run_mode )
380384
381385 log .debug ("Pre-run forecast summary" )
382- self .registry . log_forecast_trees ( self .timewindows )
386+ log_models_tree ( log , self .registry , self .time_windows )
383387 log .debug ("Pre-run result summary" )
384- self .registry . log_results_tree ( )
388+ log_results_tree ( log , self .registry )
385389
386390 log .info ("Setting up experiment's tasks" )
387391
388392 # Get the time windows strings
389- tw_strings = timewindow2str (self .timewindows )
393+ tw_strings = timewindow2str (self .time_windows )
390394
391395 # Prepare the testing catalogs
392396 task_graph = TaskGraph ()
@@ -481,7 +485,7 @@ def set_tasks(self) -> None:
481485 )
482486 # Set up the Sequential_Comparative Scores
483487 elif test_k .type == "sequential_comparative" :
484- tw_strs = timewindow2str (self .timewindows )
488+ tw_strs = timewindow2str (self .time_windows )
485489 for model_j in self .models :
486490 task_k = Task (
487491 instance = test_k ,
@@ -504,7 +508,7 @@ def set_tasks(self) -> None:
504508 )
505509 # Set up the Batch comparative Scores
506510 elif test_k .type == "batch" :
507- time_str = timewindow2str (self .timewindows [- 1 ])
511+ time_str = timewindow2str (self .time_windows [- 1 ])
508512 for model_j in self .models :
509513 task_k = Task (
510514 instance = test_k ,
@@ -540,9 +544,9 @@ def run(self) -> None:
540544 self .task_graph .run ()
541545 log .info ("Calculation completed" )
542546 log .debug ("Post-run forecast registry" )
543- self .registry . log_forecast_trees ( self .timewindows )
547+ log_models_tree ( log , self .registry , self .time_windows )
544548 log .debug ("Post-run result summary" )
545- self .registry . log_results_tree ( )
549+ log_results_tree ( log , self .registry )
546550
547551 def read_results (self , test : Evaluation , window : str ) -> List :
548552 """
@@ -559,7 +563,7 @@ def make_repr(self) -> None:
559563
560564 """
561565 log .info ("Creating reproducibility config file" )
562- repr_config = self .registry .get ("repr_config" )
566+ repr_config = self .registry .get_attr ("repr_config" )
563567
564568 # Dropping region to results folder if it is a file
565569 region_path = self .region_config .get ("path" , False )
@@ -604,7 +608,7 @@ def as_dict(self, extra: Sequence = (), extended=False) -> dict:
604608 "time_config" : {
605609 i : j
606610 for i , j in self .time_config .items ()
607- if (i not in ("timewindows " ,) or extended )
611+ if (i not in ("time_windows " ,) or extended )
608612 },
609613 "region_config" : {
610614 i : j
@@ -731,7 +735,7 @@ def test_stat(test_orig, test_repr):
731735
732736 def get_results (self ):
733737
734- win_orig = timewindow2str (self .original .timewindows )
738+ win_orig = timewindow2str (self .original .time_windows )
735739
736740 tests_orig = self .original .tests
737741
@@ -787,7 +791,7 @@ def get_hash(filename):
787791
788792 def get_filecomp (self ):
789793
790- win_orig = timewindow2str (self .original .timewindows )
794+ win_orig = timewindow2str (self .original .time_windows )
791795
792796 tests_orig = self .original .tests
793797
@@ -801,8 +805,8 @@ def get_filecomp(self):
801805 for tw in win_orig :
802806 results [test .name ][tw ] = dict .fromkeys (models_orig )
803807 for model in models_orig :
804- orig_path = self .original .registry .get_result (tw , test , model )
805- repr_path = self .reproduced .registry .get_result (tw , test , model )
808+ orig_path = self .original .registry .get_result_key (tw , test , model )
809+ repr_path = self .reproduced .registry .get_result_key (tw , test , model )
806810
807811 results [test .name ][tw ][model ] = {
808812 "hash" : (self .get_hash (orig_path ) == self .get_hash (repr_path )),
@@ -811,8 +815,8 @@ def get_filecomp(self):
811815 else :
812816 results [test .name ] = dict .fromkeys (models_orig )
813817 for model in models_orig :
814- orig_path = self .original .registry .get_result (win_orig [- 1 ], test , model )
815- repr_path = self .reproduced .registry .get_result (win_orig [- 1 ], test , model )
818+ orig_path = self .original .registry .get_result_key (win_orig [- 1 ], test , model )
819+ repr_path = self .reproduced .registry .get_result_key (win_orig [- 1 ], test , model )
816820 results [test .name ][model ] = {
817821 "hash" : (self .get_hash (orig_path ) == self .get_hash (repr_path )),
818822 "byte2byte" : filecmp .cmp (orig_path , repr_path ),
0 commit comments