Skip to content

Commit 9b674ef

Browse files
committed
ft: registries now handle isolated input I/O for future concurrency-safe execution
1 parent 181f414 commit 9b674ef

File tree

9 files changed

+95
-42
lines changed

9 files changed

+95
-42
lines changed

floatcsep/experiment.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class Experiment:
7676
7777
model_config (str): Path to the models' configuration file
7878
test_config (str): Path to the evaluations' configuration file
79+
run_mode (str): 'sequential' or 'parallel'
7980
default_test_kwargs (dict): Default values for the testing
8081
(seed, number of simulations, etc.)
8182
postprocess (dict): Contains the instruction for postprocessing
@@ -100,6 +101,7 @@ def __init__(
100101
postprocess: str = None,
101102
default_test_kwargs: dict = None,
102103
rundir: str = "results",
104+
run_mode: str = "sequential",
103105
report_hook: dict = None,
104106
**kwargs,
105107
) -> None:
@@ -127,6 +129,7 @@ def __init__(
127129
self.original_config = kwargs.get("original_config", None)
128130
self.original_run_dir = kwargs.get("original_rundir", None)
129131
self.run_dir = rundir
132+
self.run_mode = run_mode
130133
self.seed = kwargs.get("seed", None)
131134
self.time_config = read_time_cfg(time_config, **kwargs)
132135
self.region_config = read_region_cfg(region_config, **kwargs)
@@ -296,7 +299,7 @@ def stage_models(self) -> None:
296299
"""
297300
log.info("Staging models")
298301
for i in self.models:
299-
i.stage(self.time_windows)
302+
i.stage(self.time_windows, run_mode=self.run_mode, run_dir=self.run_dir)
300303
self.registry.add_model_registry(i)
301304

302305
def set_tests(self, test_config: Union[str, Dict, List]) -> list:
@@ -377,7 +380,7 @@ def set_tasks(self) -> None:
377380
"""
378381

379382
# Set the file path structure
380-
self.registry.build_tree(self.time_windows, self.models, self.tests)
383+
self.registry.build_tree(self.time_windows, self.models, self.tests, self.run_mode)
381384

382385
log.debug("Pre-run forecast summary")
383386
log_models_tree(log, self.registry, self.time_windows)

floatcsep/infrastructure/registries.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import ABC, abstractmethod
44
from datetime import datetime
55
from os.path import join, abspath, relpath, normpath, dirname, exists
6-
from typing import Sequence, Union, TYPE_CHECKING, Any
6+
from typing import Sequence, Union, TYPE_CHECKING, Any, Optional
77

88
from floatcsep.utils.helpers import timewindow2str
99

@@ -149,6 +149,7 @@ def factory(cls, registry_type: str = 'file', **kwargs) -> "ModelRegistry":
149149
class ModelFileRegistry(ModelRegistry, FilepathMixin):
150150
def __init__(
151151
self,
152+
model_name: str,
152153
workdir: str,
153154
path: str,
154155
database: str = None,
@@ -159,19 +160,25 @@ def __init__(
159160
"""
160161
161162
Args:
163+
model_name (str): Model's identifier string
162164
workdir (str): The current working directory of the experiment.
163165
path (str): The path of the model working directory (or model filepath).
164166
database (str): The path of the database, in case forecasts are stored therein.
165167
args_file (str): The path of the arguments file (only for TimeDependentModel).
166168
input_cat (str): : The path of the arguments file (only for TimeDependentModel).
167169
"""
168170

171+
self.model_name = model_name
169172
self.workdir = workdir
170173
self.path = path
171174
self.database = database
172175
self.args_file = args_file
173176
self.input_cat = input_cat
177+
174178
self.forecasts = {}
179+
self.input_args = {}
180+
self.input_cats = {}
181+
self.input_store = None
175182
self._fmt = fmt
176183

177184
@property
@@ -227,7 +234,7 @@ def get_input_catalog_key(self, *args: Sequence[str]) -> str:
227234
Returns:
228235
The input catalog registry key from a sequence of key values
229236
"""
230-
return self.get_attr("input_cat", *args)
237+
return self.get_attr("input_cats", *args)
231238

232239
def get_forecast_key(self, *args: Sequence[str]) -> str:
233240
"""
@@ -253,15 +260,15 @@ def get_args_key(self, *args: Sequence[str]) -> str:
253260
Returns:
254261
The argument file's key(s) from a sequence of key values
255262
"""
256-
return self.get_attr("args_file", *args)
263+
return self.get_attr("input_args", *args)
257264

258265
def build_tree(
259266
self,
260267
time_windows: Sequence[Sequence[datetime]] = None,
261268
model_class: str = "TimeIndependentModel",
262269
prefix: str = None,
263-
args_file: str = None,
264-
input_cat: str = None
270+
run_mode: str = 'sequential',
271+
run_dir: Optional[str] = None
265272
) -> None:
266273
"""
267274
Creates the run directory, and reads the file structure inside.
@@ -270,33 +277,43 @@ def build_tree(
270277
time_windows (list(str)): List of time windows or strings.
271278
model_class (str): Model's class name
272279
prefix (str): prefix of the model forecast filenames if TD
273-
args_file (str, bool): input arguments path of the model if TD
274-
input_cat (str, bool): input catalog path of the model if TD
275-
fmt (str, bool): for time dependent mdoels
276-
280+
run_mode (str): if run mode is sequential, input data (args and cat) will be
281+
dynamically overwritten in 'model/input/` through time_windows. If 'parallel',
282+
input data is dynamically writing anew in
283+
'results/{time_window}/input/{model_name}/'.
284+
run_dir (str): Where experiment's results are stored.
277285
"""
278286

279287
windows = timewindow2str(time_windows)
280-
281288
if model_class == "TimeIndependentModel":
282289
fname = self.database if self.database else self.path
283290
self.forecasts = {win: fname for win in windows}
284291

285292
elif model_class == "TimeDependentModel":
286293

287-
args = args_file if args_file else join("input", "args.txt")
288-
self.args_file = join(self.path, args)
289-
input_cat = input_cat if input_cat else join("input", "catalog.csv")
290-
self.input_cat = join(self.path, input_cat)
291-
# grab names for creating directories
294+
# grab names for creating model directories
292295
subfolders = ["input", "forecasts"]
293296
dirtree = {folder: self.abs(self.path, folder) for folder in subfolders}
294-
295-
# create directories if they don't exist
296297
for _, folder_ in dirtree.items():
297298
os.makedirs(folder_, exist_ok=True)
298299

299-
# set forecast names
300+
if run_mode == 'sequential':
301+
self.input_args = {
302+
win: join(self.path, 'input', self.args_file) for win in windows
303+
}
304+
self.input_cats = {
305+
win: join(self.path, 'input', self.input_cat) for win in windows
306+
}
307+
elif run_mode == 'parallel':
308+
self.input_args = {
309+
win: join(run_dir, win, 'input', self.model_name, self.args_file)
310+
for win in windows
311+
}
312+
self.input_cats = {
313+
win: join(run_dir, win, 'input', self.model_name, self.input_cat)
314+
for win in windows
315+
}
316+
300317
self.forecasts = {
301318
win: join(dirtree["forecasts"], f"{prefix}_{win}.{self.fmt}") for win in windows
302319
}
@@ -492,6 +509,7 @@ def build_tree(
492509
time_windows: Sequence[Sequence[datetime]],
493510
models: Sequence["Model"],
494511
tests: Sequence["Evaluation"],
512+
run_mode: str = 'sequential'
495513
) -> None:
496514
"""
497515
Creates the run directory and reads the file structure inside.
@@ -500,6 +518,7 @@ def build_tree(
500518
time_windows: List of time windows, or representing string.
501519
models: List of models or model names
502520
tests: List of tests or test names
521+
run_mode: 'parallel' or 'sequential'
503522
504523
"""
505524
windows = timewindow2str(time_windows)
@@ -509,6 +528,8 @@ def build_tree(
509528

510529
run_folder = self.run_dir
511530
subfolders = ["catalog", "evaluations", "figures"]
531+
if run_mode == 'parallel':
532+
subfolders.append('input')
512533
dirtree = {
513534
win: {folder: self.abs(run_folder, win, folder) for folder in subfolders}
514535
for win in windows
@@ -518,7 +539,11 @@ def build_tree(
518539
for tw, tw_folder in dirtree.items():
519540
for _, folder_ in tw_folder.items():
520541
os.makedirs(folder_, exist_ok=True)
521-
542+
print(folder_)
543+
if run_mode == 'parallel' and folder_.endswith('input'):
544+
print('a')
545+
for model in models:
546+
os.makedirs(join(folder_, model), exist_ok=True)
522547
results = {
523548
win: {
524549
test: {

floatcsep/infrastructure/repositories.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def set_input_cat(self, tstring: str, model: "Model", fmt: str ='ascii') -> None
202202
fmt (str): Output catalog format
203203
"""
204204
start, end = str2timewindow(tstring)
205-
input_cat_name = model.registry.get_input_catalog_key()
205+
input_cat_name = model.registry.get_input_catalog_key(tstring)
206206
sub_cat = self.catalog.filter([f"origin_time < {start.timestamp() * 1000}"])
207207

208208
writer = getattr(CatalogSerializer, fmt)

floatcsep/model.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,14 @@ def __init__(self, name: str, model_path: str, forecast_unit=1, store_db=False,
210210

211211
self.forecast_unit = forecast_unit
212212
self.store_db = store_db
213-
self.registry = ModelRegistry.factory(workdir=kwargs.get("workdir", os.getcwd()),
214-
path=model_path)
213+
self.registry = ModelRegistry.factory(model_name=name,
214+
workdir=kwargs.get("workdir", os.getcwd()),
215+
path=model_path)
215216
self.repository = ForecastRepository.factory(
216217
self.registry, model_class=self.__class__.__name__, **kwargs
217218
)
218219

219-
def stage(self, time_windows: Sequence[Sequence[datetime]] = None) -> None:
220+
def stage(self, time_windows: Sequence[Sequence[datetime]] = None, **kwargs) -> None:
220221
"""
221222
Acquire the forecast data if it is not in the file system. Sets the paths internally
222223
(or database pointers) to the forecast data.
@@ -298,6 +299,8 @@ def __init__(
298299
model_path: str,
299300
func: Union[str, Callable] = None,
300301
func_kwargs: dict = None,
302+
args_file: str = "args.txt",
303+
input_cat: str = "catalog.csv",
301304
fmt: str = 'csv',
302305
**kwargs,
303306
) -> None:
@@ -310,6 +313,8 @@ def __init__(
310313
func: A function/command that runs the model.
311314
func_kwargs: The keyword arguments to run the model. They are usually (over)written
312315
into the file `{model_path}/input/{args_file}`
316+
args_file: Name of the arguments file that will be used to create forecasts
317+
input_cat: Name of the file that will be used as input catalog to create forecasts
313318
**kwargs: Additional keyword parameters, such as a ``prefix`` (str) for the
314319
resulting forecast file paths, ``args_file`` (str) as the path for the model
315320
arguments file or ``input_cat`` that indicates where the input catalog will be
@@ -321,9 +326,12 @@ def __init__(
321326
self.func = func
322327
self.func_kwargs = func_kwargs or {}
323328

324-
self.registry = ModelRegistry.factory(workdir=kwargs.get("workdir", os.getcwd()),
325-
path=model_path,
326-
fmt=fmt)
329+
self.registry = ModelRegistry.factory(model_name=name,
330+
workdir=kwargs.get("workdir", os.getcwd()),
331+
path=model_path,
332+
fmt=fmt,
333+
args_file=args_file,
334+
input_cat=input_cat)
327335
self.repository = ForecastRepository.factory(
328336
self.registry, model_class=self.__class__.__name__, **kwargs
329337
)
@@ -334,7 +342,7 @@ def __init__(
334342
self.build, self.name, self.registry.abs(model_path)
335343
)
336344

337-
def stage(self, time_windows=None) -> None:
345+
def stage(self, time_windows=None, run_mode='sequential', run_dir='') -> None:
338346
"""
339347
Core method to interface a model with the experiment.
340348
@@ -355,8 +363,8 @@ def stage(self, time_windows=None) -> None:
355363
time_windows=time_windows,
356364
model_class=self.__class__.__name__,
357365
prefix=self.__dict__.get("prefix", self.name),
358-
args_file=self.__dict__.get("args_file", None),
359-
input_cat=self.__dict__.get("input_cat", None),
366+
run_mode=run_mode,
367+
run_dir=run_dir
360368
)
361369

362370
def get_forecast(
@@ -407,7 +415,7 @@ def create_forecast(self, tstring: str, **kwargs) -> None:
407415
f"Running {self.name} using {self.environment.__class__.__name__}:"
408416
f" {timewindow2str([start_date, end_date])}"
409417
)
410-
self.environment.run_command(f"{self.func} {self.registry.get_args_key()}")
418+
self.environment.run_command(f"{self.func} {self.registry.get_args_key(tstring)}")
411419

412420
def prepare_args(self, start: datetime, end: datetime, **kwargs) -> None:
413421
"""
@@ -422,7 +430,9 @@ def prepare_args(self, start: datetime, end: datetime, **kwargs) -> None:
422430
**kwargs: represents additional model arguments (name/value pair)
423431
424432
"""
425-
filepath = self.registry.get_args_key()
433+
window_str = timewindow2str([start, end])
434+
435+
filepath = self.registry.get_args_key(window_str)
426436
fmt = os.path.splitext(filepath)[1]
427437

428438
if fmt == ".txt":
@@ -477,8 +487,14 @@ def nested_update(dest: dict, src: dict, max_depth: int = 3, _level: int = 1):
477487
else:
478488
dest[key] = val
479489

490+
if not os.path.exists(filepath):
491+
template_file = os.path.join(self.registry.path,
492+
'input',
493+
self.registry.args_file)
494+
else:
495+
template_file = filepath
480496

481-
with open(filepath, "r") as file_:
497+
with open(template_file, "r") as file_:
482498
args = yaml.safe_load(file_)
483499
args["start_date"] = start.isoformat()
484500
args["end_date"] = end.isoformat()

tests/unit/test_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,15 @@ def setUp(self):
198198
self.model = TimeDependentModel(
199199
name=self.name, model_path=self.model_path, func=self.func
200200
)
201-
202201
def tearDown(self):
203202
patch.stopall()
204203

205204
def test_init(self):
206205
# Assertions to check if the components were instantiated correctly
207206
self.mock_registry_factory.assert_called_once_with(
208-
workdir=os.getcwd(), path=self.model_path, fmt='csv'
207+
model_name=self.name,
208+
workdir=os.getcwd(), path=self.model_path, fmt='csv',
209+
args_file='args.txt', input_cat='catalog.csv'
209210
) # Ensure the registry is initialized correctly
210211
self.mock_repository_factory.assert_called_once_with(
211212
self.mock_registry_instance, model_class="TimeDependentModel"
@@ -233,8 +234,8 @@ def test_stage(self, mk):
233234
time_windows=["2020-01-01_2020-12-31"],
234235
model_class="TimeDependentModel",
235236
prefix=self.model.__dict__.get("prefix", self.name),
236-
args_file=self.model.__dict__.get("args_file", None),
237-
input_cat=self.model.__dict__.get("input_cat", None),
237+
run_mode='sequential',
238+
run_dir=''
238239
)
239240
self.mock_environment_instance.create_environment.assert_called_once()
240241

tests/unit/test_registry.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,16 @@ class TestModelFileRegistry(unittest.TestCase):
99

1010
def setUp(self):
1111
self.registry_for_filebased_model = ModelFileRegistry(
12-
workdir="/test/workdir", path="/test/workdir/model.txt"
12+
model_name='test',
13+
workdir="/test/workdir",
14+
path="/test/workdir/model.txt"
1315
)
1416
self.registry_for_folderbased_model = ModelFileRegistry(
15-
workdir="/test/workdir", path="/test/workdir/model"
17+
model_name='test',
18+
workdir="/test/workdir",
19+
path="/test/workdir/model",
20+
args_file="args.txt",
21+
input_cat="catalog.csv"
1622
)
1723

1824
def test_call(self):
@@ -83,6 +89,7 @@ def test_build_tree_time_dependent(self, mock_listdir, mock_makedirs):
8389
[datetime(2023, 1, 1), datetime(2023, 1, 2)],
8490
[datetime(2023, 1, 2), datetime(2023, 1, 3)],
8591
]
92+
print(self.registry_for_folderbased_model.__dict__)
8693
self.registry_for_folderbased_model.build_tree(
8794
time_windows=time_windows, model_class="TimeDependentModel", prefix="forecast"
8895
)

tutorials/case_h/models.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
giturl: https://git.gfz-potsdam.de/csep/it_experiment/models/vetas.git
33
repo_hash: v3.2
44
path: models/etas
5-
args_file: input/args.json
5+
args_file: args.json
66
func: etas-run
77
func_kwargs:
88
n_sims: 100

tutorials/case_i/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ region_config:
1414
depth_min: 0
1515
depth_max: 70
1616

17+
run_mode: parallel
1718
force_rerun: True
1819
catalog: catalog.csep
1920
model_config: models.yml

0 commit comments

Comments
 (0)