Skip to content

Commit 13ec457

Browse files
committed
[fix] Fix flake8 issues and increase coverage
1 parent f55c4ac commit 13ec457

File tree

2 files changed

+4
-78
lines changed

2 files changed

+4
-78
lines changed

test/test_api/test_api.py

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pathlib
44
import pickle
55
import unittest
6-
from test.test_api.utils import dummy_do_dummy_prediction, dummy_eval_function, dummy_traditional_classification
6+
from test.test_api.utils import dummy_do_dummy_prediction, dummy_eval_function
77

88
import ConfigSpace as CS
99
from ConfigSpace.configuration_space import Configuration
@@ -25,12 +25,10 @@
2525

2626
from autoPyTorch.api.tabular_classification import TabularClassificationTask
2727
from autoPyTorch.api.tabular_regression import TabularRegressionTask
28-
from autoPyTorch.data.tabular_validator import TabularInputValidator
2928
from autoPyTorch.datasets.resampling_strategy import (
3029
CrossValTypes,
3130
HoldoutValTypes,
3231
)
33-
from autoPyTorch.datasets.tabular_dataset import TabularDataset
3432
from autoPyTorch.optimizer.smbo import AutoMLSMBO
3533
from autoPyTorch.pipeline.base_pipeline import BasePipeline
3634
from autoPyTorch.pipeline.components.setup.traditional_ml.traditional_learner import _traditional_learners
@@ -575,76 +573,6 @@ def test_portfolio_selection_failure(openml_id, backend, n_samples):
575573
)
576574

577575

578-
"""
579-
@pytest.mark.parametrize('dataset_name', ('iris',))
580-
@pytest.mark.parametrize('include_traditional', (True, False))
581-
def test_get_incumbent_results(dataset_name, backend, include_traditional):
582-
# TODO: Remove this function completely if possible
583-
# Get the data and check that contents of data-manager make sense
584-
X, y = sklearn.datasets.fetch_openml(
585-
name=dataset_name,
586-
return_X_y=True, as_frame=True
587-
)
588-
589-
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
590-
X, y, random_state=1)
591-
592-
# Search for a good configuration
593-
estimator = TabularClassificationTask(
594-
backend=backend,
595-
resampling_strategy=HoldoutValTypes.holdout_validation,
596-
)
597-
598-
InputValidator = TabularInputValidator(
599-
is_classification=True,
600-
)
601-
602-
# Fit a input validator to check the provided data
603-
# Also, an encoder is fit to both train and test data,
604-
# to prevent unseen categories during inference
605-
InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)
606-
607-
dataset = TabularDataset(
608-
X=X_train, Y=y_train,
609-
X_test=X_test, Y_test=y_test,
610-
validator=InputValidator,
611-
resampling_strategy=estimator.resampling_strategy,
612-
resampling_strategy_args=estimator.resampling_strategy_args,
613-
)
614-
615-
pipeline_run_history = RunHistory()
616-
pipeline_run_history.load_json(os.path.join(os.path.dirname(__file__), '.tmp_api/runhistory.json'),
617-
estimator.get_search_space(dataset))
618-
619-
estimator._do_dummy_prediction = unittest.mock.MagicMock()
620-
621-
with unittest.mock.patch.object(AutoMLSMBO, 'run_smbo') as AutoMLSMBOMock:
622-
with unittest.mock.patch.object(TabularClassificationTask, '_do_traditional_prediction',
623-
new=dummy_traditional_classification):
624-
AutoMLSMBOMock.return_value = (pipeline_run_history, {}, 'epochs')
625-
estimator.search(
626-
X_train=X_train, y_train=y_train,
627-
X_test=X_test, y_test=y_test,
628-
optimize_metric='accuracy',
629-
total_walltime_limit=150,
630-
func_eval_time_limit_secs=50,
631-
enable_traditional_pipeline=True,
632-
load_models=False,
633-
)
634-
config, results = estimator.get_incumbent_results(include_traditional=include_traditional)
635-
assert isinstance(config, Configuration)
636-
assert isinstance(results, dict)
637-
638-
run_history_data = estimator.run_history.data
639-
costs = [run_value.cost for run_key, run_value in run_history_data.items() if run_value.additional_info is not None
640-
and (run_value.additional_info['configuration_origin'] != 'traditional' or include_traditional)]
641-
assert results['opt_loss']['accuracy'] == min(costs)
642-
643-
if not include_traditional:
644-
assert results['configuration_origin'] != 'traditional'
645-
"""
646-
647-
648576
# TODO: Make faster when https://github.com/automl/Auto-PyTorch/pull/223 is incorporated
649577
@pytest.mark.parametrize("fit_dictionary_tabular", ['classification_categorical_only'], indirect=True)
650578
def test_do_traditional_pipeline(fit_dictionary_tabular):

test/test_api/test_results_manager.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,16 @@
44
from unittest.mock import MagicMock
55

66
import ConfigSpace.hyperparameters as CSH
7+
from ConfigSpace.configuration_space import Configuration, ConfigurationSpace
78

89
import numpy as np
910

1011
import pytest
1112

12-
from ConfigSpace.configuration_space import Configuration, ConfigurationSpace
13-
1413
from smac.runhistory.runhistory import RunHistory, StatusType
1514

1615
from autoPyTorch.api.base_task import BaseTask
17-
from autoPyTorch.api.results_manager import cost2metric, ResultsManager
18-
from autoPyTorch.api.results_manager import STATUS2MSG
16+
from autoPyTorch.api.results_manager import ResultsManager, STATUS2MSG, cost2metric
1917
from autoPyTorch.metrics import accuracy, balanced_accuracy, log_loss
2018

2119

@@ -96,7 +94,7 @@ def _check_metric_dict(metric_dict, status_types):
9694

9795
def test_search_results_sprint_statistics():
9896
api = BaseTask()
99-
for method in ['get_search_results', 'sprint_statistics']:
97+
for method in ['get_search_results', 'sprint_statistics', 'get_incumbent_results']:
10098
try:
10199
getattr(api, method)()
102100
except RuntimeError:

0 commit comments

Comments
 (0)