Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/case_e.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Models
Post-Process
~~~~~~~~~~~~

Additional options for post-processing can set using the ``postproc_config`` option.
Additional options for post-processing can set using the ``postprocess`` option.

.. literalinclude:: ../../examples/case_e/config.yml
:language: yaml
Expand Down
2 changes: 1 addition & 1 deletion examples/case_e/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ catalog: query_bsi
models: models.yml
test_config: tests.yml

postproc_config:
postprocess:
plot_forecasts:
region_border: True
basemap: stock_img
Expand Down
5 changes: 4 additions & 1 deletion examples/case_g/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,7 @@ region_config:
force_rerun: True
catalog: catalog.csv
model_config: models.yml
test_config: tests.yml
test_config: tests.yml

postprocess:
plot_custom: plot_script.py:main
59 changes: 59 additions & 0 deletions examples/case_g/plot_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import matplotlib.pyplot as plt
import numpy
from matplotlib import pyplot


def main(experiment):
"""
Example custom plot function
"""

# Get all the timewindows
timewindows = experiment.timewindows

# Get the pymock model
model = experiment.get_model("pymock")

# Initialize the data lists to plot
window_mid_time = []
event_counts = []
rate_mean = []
rate_2std = []

for timewindow in timewindows:

# Get for a given timewindow and the model
n_test_result = experiment.results_repo.load_results(
"Catalog_N-test", timewindow, model
)

# Append the results
window_mid_time.append(timewindow[0] + (timewindow[1] - timewindow[0]) / 2)
event_counts.append(n_test_result.observed_statistic)
rate_mean.append(numpy.mean(n_test_result.test_distribution))
rate_2std.append(2 * numpy.std(n_test_result.test_distribution))

# Create the figure
fig, ax = plt.subplots(1, 1)

# Plot the observed number of events vs. time
ax.plot(window_mid_time, event_counts, "bo", label="Observed catalog")

# Plot the forecasted mean rate and its error (2 * standard_deviation)
ax.errorbar(
window_mid_time,
rate_mean,
yerr=rate_2std,
fmt="o",
label="PyMock forecast",
color="red",
)

# Format and save figure
ax.set_xticks([tw[0] for tw in timewindows] + [timewindows[-1][1]])
fig.autofmt_xdate()
ax.set_xlabel("Time")
ax.set_ylabel(r"Number of events $M\geq 3.5$")
pyplot.legend()
pyplot.grid()
pyplot.savefig("results/forecast_events_rates.png")
24 changes: 16 additions & 8 deletions floatcsep/cmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from floatcsep.experiment import Experiment
from floatcsep.logger import setup_logger, set_console_log_level
from floatcsep.utils import ExperimentComparison
from floatcsep.postprocess import plot_results, plot_forecasts, plot_catalogs, plot_custom

setup_logger()
log = logging.getLogger("floatLogger")
Expand All @@ -13,7 +14,7 @@
def stage(config, **_):

log.info(f"floatCSEP v{__version__} | Stage")
exp = Experiment.from_yml(config)
exp = Experiment.from_yml(config_yml=config)
exp.stage_models()

log.info("Finalized")
Expand All @@ -23,12 +24,16 @@ def stage(config, **_):
def run(config, **kwargs):

log.info(f"floatCSEP v{__version__} | Run")
exp = Experiment.from_yml(config, **kwargs)
exp = Experiment.from_yml(config_yml=config, **kwargs)
exp.stage_models()
exp.set_tasks()
exp.run()
exp.plot_results()
exp.plot_forecasts()

plot_catalogs(experiment=exp)
plot_forecasts(experiment=exp)
plot_results(experiment=exp)
plot_custom(experiment=exp)

exp.generate_report()
exp.make_repr()

Expand All @@ -40,14 +45,17 @@ def plot(config, **kwargs):

log.info(f"floatCSEP v{__version__} | Plot")

exp = Experiment.from_yml(config, **kwargs)
exp = Experiment.from_yml(config_yml=config, **kwargs)
exp.stage_models()
exp.set_tasks()
exp.plot_results()
exp.plot_forecasts()

plot_catalogs(experiment=exp)
plot_forecasts(experiment=exp)
plot_results(experiment=exp)
plot_custom(experiment=exp)

exp.generate_report()

log.info("Finalized\n")
log.debug("")


Expand Down
146 changes: 3 additions & 143 deletions floatcsep/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@

import numpy
import yaml
from cartopy import crs as ccrs
from csep.core.catalogs import CSEPCatalog
from csep.utils.time_utils import decimal_year
from matplotlib import pyplot

from floatcsep import report
from floatcsep.evaluation import Evaluation
Expand All @@ -25,7 +21,6 @@
Task,
TaskGraph,
timewindow2str,
magnitude_vs_time,
parse_nested_dicts,
)

Expand Down Expand Up @@ -81,7 +76,7 @@ class Experiment:
test_config (str): Path to the evaluations' configuration file
default_test_kwargs (dict): Default values for the testing
(seed, number of simulations, etc.)
postproc_config (dict): Contains the instruction for postprocessing
postprocess (dict): Contains the instruction for postprocessing
(e.g. plot forecasts, catalogs)
**kwargs: see Note

