Skip to content

Commit c75546a

Browse files
authored
Merge pull request #23 from cseptesting/22-remake-decouple-catalog-forecast-plots-from-experiment
22: Postprocess refactoring
2 parents d32c2c3 + 9f7a9cc commit c75546a

File tree

11 files changed

+343
-173
lines changed

11 files changed

+343
-173
lines changed

docs/examples/case_e.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ Models
8686
Post-Process
8787
~~~~~~~~~~~~
8888

89-
Additional options for post-processing can set using the ``postproc_config`` option.
89+
Additional options for post-processing can set using the ``postprocess`` option.
9090

9191
.. literalinclude:: ../../examples/case_e/config.yml
9292
:language: yaml

examples/case_e/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ catalog: query_bsi
1818
models: models.yml
1919
test_config: tests.yml
2020

21-
postproc_config:
21+
postprocess:
2222
plot_forecasts:
2323
region_border: True
2424
basemap: stock_img

examples/case_g/config.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,7 @@ region_config:
1717
force_rerun: True
1818
catalog: catalog.csv
1919
model_config: models.yml
20-
test_config: tests.yml
20+
test_config: tests.yml
21+
22+
postprocess:
23+
plot_custom: plot_script.py:main

examples/case_g/plot_script.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import matplotlib.pyplot as plt
2+
import numpy
3+
from matplotlib import pyplot
4+
5+
6+
def main(experiment):
7+
"""
8+
Example custom plot function
9+
"""
10+
11+
# Get all the timewindows
12+
timewindows = experiment.timewindows
13+
14+
# Get the pymock model
15+
model = experiment.get_model("pymock")
16+
17+
# Initialize the data lists to plot
18+
window_mid_time = []
19+
event_counts = []
20+
rate_mean = []
21+
rate_2std = []
22+
23+
for timewindow in timewindows:
24+
25+
# Get for a given timewindow and the model
26+
n_test_result = experiment.results_repo.load_results(
27+
"Catalog_N-test", timewindow, model
28+
)
29+
30+
# Append the results
31+
window_mid_time.append(timewindow[0] + (timewindow[1] - timewindow[0]) / 2)
32+
event_counts.append(n_test_result.observed_statistic)
33+
rate_mean.append(numpy.mean(n_test_result.test_distribution))
34+
rate_2std.append(2 * numpy.std(n_test_result.test_distribution))
35+
36+
# Create the figure
37+
fig, ax = plt.subplots(1, 1)
38+
39+
# Plot the observed number of events vs. time
40+
ax.plot(window_mid_time, event_counts, "bo", label="Observed catalog")
41+
42+
# Plot the forecasted mean rate and its error (2 * standard_deviation)
43+
ax.errorbar(
44+
window_mid_time,
45+
rate_mean,
46+
yerr=rate_2std,
47+
fmt="o",
48+
label="PyMock forecast",
49+
color="red",
50+
)
51+
52+
# Format and save figure
53+
ax.set_xticks([tw[0] for tw in timewindows] + [timewindows[-1][1]])
54+
fig.autofmt_xdate()
55+
ax.set_xlabel("Time")
56+
ax.set_ylabel(r"Number of events $M\geq 3.5$")
57+
pyplot.legend()
58+
pyplot.grid()
59+
pyplot.savefig("results/forecast_events_rates.png")

floatcsep/cmd/main.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from floatcsep.experiment import Experiment
66
from floatcsep.logger import setup_logger, set_console_log_level
77
from floatcsep.utils import ExperimentComparison
8+
from floatcsep.postprocess import plot_results, plot_forecasts, plot_catalogs, plot_custom
89

910
setup_logger()
1011
log = logging.getLogger("floatLogger")
@@ -13,7 +14,7 @@
1314
def stage(config, **_):
1415

1516
log.info(f"floatCSEP v{__version__} | Stage")
16-
exp = Experiment.from_yml(config)
17+
exp = Experiment.from_yml(config_yml=config)
1718
exp.stage_models()
1819

1920
log.info("Finalized")
@@ -23,12 +24,16 @@ def stage(config, **_):
2324
def run(config, **kwargs):
2425

