Skip to content

Commit f352add

Browse files
authored
Merge pull request #16 from cseptesting/9-decouple-experiment-from-file-management
9 decouple experiment from file management
2 parents 334c461 + a194c47 commit f352add

File tree

14 files changed

+492
-314
lines changed

14 files changed

+492
-314
lines changed

floatcsep/cmd/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def reproduce(config, **kwargs):
5858
reproduced_exp.run()
5959

6060
original_config = reproduced_exp.original_config
61-
original_exp = Experiment.from_yml(original_config, rundir=reproduced_exp.original_rundir)
61+
original_exp = Experiment.from_yml(original_config, rundir=reproduced_exp.original_run_dir)
6262
original_exp.stage_models()
6363
original_exp.set_tasks()
6464

floatcsep/evaluation.py

Lines changed: 31 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
import datetime
2-
import json
32
import os
43
from typing import Dict, Callable, Union, Sequence, List
54

6-
import numpy
75
from csep.core.catalogs import CSEPCatalog
86
from csep.core.forecasts import GriddedForecast
9-
from csep.models import EvaluationResult
107
from matplotlib import pyplot
118

129
from floatcsep.model import Model
1310
from floatcsep.registry import ExperimentRegistry
14-
from floatcsep.utils import parse_csep_func, timewindow2str
11+
from floatcsep.utils import parse_csep_func
1512

1613