Expand Down Expand Up @@ -116,7 +111,7 @@ def __init__(
catalog: str = None,
models: str = None,
tests: str = None,
postproc_config: str = None,
postprocess: str = None,
default_test_kwargs: dict = None,
rundir: str = "results",
report_hook: dict = None,
Expand Down Expand Up @@ -174,7 +169,7 @@ def __init__(
self.models = []
self.tests = []

self.postproc_config = postproc_config if postproc_config else {}
self.postprocess = postprocess if postprocess else {}
self.default_test_kwargs = default_test_kwargs

self.catalog_repo.set_main_catalog(catalog, self.time_config, self.region_config)
Expand Down Expand Up @@ -564,141 +559,6 @@ def read_results(self, test: Evaluation, window: str) -> List:

return test.read_results(window, self.models)

def plot_results(self) -> None:
"""Plots all evaluation results."""
log.info("Plotting evaluations")
timewindows = timewindow2str(self.timewindows)

for test in self.tests:
test.plot_results(timewindows, self.models, self.registry)

def plot_catalog(self, dpi: int = 300, show: bool = False) -> None:
"""
Plots the evaluation catalogs.

Args:
dpi: Figure resolution with which to save
show: show in runtime
"""
plot_args = {
"basemap": "ESRI_terrain",
"figsize": (12, 8),
"markersize": 8,
"markercolor": "black",
"grid_fontsize": 16,
"title": "",
"legend": True,
}
plot_args.update(self.postproc_config.get("plot_catalog", {}))
catalog = self.catalog_repo.get_test_cat()
if catalog.get_number_of_events() != 0:
ax = catalog.plot(plot_args=plot_args, show=show)
ax.get_figure().tight_layout()
ax.get_figure().savefig(self.registry.get_figure("main_catalog_map"), dpi=dpi)

ax2 = magnitude_vs_time(catalog)
ax2.get_figure().tight_layout()
ax2.get_figure().savefig(self.registry.get_figure("main_catalog_time"), dpi=dpi)

if self.postproc_config.get("all_time_windows"):
timewindow = self.timewindows

for tw in timewindow:
catpath = self.registry.get_test_catalog(tw)
catalog = CSEPCatalog.load_json(catpath)
if catalog.get_number_of_events() != 0:
ax = catalog.plot(plot_args=plot_args, show=show)
ax.get_figure().tight_layout()
ax.get_figure().savefig(
self.registry.get_figure(tw, "catalog_map"), dpi=dpi
)

ax2 = magnitude_vs_time(catalog)
ax2.get_figure().tight_layout()
ax2.get_figure().savefig(
self.registry.get_figure(tw, "catalog_time"), dpi=dpi
)

def plot_forecasts(self) -> None:
"""Plots and saves all the generated forecasts."""

plot_fc_config = self.postproc_config.get("plot_forecasts")
if plot_fc_config:
log.info("Plotting forecasts")
if plot_fc_config is True:
plot_fc_config = {}
try:
proj_ = plot_fc_config.get("projection")
if isinstance(proj_, dict):
proj_name = list(proj_.keys())[0]
proj_args = list(proj_.values())[0]
else:
proj_name = proj_
proj_args = {}
plot_fc_config["projection"] = getattr(ccrs, proj_name)(**proj_args)
except (IndexError, KeyError, TypeError, AttributeError):
plot_fc_config["projection"] = ccrs.PlateCarree(central_longitude=0.0)

cat = plot_fc_config.get("catalog")
cat_args = {}
if cat:
cat_args = {
"markersize": 7,
"markercolor": "black",
"title": "asd",
"grid": False,
"legend": False,
"basemap": None,
"region_border": False,
}
if self.region:
self.catalog.filter_spatial(self.region, in_place=True)
if isinstance(cat, dict):
cat_args.update(cat)

window = self.timewindows[-1]
winstr = timewindow2str(window)

for model in self.models:
fig_path = self.registry.get_figure(winstr, "forecasts", model.name)
start = decimal_year(window[0])
end = decimal_year(window[1])
time = f"{round(end - start, 3)} years"
plot_args = {
"region_border": False,
"cmap": "magma",
"clabel": r"$\log_{10} N\left(M_w \in [{%.2f},"
r"\,{%.2f}]\right)$ per "
r"$0.1^\circ\times 0.1^\circ $ per %s"
% (min(self.magnitudes), max(self.magnitudes), time),
}
if not self.region or self.region.name == "global":
set_global = True
else:
set_global = False
plot_args.update(plot_fc_config)
ax = model.get_forecast(winstr, self.region).plot(
set_global=set_global, plot_args=plot_args
)

if self.region:
bbox = self.region.get_bbox()
dh = self.region.dh
extent = [
bbox[0] - 3 * dh,
bbox[1] + 3 * dh,
bbox[2] - 3 * dh,
bbox[3] + 3 * dh,
]
else:
extent = None
if cat:
self.catalog.plot(
ax=ax, set_global=set_global, extent=extent, plot_args=cat_args
)

pyplot.savefig(fig_path, dpi=300, facecolor=(0, 0, 0, 0))

def generate_report(self) -> None:
"""Creates a report summarizing the Experiment's results."""

Expand Down
Loading