|
7 | 7 |
|
8 | 8 | import numpy |
9 | 9 | import yaml |
10 | | -from cartopy import crs as ccrs |
11 | | -from csep.core.catalogs import CSEPCatalog |
12 | | -from csep.utils.time_utils import decimal_year |
13 | | -from matplotlib import pyplot |
14 | 10 |
|
15 | 11 | from floatcsep import report |
16 | 12 | from floatcsep.evaluation import Evaluation |
|
25 | 21 | Task, |
26 | 22 | TaskGraph, |
27 | 23 | timewindow2str, |
28 | | - magnitude_vs_time, |
29 | 24 | parse_nested_dicts, |
30 | 25 | ) |
31 | 26 |
|
@@ -81,7 +76,7 @@ class Experiment: |
81 | 76 | test_config (str): Path to the evaluations' configuration file |
82 | 77 | default_test_kwargs (dict): Default values for the testing |
83 | 78 | (seed, number of simulations, etc.) |
84 | | - postproc_config (dict): Contains the instruction for postprocessing |
| 79 | + postprocess (dict): Contains the instruction for postprocessing |
85 | 80 | (e.g. plot forecasts, catalogs) |
86 | 81 | **kwargs: see Note |
87 | 82 |
|
@@ -116,7 +111,7 @@ def __init__( |
116 | 111 | catalog: str = None, |
117 | 112 | models: str = None, |
118 | 113 | tests: str = None, |
119 | | - postproc_config: str = None, |
| 114 | + postprocess: str = None, |
120 | 115 | default_test_kwargs: dict = None, |
121 | 116 | rundir: str = "results", |
122 | 117 | report_hook: dict = None, |
@@ -174,7 +169,7 @@ def __init__( |
174 | 169 | self.models = [] |
175 | 170 | self.tests = [] |
176 | 171 |
|
177 | | - self.postproc_config = postproc_config if postproc_config else {} |
| 172 | + self.postprocess = postprocess if postprocess else {} |
178 | 173 | self.default_test_kwargs = default_test_kwargs |
179 | 174 |
|
180 | 175 | self.catalog_repo.set_main_catalog(catalog, self.time_config, self.region_config) |
@@ -564,141 +559,6 @@ def read_results(self, test: Evaluation, window: str) -> List: |
564 | 559 |
|
565 | 560 | return test.read_results(window, self.models) |
566 | 561 |
|
567 | | - def plot_results(self) -> None: |
568 | | - """Plots all evaluation results.""" |
569 | | - log.info("Plotting evaluations") |
570 | | - timewindows = timewindow2str(self.timewindows) |
571 | | - |
572 | | - for test in self.tests: |
573 | | - test.plot_results(timewindows, self.models, self.registry) |
574 | | - |
575 | | - def plot_catalog(self, dpi: int = 300, show: bool = False) -> None: |
576 | | - """ |
577 | | - Plots the evaluation catalogs. |
578 | | -
|
579 | | - Args: |
580 | | - dpi: Figure resolution with which to save |
581 | | - show: show in runtime |
582 | | - """ |
583 | | - plot_args = { |
584 | | - "basemap": "ESRI_terrain", |
585 | | - "figsize": (12, 8), |
586 | | - "markersize": 8, |
587 | | - "markercolor": "black", |
588 | | - "grid_fontsize": 16, |
589 | | - "title": "", |
590 | | - "legend": True, |
591 | | - } |
592 | | - plot_args.update(self.postproc_config.get("plot_catalog", {})) |
593 | | - catalog = self.catalog_repo.get_test_cat() |
594 | | - if catalog.get_number_of_events() != 0: |
595 | | - ax = catalog.plot(plot_args=plot_args, show=show) |
596 | | - ax.get_figure().tight_layout() |
597 | | - ax.get_figure().savefig(self.registry.get_figure("main_catalog_map"), dpi=dpi) |
598 | | - |
599 | | - ax2 = magnitude_vs_time(catalog) |
600 | | - ax2.get_figure().tight_layout() |
601 | | - ax2.get_figure().savefig(self.registry.get_figure("main_catalog_time"), dpi=dpi) |
602 | | - |
603 | | - if self.postproc_config.get("all_time_windows"): |
604 | | - timewindow = self.timewindows |
605 | | - |
606 | | - for tw in timewindow: |
607 | | - catpath = self.registry.get_test_catalog(tw) |
608 | | - catalog = CSEPCatalog.load_json(catpath) |
609 | | - if catalog.get_number_of_events() != 0: |
610 | | - ax = catalog.plot(plot_args=plot_args, show=show) |
611 | | - ax.get_figure().tight_layout() |
612 | | - ax.get_figure().savefig( |
613 | | - self.registry.get_figure(tw, "catalog_map"), dpi=dpi |
614 | | - ) |
615 | | - |
616 | | - ax2 = magnitude_vs_time(catalog) |
617 | | - ax2.get_figure().tight_layout() |
618 | | - ax2.get_figure().savefig( |
619 | | - self.registry.get_figure(tw, "catalog_time"), dpi=dpi |
620 | | - ) |
621 | | - |
622 | | - def plot_forecasts(self) -> None: |
623 | | - """Plots and saves all the generated forecasts.""" |
624 | | - |
625 | | - plot_fc_config = self.postproc_config.get("plot_forecasts") |
626 | | - if plot_fc_config: |
627 | | - log.info("Plotting forecasts") |
628 | | - if plot_fc_config is True: |
629 | | - plot_fc_config = {} |
630 | | - try: |
631 | | - proj_ = plot_fc_config.get("projection") |
632 | | - if isinstance(proj_, dict): |
633 | | - proj_name = list(proj_.keys())[0] |
634 | | - proj_args = list(proj_.values())[0] |
635 | | - else: |
636 | | - proj_name = proj_ |
637 | | - proj_args = {} |
638 | | - plot_fc_config["projection"] = getattr(ccrs, proj_name)(**proj_args) |
639 | | - except (IndexError, KeyError, TypeError, AttributeError): |
640 | | - plot_fc_config["projection"] = ccrs.PlateCarree(central_longitude=0.0) |
641 | | - |
642 | | - cat = plot_fc_config.get("catalog") |
643 | | - cat_args = {} |
644 | | - if cat: |
645 | | - cat_args = { |
646 | | - "markersize": 7, |
647 | | - "markercolor": "black", |
648 | | - "title": "asd", |
649 | | - "grid": False, |
650 | | - "legend": False, |
651 | | - "basemap": None, |
652 | | - "region_border": False, |
653 | | - } |
654 | | - if self.region: |
655 | | - self.catalog.filter_spatial(self.region, in_place=True) |
656 | | - if isinstance(cat, dict): |
657 | | - cat_args.update(cat) |
658 | | - |
659 | | - window = self.timewindows[-1] |
660 | | - winstr = timewindow2str(window) |
661 | | - |
662 | | - for model in self.models: |
663 | | - fig_path = self.registry.get_figure(winstr, "forecasts", model.name) |
664 | | - start = decimal_year(window[0]) |
665 | | - end = decimal_year(window[1]) |
666 | | - time = f"{round(end - start, 3)} years" |
667 | | - plot_args = { |
668 | | - "region_border": False, |
669 | | - "cmap": "magma", |
670 | | - "clabel": r"$\log_{10} N\left(M_w \in [{%.2f}," |
671 | | - r"\,{%.2f}]\right)$ per " |
672 | | - r"$0.1^\circ\times 0.1^\circ $ per %s" |
673 | | - % (min(self.magnitudes), max(self.magnitudes), time), |
674 | | - } |
675 | | - if not self.region or self.region.name == "global": |
676 | | - set_global = True |
677 | | - else: |
678 | | - set_global = False |
679 | | - plot_args.update(plot_fc_config) |
680 | | - ax = model.get_forecast(winstr, self.region).plot( |
681 | | - set_global=set_global, plot_args=plot_args |
682 | | - ) |
683 | | - |
684 | | - if self.region: |
685 | | - bbox = self.region.get_bbox() |
686 | | - dh = self.region.dh |
687 | | - extent = [ |
688 | | - bbox[0] - 3 * dh, |
689 | | - bbox[1] + 3 * dh, |
690 | | - bbox[2] - 3 * dh, |
691 | | - bbox[3] + 3 * dh, |
692 | | - ] |
693 | | - else: |
694 | | - extent = None |
695 | | - if cat: |
696 | | - self.catalog.plot( |
697 | | - ax=ax, set_global=set_global, extent=extent, plot_args=cat_args |
698 | | - ) |
699 | | - |
700 | | - pyplot.savefig(fig_path, dpi=300, facecolor=(0, 0, 0, 0)) |
701 | | - |
702 | 562 | def generate_report(self) -> None: |
703 | 563 | """Creates a report summarizing the Experiment's results.""" |
704 | 564 |
|
|
0 commit comments