-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Fix moo things #1501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Fix moo things #1501
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
c07856e
Push
eddiebergman b43ab70
`fit_ensemble` now has priority for kwargs to take
eddiebergman 4609c81
Change ordering of prefernce for ensemble params
eddiebergman c611617
Add TODO note for metrics
eddiebergman 39c1d7e
Add `metrics` arg to `fit_ensemble`
eddiebergman d8a01f1
Add test for pareto front sizes
eddiebergman 90e1482
Remove uneeded file
eddiebergman 9f82169
Re-added tests to `test_pareto_front`
eddiebergman dd284e4
Add descriptions to test files
eddiebergman 9a243c3
Add test to ensure argument priority
eddiebergman 13bcd49
Add test to make sure X_data only loaded when required
eddiebergman 864414d
Remove part of test required for performance history
eddiebergman 55c792b
Default to `self._metrics` if `metrics` not available
eddiebergman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Dict, Iterable, Sequence, Type, cast | ||
from typing import Any, Iterable, Mapping, Sequence, Type, cast | ||
|
||
import logging.handlers | ||
import multiprocessing | ||
|
@@ -46,7 +46,7 @@ def __init__( | |
task_type: int, | ||
metrics: Sequence[Scorer], | ||
ensemble_class: Type[AbstractEnsemble] = EnsembleSelection, | ||
ensemble_kwargs: Dict[str, Any] | None = None, | ||
ensemble_kwargs: Mapping[str, Any] | None = None, | ||
ensemble_nbest: int | float = 50, | ||
max_models_on_disc: int | float | None = 100, | ||
seed: int = 1, | ||
|
@@ -71,9 +71,11 @@ def __init__( | |
metrics: Sequence[Scorer] | ||
Metrics to optimize the ensemble for. These must be non-duplicated. | ||
|
||
ensemble_class | ||
ensemble_class: Type[AbstractEnsemble] | ||
Implementation of the ensemble algorithm. | ||
|
||
ensemble_kwargs | ||
ensemble_kwargs: Mapping[str, Any] | None | ||
Arguments passed to the constructor of the ensemble algorithm. | ||
|
||
ensemble_nbest: int | float = 50 | ||
|
||
|
@@ -169,6 +171,8 @@ def __init__( | |
self.validation_performance_ = np.inf | ||
|
||
# Data we may need | ||
# TODO: The test data is needlessly loaded but automl_common has no concept of | ||
# these and is perhaps too rigid | ||
datamanager: XYDataManager = self.backend.load_datamanager() | ||
self._X_test: SUPPORTED_FEAT_TYPES | None = datamanager.data.get("X_test", None) | ||
self._y_test: np.ndarray | None = datamanager.data.get("Y_test", None) | ||
|
@@ -442,6 +446,17 @@ def main( | |
self.logger.debug("Found no runs") | ||
raise RuntimeError("Found no runs") | ||
|
||
# We load in `X_data` if we need it | ||
if any(m._needs_X for m in self.metrics): | ||
ensemble_X_data = self.X_data("ensemble") | ||
|
||
if ensemble_X_data is None: | ||
msg = "No `X_data` for 'ensemble' which was required by metrics" | ||
self.logger.debug(msg) | ||
raise RuntimeError(msg) | ||
else: | ||
ensemble_X_data = None | ||
|
||
# Calculate the loss for those that require it | ||
requires_update = self.requires_loss_update(runs) | ||
if self.read_at_most is not None: | ||
|
@@ -450,9 +465,7 @@ def main( | |
for run in requires_update: | ||
run.record_modified_times() # So we don't count as modified next time | ||
run.losses = { | ||
metric.name: self.loss( | ||
run, metric=metric, X_data=self.X_data("ensemble") | ||
) | ||
metric.name: self.loss(run, metric=metric, X_data=ensemble_X_data) | ||
for metric in self.metrics | ||
} | ||
|
||
|
@@ -549,15 +562,14 @@ def main( | |
return self.ensemble_history, self.ensemble_nbest | ||
|
||
targets = cast(np.ndarray, self.targets("ensemble")) # Sure they exist | ||
X_data = self.X_data("ensemble") | ||
|
||
ensemble = self.fit_ensemble( | ||
candidates=candidates, | ||
X_data=X_data, | ||
targets=targets, | ||
runs=runs, | ||
ensemble_class=self.ensemble_class, | ||
ensemble_kwargs=self.ensemble_kwargs, | ||
X_data=ensemble_X_data, | ||
task=self.task_type, | ||
metrics=self.metrics, | ||
precision=self.precision, | ||
|
@@ -587,7 +599,15 @@ def main( | |
|
||
run_preds = [r.predictions(kind, precision=self.precision) for r in models] | ||
pred = ensemble.predict(run_preds) | ||
X_data = self.X_data(kind) | ||
|
||
if any(m._needs_X for m in self.metrics): | ||
X_data = self.X_data(kind) | ||
if X_data is None: | ||
msg = f"No `X` data for '{kind}' which was required by metrics" | ||
self.logger.debug(msg) | ||
raise RuntimeError(msg) | ||
else: | ||
X_data = None | ||
|
||
scores = calculate_scores( | ||
solution=pred_targets, | ||
|
@@ -597,10 +617,19 @@ def main( | |
X_data=X_data, | ||
scoring_functions=None, | ||
) | ||
|
||
# TODO only one metric in history | ||
# | ||
# We should probably return for all metrics but this makes | ||
# automl::performance_history a lot more complicated, will | ||
# tackle in a future PR | ||
first_metric = self.metrics[0] | ||
performance_stamp[f"ensemble_{score_name}_score"] = scores[ | ||
self.metrics[0].name | ||
first_metric.name | ||
] | ||
self.ensemble_history.append(performance_stamp) | ||
|
||
# Add the performance stamp to the history | ||
self.ensemble_history.append(performance_stamp) | ||
Comment on lines
+630
to
+632
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BUGFIX: Marking this to match PR description, used to be inside the for loop but should have been outside |
||
|
||
# Lastly, delete any runs that need to be deleted. We save this as the last step | ||
# so that we have an ensemble saved that is up to date. If we do not do so, | ||
|
@@ -805,13 +834,13 @@ def candidate_selection( | |
|
||
def fit_ensemble( | ||
self, | ||
candidates: list[Run], | ||
X_data: SUPPORTED_FEAT_TYPES, | ||
targets: np.ndarray, | ||
candidates: Sequence[Run], | ||
runs: Sequence[Run], | ||
*, | ||
runs: list[Run], | ||
targets: np.ndarray | None = None, | ||
ensemble_class: Type[AbstractEnsemble] = EnsembleSelection, | ||
ensemble_kwargs: Dict[str, Any] | None = None, | ||
ensemble_kwargs: Mapping[str, Any] | None = None, | ||
X_data: SUPPORTED_FEAT_TYPES | None = None, | ||
task: int | None = None, | ||
metrics: Sequence[Scorer] | None = None, | ||
precision: int | None = None, | ||
|
@@ -825,24 +854,24 @@ def fit_ensemble( | |
|
||
Parameters | ||
---------- | ||
candidates: list[Run] | ||
candidates: Sequence[Run] | ||
List of runs to build an ensemble from | ||
|
||
X_data: SUPPORTED_FEAT_TYPES | ||
The base level data. | ||
runs: Sequence[Run] | ||
List of all runs (also pruned ones and dummy runs) | ||
|
||
targets: np.ndarray | ||
targets: np.ndarray | None = None | ||
The targets to build the ensemble with | ||
|
||
runs: list[Run] | ||
List of all runs (also pruned ones and dummy runs) | ||
|
||
ensemble_class: AbstractEnsemble | ||
ensemble_class: Type[AbstractEnsemble] | ||
Implementation of the ensemble algorithm. | ||
|
||
ensemble_kwargs: Dict[str, Any] | ||
ensemble_kwargs: Mapping[str, Any] | None | ||
Arguments passed to the constructor of the ensemble algorithm. | ||
|
||
X_data: SUPPORTED_FEAT_TYPES | None = None | ||
The base level data. | ||
|
||
task: int | None = None | ||
The kind of task performed | ||
|
||
|
@@ -859,24 +888,42 @@ def fit_ensemble( | |
------- | ||
AbstractEnsemble | ||
""" | ||
task = task if task is not None else self.task_type | ||
# Validate we have targets if None specified | ||
if targets is None: | ||
targets = self.targets("ensemble") | ||
if targets is None: | ||
path = self.backend._get_targets_ensemble_filename() | ||
raise ValueError(f"`fit_ensemble` could not find any targets at {path}") | ||
|
||
ensemble_class = ( | ||
ensemble_class if ensemble_class is not None else self.ensemble_class | ||
) | ||
ensemble_kwargs = ( | ||
ensemble_kwargs if ensemble_kwargs is not None else self.ensemble_kwargs | ||
) | ||
ensemble_kwargs = ensemble_kwargs if ensemble_kwargs is not None else {} | ||
metrics = metrics if metrics is not None else self.metrics | ||
rs = random_state if random_state is not None else self.random_state | ||
|
||
ensemble = ensemble_class( | ||
task_type=task, | ||
metrics=metrics, | ||
random_state=rs, | ||
backend=self.backend, | ||
**ensemble_kwargs, | ||
) # type: AbstractEnsemble | ||
# Create the ensemble_kwargs, favouring in order: | ||
# 1) function kwargs, 2) function params 3) init_kwargs 4) init_params | ||
|
||
# Collect func params in dict if they're not None | ||
params = { | ||
k: v | ||
for k, v in [ | ||
("task_type", task), | ||
("metrics", metrics), | ||
("random_state", random_state), | ||
] | ||
if v is not None | ||
} | ||
|
||
kwargs = { | ||
"backend": self.backend, | ||
"task_type": self.task_type, | ||
"metrics": self.metrics, | ||
"random_state": self.random_state, | ||
**(self.ensemble_kwargs or {}), | ||
**params, | ||
**(ensemble_kwargs or {}), | ||
} | ||
|
||
ensemble = ensemble_class(**kwargs) # type: AbstractEnsemble | ||
|
||
self.logger.debug(f"Fitting ensemble on {len(candidates)} models") | ||
start_time = time.time() | ||
|
@@ -995,7 +1042,8 @@ def loss( | |
self, | ||
run: Run, | ||
metric: Scorer, | ||
X_data: SUPPORTED_FEAT_TYPES, | ||
*, | ||
X_data: SUPPORTED_FEAT_TYPES | None = None, | ||
kind: str = "ensemble", | ||
) -> float: | ||
"""Calculate the loss for a run | ||
|
@@ -1008,6 +1056,9 @@ def loss( | |
metric: Scorer | ||
The metric to calculate the loss of | ||
|
||
X_data: SUPPORTED_FEAT_TYPES | None = None | ||
Any X_data required to be passed to the metric | ||
|
||
kind: str = "ensemble" | ||
The kind of targets to use for the run | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import Any | ||
|
||
import numpy as np | ||
|
||
from autosklearn.metrics import accuracy, make_scorer | ||
|
||
|
||
def _accuracy_requiring_X_data( | ||
y_true: np.ndarray, | ||
y_pred: np.ndarray, | ||
X_data: Any, | ||
) -> float: | ||
"""Dummy metric that needs X Data""" | ||
if X_data is None: | ||
raise ValueError() | ||
return accuracy(y_true, y_pred) | ||
|
||
|
||
acc_with_X_data = make_scorer( | ||
name="acc_with_X_data", | ||
score_func=_accuracy_requiring_X_data, | ||
needs_X=True, | ||
optimum=1, | ||
worst_possible_result=0, | ||
greater_is_better=True, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +0,0 @@ | ||
# -*- encoding: utf-8 -*- | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.