2526
log.info(f"floatCSEP v{__version__} | Run")
26-
exp = Experiment.from_yml(config, **kwargs)
27+
exp = Experiment.from_yml(config_yml=config, **kwargs)
2728
exp.stage_models()
2829
exp.set_tasks()
2930
exp.run()
30-
exp.plot_results()
31-
exp.plot_forecasts()
31+
32+
plot_catalogs(experiment=exp)
33+
plot_forecasts(experiment=exp)
34+
plot_results(experiment=exp)
35+
plot_custom(experiment=exp)
36+
3237
exp.generate_report()
3338
exp.make_repr()
3439

@@ -40,14 +45,17 @@ def plot(config, **kwargs):
4045

4146
log.info(f"floatCSEP v{__version__} | Plot")
4247

43-
exp = Experiment.from_yml(config, **kwargs)
48+
exp = Experiment.from_yml(config_yml=config, **kwargs)
4449
exp.stage_models()
4550
exp.set_tasks()
46-
exp.plot_results()
47-
exp.plot_forecasts()
51+
52+
plot_catalogs(experiment=exp)
53+
plot_forecasts(experiment=exp)
54+
plot_results(experiment=exp)
55+
plot_custom(experiment=exp)
56+
4857
exp.generate_report()
4958

50-
log.info("Finalized\n")
5159
log.debug("")
5260

5361

floatcsep/experiment.py

Lines changed: 3 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77

88
import numpy
99
import yaml
10-
from cartopy import crs as ccrs
11-
from csep.core.catalogs import CSEPCatalog
12-
from csep.utils.time_utils import decimal_year
13-
from matplotlib import pyplot
1410

1511
from floatcsep import report
1612
from floatcsep.evaluation import Evaluation
@@ -25,7 +21,6 @@
2521
Task,
2622
TaskGraph,
2723
timewindow2str,
28-
magnitude_vs_time,
2924
parse_nested_dicts,
3025
)
3126

