11import datetime
2- import json
32import os
43from typing import Dict , Callable , Union , Sequence , List
54
6- import numpy
75from csep .core .catalogs import CSEPCatalog
86from csep .core .forecasts import GriddedForecast
9- from csep .models import EvaluationResult
107from matplotlib import pyplot
118
129from floatcsep .model import Model
1310from floatcsep .registry import ExperimentRegistry
14- from floatcsep .utils import parse_csep_func , timewindow2str
11+ from floatcsep .utils import parse_csep_func
1512
1613
1714class Evaluation :
@@ -76,6 +73,9 @@ def __init__(
7673 self .markdown = markdown
7774 self .type = Evaluation ._TYPES .get (self .func .__name__ )
7875
76+ self .results_repo = None
77+ self .catalog_repo = None
78+
7979 @property
8080 def type (self ):
8181 """
@@ -123,7 +123,6 @@ def parse_plots(self, plot_func, plot_args, plot_kwargs):
123123 def prepare_args (
124124 self ,
125125 timewindow : Union [str , list ],
126- catpath : Union [str , list ],
127126 model : Union [Model , Sequence [Model ]],
128127 ref_model : Union [Model , Sequence ] = None ,
129128 region = None ,
@@ -153,7 +152,7 @@ def prepare_args(
153152 # Prepare argument tuple
154153
155154 forecast = model .get_forecast (timewindow , region )
156- catalog = self .get_catalog (catpath , forecast )
155+ catalog = self .get_catalog (timewindow , forecast )
157156
158157 if isinstance (ref_model , Model ):
159158 # Args: (Fc, RFc, Cat)
@@ -169,29 +168,32 @@ def prepare_args(
169168
170169 return test_args
171170
172- @staticmethod
173171 def get_catalog (
174- catalog_path : Union [str , Sequence [str ]],
172+ self ,
173+ timewindow : Union [str , Sequence [str ]],
175174 forecast : Union [GriddedForecast , Sequence [GriddedForecast ]],
176175 ) -> Union [CSEPCatalog , List [CSEPCatalog ]]:
177176 """
178177 Reads the catalog(s) from the given path(s). References the catalog region to the
179178 forecast region.
180179
181180 Args:
182- catalog_path (str, list(str)): Path to the existing catalog
181+ timewindow (str): Time window of the testing catalog
183182 forecast (:class:`~csep.core.forecasts.GriddedForecast`): Forecast
184183 object, onto which the catalog will be confronted for testing.
185184
186185 Returns:
187186 """
188- if isinstance (catalog_path , str ):
189- eval_cat = CSEPCatalog .load_json (catalog_path )
187+
188+ if isinstance (timewindow , str ):
189+ # eval_cat = CSEPCatalog.load_json(catalog_path)
190+ eval_cat = self .catalog_repo .get_test_cat (timewindow )
190191 eval_cat .region = getattr (forecast , "region" )
192+
191193 else :
192- eval_cat = [CSEPCatalog . load_json (i ) for i in catalog_path ]
194+ eval_cat = [self . catalog_repo . get_test_cat (i ) for i in timewindow ]
193195 if (len (forecast ) != len (eval_cat )) or (not isinstance (forecast , Sequence )):
194- raise IndexError ("Amount of passed catalogs and forecats must " "be the same" )
196+ raise IndexError ("Amount of passed catalogs and forecasts must " "be the same" )
195197 for cat , fc in zip (eval_cat , forecast ):
196198 cat .region = getattr (fc , "region" , None )
197199
@@ -202,7 +204,6 @@ def compute(
202204 timewindow : Union [str , list ],
203205 catalog : str ,
204206 model : Model ,
205- path : str ,
206207 ref_model : Union [Model , Sequence [Model ]] = None ,
207208 region = None ,
208209 ) -> None :
@@ -216,65 +217,38 @@ def compute(
216217 catalog (str): Path to the filtered catalog
217218 model (Model, list[Model]): Model(s) to be evaluated
218219 ref_model: Model to be used as reference
219- path: Path to store the Evaluation result
220220 region: region to filter a catalog forecast.
221221
222222 Returns:
223223 """
224224 test_args = self .prepare_args (
225- timewindow , catpath = catalog , model = model , ref_model = ref_model , region = region
225+ timewindow , model = model , ref_model = ref_model , region = region
226226 )
227227
228228 evaluation_result = self .func (* test_args , ** self .func_kwargs )
229- self .write_result (evaluation_result , path )
230-
231- @staticmethod
232- def write_result (result : EvaluationResult , path : str ) -> None :
233- """Dumps a test result into a json file."""
234229
235- class NumpyEncoder (json .JSONEncoder ):
236- def default (self , obj ):
237- if isinstance (obj , numpy .integer ):
238- return int (obj )
239- if isinstance (obj , numpy .floating ):
240- return float (obj )
241- if isinstance (obj , numpy .ndarray ):
242- return obj .tolist ()
243- return json .JSONEncoder .default (self , obj )
244-
245- with open (path , "w" ) as _file :
246- json .dump (result .to_dict (), _file , indent = 4 , cls = NumpyEncoder )
230+ if self .type in ["sequential" , "sequential_comparative" ]:
231+ self .results_repo .write_result (evaluation_result , self , model , timewindow [- 1 ])
232+ else :
233+ self .results_repo .write_result (evaluation_result , self , model , timewindow )
247234
248235 def read_results (
249- self ,
250- window : Union [str , Sequence [datetime .datetime ]],
251- models : List [Model ],
252- tree : ExperimentRegistry ,
236+ self , window : Union [str , Sequence [datetime .datetime ]], models : List [Model ]
253237 ) -> List :
254238 """
255239 Reads an Evaluation result for a given time window and returns a list of the results for
256240 all tested models.
257241 """
258- test_results = []
259-
260- if not isinstance (window , str ):
261- wstr_ = timewindow2str (window )
262- else :
263- wstr_ = window
264242
265- for i in models :
266- eval_path = tree (wstr_ , "evaluations" , self , i .name )
267- with open (eval_path , "r" ) as file_ :
268- model_eval = EvaluationResult .from_dict (json .load (file_ ))
269- test_results .append (model_eval )
243+ test_results = self .results_repo .load_results (self , window , models )
270244
271245 return test_results
272246
273247 def plot_results (
274248 self ,
275249 timewindow : Union [str , List ],
276250 models : List [Model ],
277- tree : ExperimentRegistry ,
251+ registry : ExperimentRegistry ,
278252 dpi : int = 300 ,
279253 show : bool = False ,
280254 ) -> None :
@@ -284,7 +258,7 @@ def plot_results(
284258 Args:
285259 timewindow: string representing the desired timewindow to plot
286260 models: a list of :class:`floatcsep:models.Model`
287- tree : a :class:`floatcsep:models.PathTree` containing path of the results
261+ registry : a :class:`floatcsep:models.PathTree` containing path of the results
288262 dpi: Figure resolution with which to save
289263 show: show in runtime
290264 """
@@ -296,8 +270,8 @@ def plot_results(
296270
297271 try :
298272 for time_str in timewindow :
299- fig_path = tree (time_str , "figures" , self .name )
300- results = self .read_results (time_str , models , tree )
273+ fig_path = registry . get (time_str , "figures" , self .name )
274+ results = self .read_results (time_str , models )
301275 ax = func (results , plot_args = fargs , ** fkwargs )
302276 if "code" in fargs :
303277 exec (fargs ["code" ])
@@ -308,14 +282,14 @@ def plot_results(
308282 except AttributeError as msg :
309283 if self .type in ["consistency" , "comparative" ]:
310284 for time_str in timewindow :
311- results = self .read_results (time_str , models , tree )
285+ results = self .read_results (time_str , models )
312286 for result , model in zip (results , models ):
313287 fig_name = f"{ self .name } _{ model .name } "
314288
315- tree .paths [time_str ]["figures" ][fig_name ] = os .path .join (
289+ registry .paths [time_str ]["figures" ][fig_name ] = os .path .join (
316290 time_str , "figures" , fig_name
317291 )
318- fig_path = tree (time_str , "figures" , fig_name )
292+ fig_path = registry . get (time_str , "figures" , fig_name )
319293 ax = func (result , plot_args = fargs , ** fkwargs , show = False )
320294 if "code" in fargs :
321295 exec (fargs ["code" ])
@@ -324,8 +298,8 @@ def plot_results(
324298 pyplot .show ()
325299
326300 elif self .type in ["sequential" , "sequential_comparative" , "batch" ]:
327- fig_path = tree (timewindow [- 1 ], "figures" , self .name )
328- results = self .read_results (timewindow [- 1 ], models , tree )
301+ fig_path = registry . get (timewindow [- 1 ], "figures" , self .name )
302+ results = self .read_results (timewindow [- 1 ], models )
329303 ax = func (results , plot_args = fargs , ** fkwargs )
330304
331305 if "code" in fargs :
0 commit comments