1714
class Evaluation:
@@ -76,6 +73,9 @@ def __init__(
7673
self.markdown = markdown
7774
self.type = Evaluation._TYPES.get(self.func.__name__)
7875

76+
self.results_repo = None
77+
self.catalog_repo = None
78+
7979
@property
8080
def type(self):
8181
"""
@@ -123,7 +123,6 @@ def parse_plots(self, plot_func, plot_args, plot_kwargs):
123123
def prepare_args(
124124
self,
125125
timewindow: Union[str, list],
126-
catpath: Union[str, list],
127126
model: Union[Model, Sequence[Model]],
128127
ref_model: Union[Model, Sequence] = None,
129128
region=None,
@@ -153,7 +152,7 @@ def prepare_args(
153152
# Prepare argument tuple
154153

155154
forecast = model.get_forecast(timewindow, region)
156-
catalog = self.get_catalog(catpath, forecast)
155+
catalog = self.get_catalog(timewindow, forecast)
157156

158157
if isinstance(ref_model, Model):
159158
# Args: (Fc, RFc, Cat)
@@ -169,29 +168,32 @@ def prepare_args(
169168

170169
return test_args
171170

172-
@staticmethod
173171
def get_catalog(
174-
catalog_path: Union[str, Sequence[str]],
172+
self,
173+
timewindow: Union[str, Sequence[str]],
175174
forecast: Union[GriddedForecast, Sequence[GriddedForecast]],
176175
) -> Union[CSEPCatalog, List[CSEPCatalog]]:
177176
"""
178177
Reads the catalog(s) from the given path(s). References the catalog region to the
179178
forecast region.
180179
181180
Args:
182-
catalog_path (str, list(str)): Path to the existing catalog
181+
timewindow (str): Time window of the testing catalog
183182
forecast (:class:`~csep.core.forecasts.GriddedForecast`): Forecast
184183
object, onto which the catalog will be confronted for testing.
185184
186185
Returns:
187186
"""
188-
if isinstance(catalog_path, str):
189-
eval_cat = CSEPCatalog.load_json(catalog_path)
187+
188+
if isinstance(timewindow, str):
189+
# eval_cat = CSEPCatalog.load_json(catalog_path)
190+
eval_cat = self.catalog_repo.get_test_cat(timewindow)
190191
eval_cat.region = getattr(forecast, "region")
192+
191193
else:
192-
eval_cat = [CSEPCatalog.load_json(i) for i in catalog_path]
194+
eval_cat = [self.catalog_repo.get_test_cat(i) for i in timewindow]
193195
if (len(forecast) != len(eval_cat)) or (not isinstance(forecast, Sequence)):
194-
raise IndexError("Amount of passed catalogs and forecats must " "be the same")
196+
raise IndexError("Amount of passed catalogs and forecasts must " "be the same")
195197
for cat, fc in zip(eval_cat, forecast):
196198
cat.region = getattr(fc, "region", None)
197199

@@ -202,7 +204,6 @@ def compute(
202204
timewindow: Union[str, list],
203205
catalog: str,
204206
model: Model,
205-
path: str,
206207
ref_model: Union[Model, Sequence[Model]] = None,
207208
region=None,
208209
) -> None:
@@ -216,65 +217,38 @@ def compute(
216217
catalog (str): Path to the filtered catalog
217218
model (Model, list[Model]): Model(s) to be evaluated
218219
ref_model: Model to be used as reference
219-
path: Path to store the Evaluation result
220220
region: region to filter a catalog forecast.
221221
222222
Returns:
223223
"""
224224
test_args = self.prepare_args(
225-
timewindow, catpath=catalog, model=model, ref_model=ref_model, region=region
225+
timewindow, model=model, ref_model=ref_model, region=region
226226
)
227227

228228
evaluation_result = self.func(*test_args, **self.func_kwargs)
229-
self.write_result(evaluation_result, path)
230-
231-
@staticmethod
232-
def write_result(result: EvaluationResult, path: str) -> None:
233-
"""Dumps a test result into a json file."""
234229

235-
class NumpyEncoder(json.JSONEncoder):
236-
def default(self, obj):
237-
if isinstance(obj, numpy.integer):
238-
return int(obj)
239-
if isinstance(obj, numpy.floating):
240-
return float(obj)
241-
if isinstance(obj, numpy.ndarray):
242-
return obj.tolist()
243-
return json.JSONEncoder.default(self, obj)
244-
245-
with open(path, "w") as _file:
246-
json.dump(result.to_dict(), _file, indent=4, cls=NumpyEncoder)
230+
if self.type in ["sequential", "sequential_comparative"]:
231+
self.results_repo.write_result(evaluation_result, self, model, timewindow[-1])
232+
else:
233+
self.results_repo.write_result(evaluation_result, self, model, timewindow)
247234

248235
def read_results(
249-
self,
250-
window: Union[str, Sequence[datetime.datetime]],
251-
models: List[Model],
252-
tree: ExperimentRegistry,
236+
self, window: Union[str, Sequence[datetime.datetime]], models: List[Model]
253237
) -> List:
254238
"""
255239
Reads an Evaluation result for a given time window and returns a list of the results for
256240
all tested models.
257241
"""
258-
test_results = []
259-
260-
if not isinstance(window, str):
261-
wstr_ = timewindow2str(window)
262-
else:
263-
wstr_ = window
264242

265-
for i in models:
266-
eval_path = tree(wstr_, "evaluations", self, i.name)
267-
with open(eval_path, "r") as file_:
268-
model_eval = EvaluationResult.from_dict(json.load(file_))
269-
test_results.append(model_eval)
243+
test_results = self.results_repo.load_results(self, window, models)
270244

271245
return test_results
272246

273247
def plot_results(
274248
self,
275249
timewindow: Union[str, List],
276250
models: List[Model],
277-
tree: ExperimentRegistry,
251+
registry: ExperimentRegistry,
278252
dpi: int = 300,
279253
show: bool = False,
280254
) -> None:
@@ -284,7 +258,7 @@ def plot_results(
284258
Args:
285259
timewindow: string representing the desired timewindow to plot
286260
models: a list of :class:`floatcsep:models.Model`
287-
tree: a :class:`floatcsep:models.PathTree` containing path of the results
261+
registry: a :class:`floatcsep:models.PathTree` containing path of the results
288262
dpi: Figure resolution with which to save
289263
show: show in runtime
290264
"""
@@ -296,8 +270,8 @@ def plot_results(
296270

297271
try:
298272
for time_str in timewindow:
299-
fig_path = tree(time_str, "figures", self.name)
300-
results = self.read_results(time_str, models, tree)
273+
fig_path = registry.get(time_str, "figures", self.name)
274+
results = self.read_results(time_str, models)
301275
ax = func(results, plot_args=fargs, **fkwargs)
302276
if "code" in fargs:
303277
exec(fargs["code"])
@@ -308,14 +282,14 @@ def plot_results(
308282
except AttributeError as msg:
309283
if self.type in ["consistency", "comparative"]:
310284
for time_str in timewindow:
311-
results = self.read_results(time_str, models, tree)
285+
results = self.read_results(time_str, models)
312286
for result, model in zip(results, models):
313287
fig_name = f"{self.name}_{model.name}"
314288

315-
tree.paths[time_str]["figures"][fig_name] = os.path.join(
289+
registry.paths[time_str]["figures"][fig_name] = os.path.join(
316290
time_str, "figures", fig_name
317291
)
318-
fig_path = tree(time_str, "figures", fig_name)
292+
fig_path = registry.get(time_str, "figures", fig_name)
319293
ax = func(result, plot_args=fargs, **fkwargs, show=False)
320294
if "code" in fargs:
321295
exec(fargs["code"])
@@ -324,8 +298,8 @@ def plot_results(
324298
pyplot.show()
325299

326300
elif self.type in ["sequential", "sequential_comparative", "batch"]:
327-
fig_path = tree(timewindow[-1], "figures", self.name)
328-
results = self.read_results(timewindow[-1], models, tree)
301+
fig_path = registry.get(timewindow[-1], "figures", self.name)
302+
results = self.read_results(timewindow[-1], models)
329303
ax = func(results, plot_args=fargs, **fkwargs)
330304

331305
if "code" in fargs:

0 commit comments

Comments
 (0)