@@ -81,7 +76,7 @@ class Experiment:
8176
test_config (str): Path to the evaluations' configuration file
8277
default_test_kwargs (dict): Default values for the testing
8378
(seed, number of simulations, etc.)
84-
postproc_config (dict): Contains the instruction for postprocessing
79+
postprocess (dict): Contains the instruction for postprocessing
8580
(e.g. plot forecasts, catalogs)
8681
**kwargs: see Note
8782
@@ -116,7 +111,7 @@ def __init__(
116111
catalog: str = None,
117112
models: str = None,
118113
tests: str = None,
119-
postproc_config: str = None,
114+
postprocess: str = None,
120115
default_test_kwargs: dict = None,
121116
rundir: str = "results",
122117
report_hook: dict = None,
@@ -174,7 +169,7 @@ def __init__(
174169
self.models = []
175170
self.tests = []
176171

177-
self.postproc_config = postproc_config if postproc_config else {}
172+
self.postprocess = postprocess if postprocess else {}
178173
self.default_test_kwargs = default_test_kwargs
179174

180175
self.catalog_repo.set_main_catalog(catalog, self.time_config, self.region_config)
@@ -564,141 +559,6 @@ def read_results(self, test: Evaluation, window: str) -> List:
564559

565560
return test.read_results(window, self.models)
566561

567-
def plot_results(self) -> None:
568-
"""Plots all evaluation results."""
569-
log.info("Plotting evaluations")
570-
timewindows = timewindow2str(self.timewindows)
571-
572-
for test in self.tests:
573-
test.plot_results(timewindows, self.models, self.registry)
574-
575-
def plot_catalog(self, dpi: int = 300, show: bool = False) -> None:
576-
"""
577-
Plots the evaluation catalogs.
578-
579-
Args:
580-
dpi: Figure resolution with which to save
581-
show: show in runtime
582-
"""
583-
plot_args = {
584-
"basemap": "ESRI_terrain",
585-
"figsize": (12, 8),
586-
"markersize": 8,
587-
"markercolor": "black",
588-
"grid_fontsize": 16,
589-
"title": "",
590-
"legend": True,
591-
}
592-
plot_args.update(self.postproc_config.get("plot_catalog", {}))
593-
catalog = self.catalog_repo.get_test_cat()
594-
if catalog.get_number_of_events() != 0:
595-
ax = catalog.plot(plot_args=plot_args, show=show)
596-
ax.get_figure().tight_layout()
597-
ax.get_figure().savefig(self.registry.get_figure("main_catalog_map"), dpi=dpi)
598-
599-
ax2 = magnitude_vs_time(catalog)
600-
ax2.get_figure().tight_layout()
601-
ax2.get_figure().savefig(self.registry.get_figure("main_catalog_time"), dpi=dpi)
602-
603-
if self.postproc_config.get("all_time_windows"):
604-
timewindow = self.timewindows
605-
606-
for tw in timewindow:
607-
catpath = self.registry.get_test_catalog(tw)
608-
catalog = CSEPCatalog.load_json(catpath)
609-
if catalog.get_number_of_events() != 0:
610-
ax = catalog.plot(plot_args=plot_args, show=show)
611-
ax.get_figure().tight_layout()
612-
ax.get_figure().savefig(
613-
self.registry.get_figure(tw, "catalog_map"), dpi=dpi
614-
)
615-
616-
ax2 = magnitude_vs_time(catalog)
617-
ax2.get_figure().tight_layout()
618-
ax2.get_figure().savefig(
619-
self.registry.get_figure(tw, "catalog_time"), dpi=dpi
620-
)
621-
622-
def plot_forecasts(self) -> None:
623-
"""Plots and saves all the generated forecasts."""
624-
625-
plot_fc_config = self.postproc_config.get("plot_forecasts")
626-
if plot_fc_config:
627-
log.info("Plotting forecasts")
628-
if plot_fc_config is True:
629-
plot_fc_config = {}
630-
try:
631-
proj_ = plot_fc_config.get("projection")
632-
if isinstance(proj_, dict):
633-
proj_name = list(proj_.keys())[0]
634-
proj_args = list(proj_.values())[0]
635-
else:
636-
proj_name = proj_
637-
proj_args = {}
638-
plot_fc_config["projection"] = getattr(ccrs, proj_name)(**proj_args)
639-
except (IndexError, KeyError, TypeError, AttributeError):
640-
plot_fc_config["projection"] = ccrs.PlateCarree(central_longitude=0.0)
641-
642-
cat = plot_fc_config.get("catalog")
643-
cat_args = {}
644-
if cat:
645-
cat_args = {
646-
"markersize": 7,
647-
"markercolor": "black",
648-
"title": "asd",
649-
"grid": False,
650-
"legend": False,
651-
"basemap": None,
652-
"region_border": False,
653-
}
654-
if self.region:
655-
self.catalog.filter_spatial(self.region, in_place=True)
656-
if isinstance(cat, dict):
657-
cat_args.update(cat)
658-
659-
window = self.timewindows[-1]
660-
winstr = timewindow2str(window)
661-
662-
for model in self.models:
663-
fig_path = self.registry.get_figure(winstr, "forecasts", model.name)
664-
start = decimal_year(window[0])
665-
end = decimal_year(window[1])
666-
time = f"{round(end - start, 3)} years"
667-
plot_args = {
668-
"region_border": False,
669-
"cmap": "magma",
670-
"clabel": r"$\log_{10} N\left(M_w \in [{%.2f},"
671-
r"\,{%.2f}]\right)$ per "
672-
r"$0.1^\circ\times 0.1^\circ $ per %s"
673-
% (min(self.magnitudes), max(self.magnitudes), time),
674-
}
675-
if not self.region or self.region.name == "global":
676-
set_global = True
677-
else:
678-
set_global = False
679-
plot_args.update(plot_fc_config)
680-
ax = model.get_forecast(winstr, self.region).plot(
681-
set_global=set_global, plot_args=plot_args
682-
)
683-
684-
if self.region:
685-
bbox = self.region.get_bbox()
686-
dh = self.region.dh
687-
extent = [
688-
bbox[0] - 3 * dh,
689-
bbox[1] + 3 * dh,
690-
bbox[2] - 3 * dh,
691-
bbox[3] + 3 * dh,
692-
]
693-
else:
694-
extent = None
695-
if cat:
696-
self.catalog.plot(
697-
ax=ax, set_global=set_global, extent=extent, plot_args=cat_args
698-
)
699-
700-
pyplot.savefig(fig_path, dpi=300, facecolor=(0, 0, 0, 0))
701-
702562
def generate_report(self) -> None:
703563
"""Creates a report summarizing the Experiment's results."""
704564

0 commit comments

Comments
 (0)