Skip to content

Commit 19a4e9a

Browse files
authored
Merge pull request #18 from cseptesting/17-improve-experiment-self-discovery-persisting-inputresults-description
refac: more control on the output of dicts from Model.as_dict() and experiment.as_dict()
2 parents f352add + ce77d10 commit 19a4e9a

File tree

6 files changed

+54
-61
lines changed

6 files changed

+54
-61
lines changed

floatcsep/experiment.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import datetime
2-
import json
32
import logging
43
import os
54
import shutil
65
import warnings
76
from os.path import join, abspath, relpath, dirname, isfile, split, exists
8-
from typing import Union, List, Dict, Callable, Sequence
7+
from typing import Union, List, Dict, Sequence
98

10-
import csep
119
import numpy
1210
import yaml
1311
from cartopy import crs as ccrs
@@ -23,13 +21,11 @@
2321
from floatcsep.repository import ResultsRepository, CatalogRepository
2422
from floatcsep.utils import (
2523
NoAliasLoader,
26-
parse_csep_func,
2724
read_time_cfg,
2825
read_region_cfg,
2926
Task,
3027
TaskGraph,
3128
timewindow2str,
32-
str2timewindow,
3329
magnitude_vs_time,
3430
parse_nested_dicts,
3531
)
@@ -184,7 +180,7 @@ def __init__(
184180
self.postproc_config = postproc_config if postproc_config else {}
185181
self.default_test_kwargs = default_test_kwargs
186182

187-
self.catalog_repo.set_catalog(catalog, self.time_config, self.region_config)
183+
self.catalog_repo.set_main_catalog(catalog, self.time_config, self.region_config)
188184

189185
self.models = self.set_models(
190186
models or kwargs.get("model_config"), kwargs.get("order", None)
@@ -717,7 +713,7 @@ def make_repr(self):
717713

718714
# Dropping region to results folder if it is a file
719715
region_path = self.region_config.get("path", False)
720-
if region_path:
716+
if isinstance(region_path, str):
721717
if isfile(region_path) and region_path:
722718
new_path = join(self.registry.rundir, self.region_config["path"])
723719
shutil.copy2(region_path, new_path)
@@ -726,10 +722,10 @@ def make_repr(self):
726722

727723
# Dropping catalog to results folder
728724
target_cat = join(
729-
self.registry.workdir, self.registry.rundir, split(self.catalog_repo._catpath)[-1]
725+
self.registry.workdir, self.registry.rundir, split(self.catalog_repo.cat_path)[-1]
730726
)
731727
if not exists(target_cat):
732-
shutil.copy2(self.registry.abs(self.catalog_repo._catpath), target_cat)
728+
shutil.copy2(self.registry.abs(self.catalog_repo.cat_path), target_cat)
733729
self._catpath = self.registry.rel(target_cat)
734730

735731
relative_path = os.path.relpath(
@@ -738,41 +734,41 @@ def make_repr(self):
738734
self.registry.workdir = relative_path
739735
self.to_yml(repr_config, extended=True)
740736

741-
def as_dict(
742-
self,
743-
exclude: Sequence = (
744-
"magnitudes",
745-
"depths",
746-
"timewindows",
747-
"filetree",
748-
"task_graph",
749-
"tasks",
750-
"models",
751-
"tests",
752-
"results_repo",
753-
"catalog_repo",
754-
),
755-
extended: bool = False,
756-
) -> dict:
737+
def as_dict(self, extra: Sequence = (), extended=False) -> dict:
757738
"""
758739
Converts an Experiment instance into a dictionary.
759740
760741
Args:
761-
exclude (tuple, list): Attributes, or attribute keys, to ignore
762-
extended (bool): Verbose representation of pycsep objects
742+
extra: additional instance attribute to include in the dictionary.
743+
extended: Include explicit parameters
763744
764745
Returns:
765746
A dictionary with serialized instance's attributes, which are
766747
floatCSEP readable
767748
"""
768749

769-
listwalk = [(i, j) for i, j in self.__dict__.items() if not i.startswith("_") and j]
770-
listwalk.insert(6, ("catalog", self.catalog_repo._catpath))
771-
772-
dictwalk = {i: j for i, j in listwalk}
773-
dictwalk["path"] = dictwalk.pop("registry").workdir
750+
dict_walk = {
751+
"name": self.name,
752+
"config_file": self.config_file,
753+
"path": self.registry.workdir,
754+
"run_dir": self.registry.rundir,
755+
"time_config": {
756+
i: j
757+
for i, j in self.time_config.items()
758+
if (i not in ("timewindows",) or extended)
759+
},
760+
"region_config": {
761+
i: j
762+
for i, j in self.region_config.items()
763+
if (i not in ("magnitudes", "depths") or extended)
764+
},
765+
"catalog": self.catalog_repo.cat_path,
766+
"models": [i.as_dict() for i in self.models],
767+
"tests": [i.as_dict() for i in self.tests],
768+
}
769+
dict_walk.update(extra)
774770

775-
return parse_nested_dicts(dictwalk, excluded=exclude, extended=extended)
771+
return parse_nested_dicts(dict_walk)
776772

777773
def to_yml(self, filename: str, **kwargs) -> None:
778774
"""

floatcsep/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
import os
44
from abc import ABC, abstractmethod
55
from datetime import datetime
6-
from typing import List, Callable, Union, Mapping, Sequence
6+
from typing import List, Callable, Union, Sequence
77

88
import git
9-
import numpy
109
from csep.core.forecasts import GriddedForecast, CatalogForecast
1110

1211
from floatcsep.accessors import from_zenodo, from_git
@@ -133,10 +132,10 @@ def as_dict(self, excluded=("name", "repository", "workdir")):
133132
(i, j) for i, j in sorted(self.__dict__.items()) if not i.startswith("_") and j
134133
]
135134

136-
dict_walk = {i: j for i, j in list_walk}
135+
dict_walk = {i: j for i, j in list_walk if i not in excluded}
137136
dict_walk["path"] = dict_walk.pop("registry").path
138137

139-
return {self.name: parse_nested_dicts(dict_walk, excluded=excluded)}
138+
return {self.name: parse_nested_dicts(dict_walk)}
140139

141140
@classmethod
142141
def from_dict(cls, record: dict, **kwargs):

floatcsep/repository.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ class ResultsRepository:
176176

177177
def __init__(self, registry: ExperimentRegistry):
178178
self.registry = registry
179-
self.a = 1
180179

181180
def _load_result(
182181
self,
@@ -236,6 +235,8 @@ def default(self, obj):
236235
class CatalogRepository:
237236

238237
def __init__(self, registry: ExperimentRegistry):
238+
self.cat_path = None
239+
self._catalog = None
239240
self.registry = registry
240241
self.time_config = {}
241242
self.region_config = {}
@@ -270,7 +271,7 @@ def __getattr__(self, item: str) -> object:
270271
def as_dict(self):
271272
return
272273

273-
def set_catalog(
274+
def set_main_catalog(
274275
self, catalog: Union[str, Callable, CSEPCatalog], time_config: dict, region_config: dict
275276
):
276277
"""
@@ -291,11 +292,11 @@ def catalog(self) -> CSEPCatalog:
291292
Returns a CSEP catalog loaded from the given query function or a stored file if it
292293
exists.
293294
"""
294-
cat_path = self.registry.abs(self._catpath)
295+
cat_path = self.registry.abs(self.cat_path)
295296

296297
if callable(self._catalog):
297-
if isfile(self._catpath):
298-
return CSEPCatalog.load_json(self._catpath)
298+
if isfile(self.cat_path):
299+
return CSEPCatalog.load_json(self.cat_path)
299300
bounds = {
300301
"start_time": min([item for sublist in self.timewindows for item in sublist]),
301302
"end_time": max([item for sublist in self.timewindows for item in sublist]),
@@ -318,7 +319,7 @@ def catalog(self) -> CSEPCatalog:
318319
if self.region:
319320
catalog.filter_spatial(region=self.region, in_place=True)
320321
catalog.region = None
321-
catalog.write_json(self._catpath)
322+
catalog.write_json(self.cat_path)
322323

323324
return catalog
324325

@@ -333,19 +334,19 @@ def catalog(self, cat: Union[Callable, CSEPCatalog, str]) -> None:
333334

334335
if cat is None:
335336
self._catalog = None
336-
self._catpath = None
337+
self.cat_path = None
337338

338339
elif isfile(self.registry.abs(cat)):
339340
log.info(f"\tCatalog: '{cat}'")
340341
self._catalog = self.registry.rel(cat)
341-
self._catpath = self.registry.rel(cat)
342+
self.cat_path = self.registry.rel(cat)
342343

343344
else:
344345
# catalog can be a function
345346
self._catalog = parse_csep_func(cat)
346-
self._catpath = self.registry.abs("catalog.json")
347-
if isfile(self._catpath):
348-
log.info(f"\tCatalog: stored " f"'{self._catpath}' " f"from '{cat}'")
347+
self.cat_path = self.registry.abs("catalog.json")
348+
if isfile(self.cat_path):
349+
log.info(f"\tCatalog: stored " f"'{self.cat_path}' " f"from '{cat}'")
349350
else:
350351
log.info(f"\tCatalog: '{cat}'")
351352

@@ -363,7 +364,7 @@ def get_test_cat(self, tstring: str = None) -> CSEPCatalog:
363364
else:
364365
start = self.start_date
365366
end = self.end_date
366-
print(self.catalog)
367+
367368
sub_cat = self.catalog.filter(
368369
[
369370
f"origin_time < {end.timestamp() * 1000}",

floatcsep/utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -407,11 +407,9 @@ def timewindows_td(
407407
# return timewindows
408408

409409

410-
def parse_nested_dicts(
411-
nested_dict: dict, excluded: Sequence = (), extended: bool = False
412-
) -> dict:
410+
def parse_nested_dicts(nested_dict: dict) -> dict:
413411
"""
414-
Parses nested dictionaries to flatten them
412+
Parses nested dictionaries to return appropriate parsing on each element
415413
"""
416414

417415
def _get_value(x):
@@ -435,11 +433,7 @@ def _get_value(x):
435433
def iter_attr(val):
436434
# recursive iter through nested dicts/lists
437435
if isinstance(val, Mapping):
438-
return {
439-
item: iter_attr(val_)
440-
for item, val_ in val.items()
441-
if ((item not in excluded) and val_) or extended
442-
}
436+
return {item: iter_attr(val_) for item, val_ in val.items()}
443437
elif isinstance(val, Sequence) and not isinstance(val, str):
444438
return [iter_attr(i) for i in val]
445439
else:

tests/unit/test_experiment.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ def test_to_dict(self):
6666
"name": "test",
6767
"path": os.getcwd(),
6868
"run_dir": "results",
69+
"config_file": None,
70+
"models": [],
71+
"tests": [],
6972
"time_config": {
7073
"exp_class": "ti",
7174
"start_date": datetime(2020, 1, 1),
@@ -109,7 +112,7 @@ def test_to_yml(self):
109112
self.assertEqualExperiment(exp_a, exp_b)
110113

111114
file_ = tempfile.mkstemp()[1]
112-
exp_a.to_yml(file_, extended=True)
115+
exp_a.to_yml(file_)
113116
exp_c = Experiment.from_yml(file_)
114117
self.assertEqualExperiment(exp_a, exp_c)
115118

tests/unit/test_repositories.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,10 @@ def test_set_catalog(self, mock_isfile):
204204
# Mock the registry's rel method to return the same path for simplicity
205205
self.mock_registry.rel.return_value = "catalog_path"
206206

207-
self.catalog_repo.set_catalog("catalog_path", {}, {})
207+
self.catalog_repo.set_main_catalog("catalog_path", {}, {})
208208

209209
# Check if _catpath is set correctly
210-
self.assertEqual(self.catalog_repo._catpath, "catalog_path")
210+
self.assertEqual(self.catalog_repo.cat_path, "catalog_path")
211211

212212
# Check if _catalog is set correctly
213213
self.assertEqual(self.catalog_repo._catalog, "catalog_path")

0 commit comments

Comments
 (0)