From 9f7a9cc39b6cf93d701f81d33fef1b1bbde7f3f9 Mon Sep 17 00:00:00 2001 From: pciturri Date: Thu, 15 Aug 2024 12:09:44 +0200 Subject: [PATCH] ft: Added script hooks for custom postprocess plots. refac: Removed plot_catalogs/results/forecasts methods from the Experiment class and located them as functions in postprocess.py, now taking an experiment instance as argument. Simplified the functions for readability. Changed `postproc_config` argument to `postprocess`. fix: patches from test_data --- docs/examples/case_e.rst | 2 +- examples/case_e/config.yml | 2 +- examples/case_g/config.yml | 5 +- examples/case_g/plot_script.py | 59 ++++++++ floatcsep/cmd/main.py | 24 ++-- floatcsep/experiment.py | 146 +------------------- floatcsep/postprocess.py | 237 +++++++++++++++++++++++++++++++++ floatcsep/registry.py | 10 +- floatcsep/report.py | 1 - floatcsep/repository.py | 18 +-- tests/qa/test_data.py | 12 +- 11 files changed, 343 insertions(+), 173 deletions(-) create mode 100644 examples/case_g/plot_script.py create mode 100644 floatcsep/postprocess.py diff --git a/docs/examples/case_e.rst b/docs/examples/case_e.rst index 58abceb..92d3093 100644 --- a/docs/examples/case_e.rst +++ b/docs/examples/case_e.rst @@ -86,7 +86,7 @@ Models Post-Process ~~~~~~~~~~~~ - Additional options for post-processing can set using the ``postproc_config`` option. + Additional options for post-processing can set using the ``postprocess`` option. .. literalinclude:: ../../examples/case_e/config.yml :language: yaml diff --git a/examples/case_e/config.yml b/examples/case_e/config.yml index 020686a..1f235e5 100644 --- a/examples/case_e/config.yml +++ b/examples/case_e/config.yml @@ -18,7 +18,7 @@ catalog: query_bsi models: models.yml test_config: tests.yml -postproc_config: +postprocess: plot_forecasts: region_border: True basemap: stock_img diff --git a/examples/case_g/config.yml b/examples/case_g/config.yml index 5a56a93..b36f4f1 100644 --- a/examples/case_g/config.yml +++ b/examples/case_g/config.yml @@ -17,4 +17,7 @@ region_config: force_rerun: True catalog: catalog.csv model_config: models.yml -test_config: tests.yml \ No newline at end of file +test_config: tests.yml + +postprocess: + plot_custom: plot_script.py:main \ No newline at end of file diff --git a/examples/case_g/plot_script.py b/examples/case_g/plot_script.py new file mode 100644 index 0000000..6f9d62f --- /dev/null +++ b/examples/case_g/plot_script.py @@ -0,0 +1,59 @@ +import matplotlib.pyplot as plt +import numpy +from matplotlib import pyplot + + +def main(experiment): + """ + Example custom plot function + """ + + # Get all the timewindows + timewindows = experiment.timewindows + + # Get the pymock model + model = experiment.get_model("pymock") + + # Initialize the data lists to plot + window_mid_time = [] + event_counts = [] + rate_mean = [] + rate_2std = [] + + for timewindow in timewindows: + + # Get for a given timewindow and the model + n_test_result = experiment.results_repo.load_results( + "Catalog_N-test", timewindow, model + ) + + # Append the results + window_mid_time.append(timewindow[0] + (timewindow[1] - timewindow[0]) / 2) + event_counts.append(n_test_result.observed_statistic) + rate_mean.append(numpy.mean(n_test_result.test_distribution)) + rate_2std.append(2 * numpy.std(n_test_result.test_distribution)) + + # Create the figure + fig, ax = plt.subplots(1, 1) + + # Plot the observed number of events vs. time + ax.plot(window_mid_time, event_counts, "bo", label="Observed catalog") + + # Plot the forecasted mean rate and its error (2 * standard_deviation) + ax.errorbar( + window_mid_time, + rate_mean, + yerr=rate_2std, + fmt="o", + label="PyMock forecast", + color="red", + ) + + # Format and save figure + ax.set_xticks([tw[0] for tw in timewindows] + [timewindows[-1][1]]) + fig.autofmt_xdate() + ax.set_xlabel("Time") + ax.set_ylabel(r"Number of events $M\geq 3.5$") + pyplot.legend() + pyplot.grid() + pyplot.savefig("results/forecast_events_rates.png") diff --git a/floatcsep/cmd/main.py b/floatcsep/cmd/main.py index 46f995d..efeac17 100644 --- a/floatcsep/cmd/main.py +++ b/floatcsep/cmd/main.py @@ -5,6 +5,7 @@ from floatcsep.experiment import Experiment from floatcsep.logger import setup_logger, set_console_log_level from floatcsep.utils import ExperimentComparison +from floatcsep.postprocess import plot_results, plot_forecasts, plot_catalogs, plot_custom setup_logger() log = logging.getLogger("floatLogger") @@ -13,7 +14,7 @@ def stage(config, **_): log.info(f"floatCSEP v{__version__} | Stage") - exp = Experiment.from_yml(config) + exp = Experiment.from_yml(config_yml=config) exp.stage_models() log.info("Finalized") @@ -23,12 +24,16 @@ def stage(config, **_): def run(config, **kwargs): log.info(f"floatCSEP v{__version__} | Run") - exp = Experiment.from_yml(config, **kwargs) + exp = Experiment.from_yml(config_yml=config, **kwargs) exp.stage_models() exp.set_tasks() exp.run() - exp.plot_results() - exp.plot_forecasts() + + plot_catalogs(experiment=exp) + plot_forecasts(experiment=exp) + plot_results(experiment=exp) + plot_custom(experiment=exp) + exp.generate_report() exp.make_repr() @@ -40,14 +45,17 @@ def plot(config, **kwargs): log.info(f"floatCSEP v{__version__} | Plot") - exp = Experiment.from_yml(config, **kwargs) + exp = Experiment.from_yml(config_yml=config, **kwargs) exp.stage_models() exp.set_tasks() - exp.plot_results() - exp.plot_forecasts() + + plot_catalogs(experiment=exp) + plot_forecasts(experiment=exp) + plot_results(experiment=exp) + plot_custom(experiment=exp) + exp.generate_report() - log.info("Finalized\n") log.debug("") diff --git a/floatcsep/experiment.py b/floatcsep/experiment.py index 19b14b7..b4b8efa 100644 --- a/floatcsep/experiment.py +++ b/floatcsep/experiment.py @@ -7,10 +7,6 @@ import numpy import yaml -from cartopy import crs as ccrs -from csep.core.catalogs import CSEPCatalog -from csep.utils.time_utils import decimal_year -from matplotlib import pyplot from floatcsep import report from floatcsep.evaluation import Evaluation @@ -25,7 +21,6 @@ Task, TaskGraph, timewindow2str, - magnitude_vs_time, parse_nested_dicts, ) @@ -81,7 +76,7 @@ class Experiment: test_config (str): Path to the evaluations' configuration file default_test_kwargs (dict): Default values for the testing (seed, number of simulations, etc.) - postproc_config (dict): Contains the instruction for postprocessing + postprocess (dict): Contains the instruction for postprocessing (e.g. plot forecasts, catalogs) **kwargs: see Note @@ -116,7 +111,7 @@ def __init__( catalog: str = None, models: str = None, tests: str = None, - postproc_config: str = None, + postprocess: str = None, default_test_kwargs: dict = None, rundir: str = "results", report_hook: dict = None, @@ -174,7 +169,7 @@ def __init__( self.models = [] self.tests = [] - self.postproc_config = postproc_config if postproc_config else {} + self.postprocess = postprocess if postprocess else {} self.default_test_kwargs = default_test_kwargs self.catalog_repo.set_main_catalog(catalog, self.time_config, self.region_config) @@ -564,141 +559,6 @@ def read_results(self, test: Evaluation, window: str) -> List: return test.read_results(window, self.models) - def plot_results(self) -> None: - """Plots all evaluation results.""" - log.info("Plotting evaluations") - timewindows = timewindow2str(self.timewindows) - - for test in self.tests: - test.plot_results(timewindows, self.models, self.registry) - - def plot_catalog(self, dpi: int = 300, show: bool = False) -> None: - """ - Plots the evaluation catalogs. - - Args: - dpi: Figure resolution with which to save - show: show in runtime - """ - plot_args = { - "basemap": "ESRI_terrain", - "figsize": (12, 8), - "markersize": 8, - "markercolor": "black", - "grid_fontsize": 16, - "title": "", - "legend": True, - } - plot_args.update(self.postproc_config.get("plot_catalog", {})) - catalog = self.catalog_repo.get_test_cat() - if catalog.get_number_of_events() != 0: - ax = catalog.plot(plot_args=plot_args, show=show) - ax.get_figure().tight_layout() - ax.get_figure().savefig(self.registry.get_figure("main_catalog_map"), dpi=dpi) - - ax2 = magnitude_vs_time(catalog) - ax2.get_figure().tight_layout() - ax2.get_figure().savefig(self.registry.get_figure("main_catalog_time"), dpi=dpi) - - if self.postproc_config.get("all_time_windows"): - timewindow = self.timewindows - - for tw in timewindow: - catpath = self.registry.get_test_catalog(tw) - catalog = CSEPCatalog.load_json(catpath) - if catalog.get_number_of_events() != 0: - ax = catalog.plot(plot_args=plot_args, show=show) - ax.get_figure().tight_layout() - ax.get_figure().savefig( - self.registry.get_figure(tw, "catalog_map"), dpi=dpi - ) - - ax2 = magnitude_vs_time(catalog) - ax2.get_figure().tight_layout() - ax2.get_figure().savefig( - self.registry.get_figure(tw, "catalog_time"), dpi=dpi - ) - - def plot_forecasts(self) -> None: - """Plots and saves all the generated forecasts.""" - - plot_fc_config = self.postproc_config.get("plot_forecasts") - if plot_fc_config: - log.info("Plotting forecasts") - if plot_fc_config is True: - plot_fc_config = {} - try: - proj_ = plot_fc_config.get("projection") - if isinstance(proj_, dict): - proj_name = list(proj_.keys())[0] - proj_args = list(proj_.values())[0] - else: - proj_name = proj_ - proj_args = {} - plot_fc_config["projection"] = getattr(ccrs, proj_name)(**proj_args) - except (IndexError, KeyError, TypeError, AttributeError): - plot_fc_config["projection"] = ccrs.PlateCarree(central_longitude=0.0) - - cat = plot_fc_config.get("catalog") - cat_args = {} - if cat: - cat_args = { - "markersize": 7, - "markercolor": "black", - "title": "asd", - "grid": False, - "legend": False, - "basemap": None, - "region_border": False, - } - if self.region: - self.catalog.filter_spatial(self.region, in_place=True) - if isinstance(cat, dict): - cat_args.update(cat) - - window = self.timewindows[-1] - winstr = timewindow2str(window) - - for model in self.models: - fig_path = self.registry.get_figure(winstr, "forecasts", model.name) - start = decimal_year(window[0]) - end = decimal_year(window[1]) - time = f"{round(end - start, 3)} years" - plot_args = { - "region_border": False, - "cmap": "magma", - "clabel": r"$\log_{10} N\left(M_w \in [{%.2f}," - r"\,{%.2f}]\right)$ per " - r"$0.1^\circ\times 0.1^\circ $ per %s" - % (min(self.magnitudes), max(self.magnitudes), time), - } - if not self.region or self.region.name == "global": - set_global = True - else: - set_global = False - plot_args.update(plot_fc_config) - ax = model.get_forecast(winstr, self.region).plot( - set_global=set_global, plot_args=plot_args - ) - - if self.region: - bbox = self.region.get_bbox() - dh = self.region.dh - extent = [ - bbox[0] - 3 * dh, - bbox[1] + 3 * dh, - bbox[2] - 3 * dh, - bbox[3] + 3 * dh, - ] - else: - extent = None - if cat: - self.catalog.plot( - ax=ax, set_global=set_global, extent=extent, plot_args=cat_args - ) - - pyplot.savefig(fig_path, dpi=300, facecolor=(0, 0, 0, 0)) - def generate_report(self) -> None: """Creates a report summarizing the Experiment's results.""" diff --git a/floatcsep/postprocess.py b/floatcsep/postprocess.py new file mode 100644 index 0000000..a07fbab --- /dev/null +++ b/floatcsep/postprocess.py @@ -0,0 +1,237 @@ +import importlib.util +import logging +import os +from datetime import datetime +from typing import TYPE_CHECKING, Union + +from cartopy import crs as ccrs +from matplotlib import pyplot + +from floatcsep.utils import ( + timewindow2str, + magnitude_vs_time, +) + +if TYPE_CHECKING: + from floatcsep.experiment import Experiment + +log = logging.getLogger("floatLogger") + + +def plot_results(experiment: "Experiment") -> None: + """Plots all evaluation results.""" + log.info("Plotting evaluation results") + timewindows = timewindow2str(experiment.timewindows) + + for test in experiment.tests: + test.plot_results(timewindows, experiment.models, experiment.registry) + + +def plot_forecasts(experiment: "Experiment") -> None: + """Plots and saves all the generated forecasts.""" + + # Parsing plot configuration file + plot_forecast_config: dict = parse_plot_config( + experiment.postprocess.get("plot_forecasts", {}) + ) + if not isinstance(plot_forecast_config, dict): + return + + ##################################### + # Default forecast plotting function. + ##################################### + log.info("Plotting forecasts") + + # Get the time windows to be plotted. Defaults to only the last time window. + time_windows: list[list[datetime]] = ( + timewindow2str(experiment.timewindows) + if plot_forecast_config.get("all_time_windows") + else [timewindow2str(experiment.timewindows[-1])] + ) + + # Get the projection of the plots + plot_forecast_config["projection"]: ccrs.Projection = parse_projection( + plot_forecast_config.get("projection") + ) + + for model in experiment.models: + for window in time_windows: + ax = model.get_forecast(window, experiment.region).plot( + plot_args=plot_forecast_config + ) + + # If catalog option is passed, catalog is plotted on top of the forecast + if plot_forecast_config.get("catalog"): + cat_args = plot_forecast_config.get("catalog", {}) + experiment.catalog_repo.get_test_cat(window).plot( + ax=ax, + extent=ax.get_extent(), + plot_args=cat_args.update( + { + "basemap": plot_forecast_config.get("basemap", None), + "title": ax.get_title(), + } + ), + ) + fig_path = experiment.registry.get_figure(window, "forecasts", model.name) + pyplot.savefig(fig_path, dpi=plot_forecast_config.get("dpi", 300)) + + +def plot_catalogs(experiment: "Experiment") -> None: + + # Parsing plot configuration file + plot_catalog_config: dict = parse_plot_config( + experiment.postprocess.get("plot_catalog", {}) + ) + if not isinstance(plot_catalog_config, dict): + return + + #################################### + # Default catalog plotting function. + #################################### + log.info("Plotting catalogs") + + # Get the projection of the plots + plot_catalog_config["projection"]: ccrs.Projection = parse_projection( + plot_catalog_config.get("projection") + ) + # Get the start and end dates of the experiment (as a string) + experiment_timewindow = timewindow2str([experiment.start_date, experiment.end_date]) + + # Get the catalog for the entire duration of the experiment + main_catalog = experiment.catalog_repo.get_test_cat(experiment_timewindow) + + # Skip plotting if no events + if main_catalog.get_number_of_events() == 0: + log.debug(f"Catalog has zero events in {experiment_timewindow}") + return + + # Plot catalog map + ax = main_catalog.plot(plot_args=plot_catalog_config) + cat_map_path = experiment.registry.get_figure("main_catalog_map") + ax.get_figure().savefig(cat_map_path, dpi=plot_catalog_config.get("dpi", 300)) + + # Plot catalog time series vs. magnitude + ax = magnitude_vs_time(main_catalog) + cat_time_path = experiment.registry.get_figure("main_catalog_time") + ax.get_figure().savefig(cat_time_path, dpi=plot_catalog_config.get("dpi", 300)) + + # If selected, plot the test catalogs for each of the time windows + if plot_catalog_config.get("all_time_windows"): + for tw in experiment.timewindows: + test_catalog = experiment.catalog_repo.get_test_cat(timewindow2str(tw)) + + if test_catalog.get_number_of_events() != 0: + log.debug(f"Catalog has zero events in {tw}. Skip plotting") + continue + + ax = test_catalog.plot(plot_args=plot_catalog_config) + cat_map_path = experiment.registry.get_figure(tw, "catalog_map") + ax.get_figure().savefig(cat_map_path, dpi=plot_catalog_config.get("dpi", 300)) + + ax = magnitude_vs_time(test_catalog) + cat_time_path = experiment.registry.get_figure(tw, "catalog_time") + ax.get_figure().savefig(cat_time_path, dpi=plot_catalog_config.get("dpi", 300)) + + +def plot_custom(experiment: "Experiment"): + + plot_config = parse_plot_config(experiment.postprocess.get("plot_custom", None)) + if plot_config is None: + return + script_path, func_name = plot_config + + log.info(f"Plotting from script {script_path} and function {func_name}") + script_abs_path = experiment.registry.abs(script_path) + allowed_directory = os.path.dirname(experiment.registry.abs(experiment.config_file)) + + if not os.path.isfile(script_path) or ( + os.path.dirname(script_abs_path) != os.path.realpath(allowed_directory) + ): + + log.error(f"Script {script_path} is not in the configuration file directory.") + log.info( + "\t Skipping plotting. Script can be reallocated and re-run the plotting only" + " by typing 'floatcsep run {config}'" + ) + return + + module_name = os.path.splitext(os.path.basename(script_abs_path))[0] + spec = importlib.util.spec_from_file_location(module_name, script_abs_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + # Execute the script securely + try: + func = getattr(module, func_name) + + except AttributeError: + log.error(f"Function {func_name} not found in {script_path}") + log.info( + "\t Skipping plotting. Plot script can be modified and re-run the plotting only" + " by typing 'floatcsep run {config}'" + ) + return + + try: + func(experiment) + except Exception as e: + log.error(f"Error executing {func_name} from {script_path}: {e}") + log.info( + "\t Skipping plotting. Plot script can be modified and re-run the plotting only" + " by typing 'floatcsep run {config}'" + ) + return + + +def parse_plot_config(plot_config: Union[dict, str, bool]): + + if plot_config is True: + return {} + + elif plot_config in (None, False): + return + + elif isinstance(plot_config, dict): + return plot_config + + elif isinstance(plot_config, str): + # Parse the script path and function name + try: + script_path, func_name = plot_config.split(".py:") + script_path += ".py" + return script_path, func_name + except ValueError: + log.error( + f"Invalid format for custom plot function: {plot_config}. " + "Try {script_name}.py:{func}" + ) + log.info( + "\t Skipping plotting. The script can be modified and re-run the plotting only " + "by typing 'floatcsep run {config}'" + ) + return + + else: + log.error("Plot configuration not understood. Skipping plotting") + return + + +def parse_projection(proj_config: Union[dict, str, bool]): + """Retrieve projection configuration. + e.g., as defined in the config file: + projection: + Mercator: + central_longitude: 0.0 + """ + if proj_config is None: + return ccrs.PlateCarree(central_longitude=0.0) + + if isinstance(proj_config, dict): + proj_name, proj_args = next(iter(proj_config.items())) + else: + proj_name, proj_args = proj_config, {} + + if not isinstance(proj_name, str): + return ccrs.PlateCarree(central_longitude=0.0) + + return getattr(ccrs, proj_name, ccrs.PlateCarree)(**proj_args) diff --git a/floatcsep/registry.py b/floatcsep/registry.py index cd90119..42838b2 100644 --- a/floatcsep/registry.py +++ b/floatcsep/registry.py @@ -14,7 +14,7 @@ log = logging.getLogger("floatLogger") -class BaseFileRegistry(ABC): +class FileRegistry(ABC): def __init__(self, workdir: str) -> None: self.workdir = workdir @@ -79,7 +79,7 @@ def file_exists(self, *args: Sequence[str]): return exists(file_abspath) -class ForecastRegistry(BaseFileRegistry): +class ForecastRegistry(FileRegistry): def __init__( self, workdir: str, @@ -214,7 +214,7 @@ def log_tree(self) -> None: log.debug(f" Time Window: {timewindow}") -class ExperimentRegistry(BaseFileRegistry): +class ExperimentRegistry(FileRegistry): def __init__(self, workdir: str, run_dir: str = "results") -> None: super().__init__(workdir) self.run_dir = run_dir @@ -349,8 +349,8 @@ def build_tree( **{ win: { **{test: join(win, "figures", f"{test}") for test in tests}, - "catalog": join(win, "figures", "catalog_map"), - "magnitude_time": join(win, "figures", "catalog_time"), + "catalog_map": join(win, "figures", "catalog_map"), + "catalog_time": join(win, "figures", "catalog_time"), "forecasts": { model: join(win, "figures", f"forecast_{model}") for model in models }, diff --git a/floatcsep/report.py b/floatcsep/report.py index 7872e12..51aa9a1 100644 --- a/floatcsep/report.py +++ b/floatcsep/report.py @@ -40,7 +40,6 @@ def generate_report(experiment, timewindow=-1): # Generate catalog plot if experiment.catalog_repo.catalog is not None: - experiment.plot_catalog() report.add_figure( f"Input catalog", [ diff --git a/floatcsep/repository.py b/floatcsep/repository.py index 9c1765b..ba7779a 100644 --- a/floatcsep/repository.py +++ b/floatcsep/repository.py @@ -199,19 +199,21 @@ def load_results( self, test, window: Union[str, Sequence[datetime.datetime]], - models: List, - ) -> List: + models: Union[list["Model"], "Model"], + ) -> Union[List, EvaluationResult]: """ Reads an Evaluation result for a given time window and returns a list of the results for all tested models. """ - test_results = [] - for model in models: - model_eval = self._load_result(test, window, model) - test_results.append(model_eval) - - return test_results + if isinstance(models, list): + test_results = [] + for model in models: + model_eval = self._load_result(test, window, model) + test_results.append(model_eval) + return test_results + else: + return self._load_result(test, window, models) def write_result(self, result: EvaluationResult, test, model, window) -> None: diff --git a/tests/qa/test_data.py b/tests/qa/test_data.py index 0e984ab..872b4be 100644 --- a/tests/qa/test_data.py +++ b/tests/qa/test_data.py @@ -1,5 +1,6 @@ from floatcsep.cmd import main from floatcsep.experiment import Experiment + import unittest from unittest.mock import patch import os @@ -34,8 +35,9 @@ def get_eval_dist(self): @patch.object(Experiment, "generate_report") -@patch.object(Experiment, "plot_forecasts") -@patch.object(Experiment, "plot_catalog") +@patch("floatcsep.cmd.main.plot_forecasts") +@patch("floatcsep.cmd.main.plot_catalogs") +@patch("floatcsep.cmd.main.plot_custom") class RunExamples(DataTest): def test_case_a(self, *args): @@ -74,9 +76,9 @@ def test_case_g(self, *args): self.assertEqual(1, 1) -@patch.object(Experiment, "generate_report") -@patch.object(Experiment, "plot_forecasts") -@patch.object(Experiment, "plot_catalog") +@patch("floatcsep.cmd.main.plot_forecasts") +@patch("floatcsep.cmd.main.plot_catalogs") +@patch("floatcsep.cmd.main.plot_custom") class ReproduceExamples(DataTest): def test_case_c(self, *args):