From ce77d10849aa81e2ee373c28a6e812e0e359292b Mon Sep 17 00:00:00 2001 From: pciturri Date: Sun, 11 Aug 2024 00:17:35 +0200 Subject: [PATCH] refac: more control on the output of dicts from Model.as_dict() and Experiment.as_dict() --- floatcsep/experiment.py | 62 +++++++++++++++------------------ floatcsep/model.py | 7 ++-- floatcsep/repository.py | 25 ++++++------- floatcsep/utils.py | 12 ++----- tests/unit/test_experiment.py | 5 ++- tests/unit/test_repositories.py | 4 +-- 6 files changed, 54 insertions(+), 61 deletions(-) diff --git a/floatcsep/experiment.py b/floatcsep/experiment.py index 80c3487..6693d2a 100644 --- a/floatcsep/experiment.py +++ b/floatcsep/experiment.py @@ -1,13 +1,11 @@ import datetime -import json import logging import os import shutil import warnings from os.path import join, abspath, relpath, dirname, isfile, split, exists -from typing import Union, List, Dict, Callable, Sequence +from typing import Union, List, Dict, Sequence -import csep import numpy import yaml from cartopy import crs as ccrs @@ -23,13 +21,11 @@ from floatcsep.repository import ResultsRepository, CatalogRepository from floatcsep.utils import ( NoAliasLoader, - parse_csep_func, read_time_cfg, read_region_cfg, Task, TaskGraph, timewindow2str, - str2timewindow, magnitude_vs_time, parse_nested_dicts, ) @@ -184,7 +180,7 @@ def __init__( self.postproc_config = postproc_config if postproc_config else {} self.default_test_kwargs = default_test_kwargs - self.catalog_repo.set_catalog(catalog, self.time_config, self.region_config) + self.catalog_repo.set_main_catalog(catalog, self.time_config, self.region_config) self.models = self.set_models( models or kwargs.get("model_config"), kwargs.get("order", None) @@ -717,7 +713,7 @@ def make_repr(self): # Dropping region to results folder if it is a file region_path = self.region_config.get("path", False) - if region_path: + if isinstance(region_path, str): if isfile(region_path) and region_path: new_path = join(self.registry.rundir, self.region_config["path"]) shutil.copy2(region_path, new_path) @@ -726,10 +722,10 @@ def make_repr(self): # Dropping catalog to results folder target_cat = join( - self.registry.workdir, self.registry.rundir, split(self.catalog_repo._catpath)[-1] + self.registry.workdir, self.registry.rundir, split(self.catalog_repo.cat_path)[-1] ) if not exists(target_cat): - shutil.copy2(self.registry.abs(self.catalog_repo._catpath), target_cat) + shutil.copy2(self.registry.abs(self.catalog_repo.cat_path), target_cat) self._catpath = self.registry.rel(target_cat) relative_path = os.path.relpath( @@ -738,41 +734,41 @@ def make_repr(self): self.registry.workdir = relative_path self.to_yml(repr_config, extended=True) - def as_dict( - self, - exclude: Sequence = ( - "magnitudes", - "depths", - "timewindows", - "filetree", - "task_graph", - "tasks", - "models", - "tests", - "results_repo", - "catalog_repo", - ), - extended: bool = False, - ) -> dict: + def as_dict(self, extra: Sequence = (), extended=False) -> dict: """ Converts an Experiment instance into a dictionary. Args: - exclude (tuple, list): Attributes, or attribute keys, to ignore - extended (bool): Verbose representation of pycsep objects + extra: additional instance attribute to include in the dictionary. + extended: Include explicit parameters Returns: A dictionary with serialized instance's attributes, which are floatCSEP readable """ - listwalk = [(i, j) for i, j in self.__dict__.items() if not i.startswith("_") and j] - listwalk.insert(6, ("catalog", self.catalog_repo._catpath)) - - dictwalk = {i: j for i, j in listwalk} - dictwalk["path"] = dictwalk.pop("registry").workdir + dict_walk = { + "name": self.name, + "config_file": self.config_file, + "path": self.registry.workdir, + "run_dir": self.registry.rundir, + "time_config": { + i: j + for i, j in self.time_config.items() + if (i not in ("timewindows",) or extended) + }, + "region_config": { + i: j + for i, j in self.region_config.items() + if (i not in ("magnitudes", "depths") or extended) + }, + "catalog": self.catalog_repo.cat_path, + "models": [i.as_dict() for i in self.models], + "tests": [i.as_dict() for i in self.tests], + } + dict_walk.update(extra) - return parse_nested_dicts(dictwalk, excluded=exclude, extended=extended) + return parse_nested_dicts(dict_walk) def to_yml(self, filename: str, **kwargs) -> None: """ diff --git a/floatcsep/model.py b/floatcsep/model.py index 67c3eee..60f4cd0 100644 --- a/floatcsep/model.py +++ b/floatcsep/model.py @@ -3,10 +3,9 @@ import os from abc import ABC, abstractmethod from datetime import datetime -from typing import List, Callable, Union, Mapping, Sequence +from typing import List, Callable, Union, Sequence import git -import numpy from csep.core.forecasts import GriddedForecast, CatalogForecast from floatcsep.accessors import from_zenodo, from_git @@ -133,10 +132,10 @@ def as_dict(self, excluded=("name", "repository", "workdir")): (i, j) for i, j in sorted(self.__dict__.items()) if not i.startswith("_") and j ] - dict_walk = {i: j for i, j in list_walk} + dict_walk = {i: j for i, j in list_walk if i not in excluded} dict_walk["path"] = dict_walk.pop("registry").path - return {self.name: parse_nested_dicts(dict_walk, excluded=excluded)} + return {self.name: parse_nested_dicts(dict_walk)} @classmethod def from_dict(cls, record: dict, **kwargs): diff --git a/floatcsep/repository.py b/floatcsep/repository.py index e92e63b..02070c1 100644 --- a/floatcsep/repository.py +++ b/floatcsep/repository.py @@ -176,7 +176,6 @@ class ResultsRepository: def __init__(self, registry: ExperimentRegistry): self.registry = registry - self.a = 1 def _load_result( self, @@ -236,6 +235,8 @@ def default(self, obj): class CatalogRepository: def __init__(self, registry: ExperimentRegistry): + self.cat_path = None + self._catalog = None self.registry = registry self.time_config = {} self.region_config = {} @@ -270,7 +271,7 @@ def __getattr__(self, item: str) -> object: def as_dict(self): return - def set_catalog( + def set_main_catalog( self, catalog: Union[str, Callable, CSEPCatalog], time_config: dict, region_config: dict ): """ @@ -291,11 +292,11 @@ def catalog(self) -> CSEPCatalog: Returns a CSEP catalog loaded from the given query function or a stored file if it exists. """ - cat_path = self.registry.abs(self._catpath) + cat_path = self.registry.abs(self.cat_path) if callable(self._catalog): - if isfile(self._catpath): - return CSEPCatalog.load_json(self._catpath) + if isfile(self.cat_path): + return CSEPCatalog.load_json(self.cat_path) bounds = { "start_time": min([item for sublist in self.timewindows for item in sublist]), "end_time": max([item for sublist in self.timewindows for item in sublist]), @@ -318,7 +319,7 @@ def catalog(self) -> CSEPCatalog: if self.region: catalog.filter_spatial(region=self.region, in_place=True) catalog.region = None - catalog.write_json(self._catpath) + catalog.write_json(self.cat_path) return catalog @@ -333,19 +334,19 @@ def catalog(self, cat: Union[Callable, CSEPCatalog, str]) -> None: if cat is None: self._catalog = None - self._catpath = None + self.cat_path = None elif isfile(self.registry.abs(cat)): log.info(f"\tCatalog: '{cat}'") self._catalog = self.registry.rel(cat) - self._catpath = self.registry.rel(cat) + self.cat_path = self.registry.rel(cat) else: # catalog can be a function self._catalog = parse_csep_func(cat) - self._catpath = self.registry.abs("catalog.json") - if isfile(self._catpath): - log.info(f"\tCatalog: stored " f"'{self._catpath}' " f"from '{cat}'") + self.cat_path = self.registry.abs("catalog.json") + if isfile(self.cat_path): + log.info(f"\tCatalog: stored " f"'{self.cat_path}' " f"from '{cat}'") else: log.info(f"\tCatalog: '{cat}'") @@ -363,7 +364,7 @@ def get_test_cat(self, tstring: str = None) -> CSEPCatalog: else: start = self.start_date end = self.end_date - print(self.catalog) + sub_cat = self.catalog.filter( [ f"origin_time < {end.timestamp() * 1000}", diff --git a/floatcsep/utils.py b/floatcsep/utils.py index 1409e8b..63ebfab 100644 --- a/floatcsep/utils.py +++ b/floatcsep/utils.py @@ -407,11 +407,9 @@ def timewindows_td( # return timewindows -def parse_nested_dicts( - nested_dict: dict, excluded: Sequence = (), extended: bool = False -) -> dict: +def parse_nested_dicts(nested_dict: dict) -> dict: """ - Parses nested dictionaries to flatten them + Parses nested dictionaries to return appropriate parsing on each element """ def _get_value(x): @@ -435,11 +433,7 @@ def _get_value(x): def iter_attr(val): # recursive iter through nested dicts/lists if isinstance(val, Mapping): - return { - item: iter_attr(val_) - for item, val_ in val.items() - if ((item not in excluded) and val_) or extended - } + return {item: iter_attr(val_) for item, val_ in val.items()} elif isinstance(val, Sequence) and not isinstance(val, str): return [iter_attr(i) for i in val] else: diff --git a/tests/unit/test_experiment.py b/tests/unit/test_experiment.py index 0abdc79..c7f8c62 100644 --- a/tests/unit/test_experiment.py +++ b/tests/unit/test_experiment.py @@ -66,6 +66,9 @@ def test_to_dict(self): "name": "test", "path": os.getcwd(), "run_dir": "results", + "config_file": None, + "models": [], + "tests": [], "time_config": { "exp_class": "ti", "start_date": datetime(2020, 1, 1), @@ -109,7 +112,7 @@ def test_to_yml(self): self.assertEqualExperiment(exp_a, exp_b) file_ = tempfile.mkstemp()[1] - exp_a.to_yml(file_, extended=True) + exp_a.to_yml(file_) exp_c = Experiment.from_yml(file_) self.assertEqualExperiment(exp_a, exp_c) diff --git a/tests/unit/test_repositories.py b/tests/unit/test_repositories.py index 1f3a28b..7128bda 100644 --- a/tests/unit/test_repositories.py +++ b/tests/unit/test_repositories.py @@ -204,10 +204,10 @@ def test_set_catalog(self, mock_isfile): # Mock the registry's rel method to return the same path for simplicity self.mock_registry.rel.return_value = "catalog_path" - self.catalog_repo.set_catalog("catalog_path", {}, {}) + self.catalog_repo.set_main_catalog("catalog_path", {}, {}) # Check if _catpath is set correctly - self.assertEqual(self.catalog_repo._catpath, "catalog_path") + self.assertEqual(self.catalog_repo.cat_path, "catalog_path") # Check if _catalog is set correctly self.assertEqual(self.catalog_repo._catalog, "catalog_path")