@@ -101,8 +101,9 @@ def __init__(
101101 tests : str = None ,
102102 postprocess : str = None ,
103103 default_test_kwargs : dict = None ,
104- rundir : str = "results" ,
104+ run_dir : str = "results" ,
105105 run_mode : str = "serial" ,
106+ stage_dir : ... = "results" ,
106107 report_hook : dict = None ,
107108 ** kwargs ,
108109 ) -> None :
@@ -116,20 +117,21 @@ def __init__(
116117
117118 workdir = Path (kwargs .get ("path" , os .getcwd ())).resolve ()
118119 if kwargs .get ("timestamp" , False ):
119- rundir = Path (rundir , f"run_{ datetime .datetime .utcnow ().date ().isoformat ()} " )
120- os .makedirs (Path (workdir , rundir ), exist_ok = True )
120+ run_dir = Path (run_dir , f"run_{ datetime .datetime .utcnow ().date ().isoformat ()} " )
121+ os .makedirs (Path (workdir , run_dir ), exist_ok = True )
121122
122123 self .name = name if name else "floatingExp"
123- self .registry = ExperimentRegistry .factory (workdir = workdir , run_dir = rundir )
124+ self .registry = ExperimentRegistry .factory (workdir = workdir , run_dir = run_dir )
124125 self .results_repo = ResultsRepository (self .registry )
125126 self .catalog_repo = CatalogRepository (self .registry )
126127 self .run_id = "run"
127128
128129 self .config_file = kwargs .get ("config_file" , None )
129130 self .original_config = kwargs .get ("original_config" , None )
130- self .original_run_dir = kwargs .get ("original_rundir " , None )
131- self .run_dir = rundir
131+ self .original_run_dir = kwargs .get ("original_run_dir " , None )
132+ self .run_dir = run_dir
132133 self .run_mode = run_mode
134+ self .stage_dir = stage_dir
133135 self .seed = kwargs .get ("seed" , None )
134136 self .time_config = read_time_cfg (time_config , ** kwargs )
135137 self .region_config = read_region_cfg (region_config , ** kwargs )
@@ -139,7 +141,7 @@ def __init__(
139141 logger = kwargs .get ("logging" , False )
140142 if logger :
141143 filename = "experiment.log" if logger is True else logger
142- self .registry .logger = os .path .join (workdir , rundir , filename )
144+ self .registry .logger = os .path .join (workdir , run_dir , filename )
143145 log .info (f"Logging at { self .registry .logger } " )
144146 add_fhandler (self .registry .logger )
145147
@@ -304,7 +306,7 @@ def stage_models(self) -> None:
304306 i .stage (
305307 self .time_windows ,
306308 run_mode = self .run_mode ,
307- stage_dir = self .registry . run_dir ,
309+ stage_dir = self .stage_dir ,
308310 run_id = self .run_id ,
309311 )
310312 self .registry .add_model_registry (i )
@@ -587,8 +589,6 @@ def make_repr(self) -> None:
587589 if not exists (target_cat ):
588590 shutil .copy2 (self .registry .abs (self .catalog_repo .cat_path ), target_cat )
589591
590- # relative_path = self.registry.rel(self.registry.run_dir)
591- # print(self.registry.workdir.__class__, self.registry.run_dir.__class__)
592592 relative_path = Path (
593593 os .path .relpath (self .registry .workdir .as_posix (), self .registry .run_dir .as_posix ())
594594 )
@@ -687,14 +687,13 @@ def from_yml(cls, config_yml: str, repr_dir=None, **kwargs):
687687 # Only ABSOLUTE PATH
688688 _dict ["path" ] = abspath (join (_dir_yml , _dict .get ("path" , "" )))
689689
690- # replaces rundir case reproduce option is used
690+ # replaces run_dir case reproduce option is used
691691 if repr_dir :
692- _dict ["original_rundir " ] = _dict .get ("rundir " , "results" )
693- _dict ["rundir " ] = relpath (join (_dir_yml , repr_dir ), _dict ["path" ])
692+ _dict ["original_run_dir " ] = _dict .get ("run_dir " , "results" )
693+ _dict ["run_dir " ] = relpath (join (_dir_yml , repr_dir ), _dict ["path" ])
694694 _dict ["original_config" ] = abspath (join (_dict ["path" ], _dict ["config_file" ]))
695695 else :
696-
697- _dict ["rundir" ] = _dict .get ("rundir" , kwargs .pop ("rundir" , "results" ))
696+ _dict ["run_dir" ] = _dict .get ("run_dir" , kwargs .pop ("run_dir" , "results" ))
698697 _dict ["config_file" ] = relpath (config_yml , _dir_yml )
699698 if "logging" in _dict :
700699 kwargs .pop ("logging" )
0 commit comments