Skip to content

Commit 02e97a1

Browse files
committed
[FIX] Tests after rebase of reg_cocktails (#359)
* update requirements * update requirements * resolve remaining conflicts and fix flake and mypy * Fix remaining tests and examples * fix failing checks * fix flake
1 parent 392f07a commit 02e97a1

38 files changed

+290
-1024
lines changed

autoPyTorch/api/base_task.py

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -954,18 +954,15 @@ def run_traditional_ml(
954954
learning algorithm runs over the time limit.
955955
"""
956956
assert self._logger is not None # for mypy compliancy
957-
if STRING_TO_TASK_TYPES[self.task_type] in REGRESSION_TASKS:
958-
self._logger.warning("Traditional Pipeline is not enabled for regression. Skipping...")
959-
else:
960-
traditional_task_name = 'runTraditional'
961-
self._stopwatch.start_task(traditional_task_name)
962-
elapsed_time = self._stopwatch.wall_elapsed(current_task_name)
963-
time_for_traditional = int(runtime_limit - elapsed_time)
964-
self._do_traditional_prediction(
965-
func_eval_time_limit_secs=func_eval_time_limit_secs,
966-
time_left=time_for_traditional,
967-
)
968-
self._stopwatch.stop_task(traditional_task_name)
957+
traditional_task_name = 'runTraditional'
958+
self._stopwatch.start_task(traditional_task_name)
959+
elapsed_time = self._stopwatch.wall_elapsed(current_task_name)
960+
time_for_traditional = int(runtime_limit - elapsed_time)
961+
self._do_traditional_prediction(
962+
func_eval_time_limit_secs=func_eval_time_limit_secs,
963+
time_left=time_for_traditional,
964+
)
965+
self._stopwatch.stop_task(traditional_task_name)
969966

970967
def _search(
971968
self,
@@ -1347,22 +1344,7 @@ def _search(
13471344
self._logger.info("Starting Shutdown")
13481345

13491346
if proc_ensemble is not None:
1350-
self._results_manager.ensemble_performance_history = list(proc_ensemble.history)
1351-
1352-
if len(proc_ensemble.futures) > 0:
1353-
# Also add ensemble runs that did not finish within smac time
1354-
# and add them into the ensemble history
1355-
self._logger.info("Ensemble script still running, waiting for it to finish.")
1356-
result = proc_ensemble.futures.pop().result()
1357-
if result:
1358-
ensemble_history, _, _, _ = result
1359-
self._results_manager.ensemble_performance_history.extend(ensemble_history)
1360-
self._logger.info("Ensemble script finished, continue shutdown.")
1361-
1362-
# save the ensemble performance history file
1363-
if len(self.ensemble_performance_history) > 0:
1364-
pd.DataFrame(self.ensemble_performance_history).to_json(
1365-
os.path.join(self._backend.internals_directory, 'ensemble_history.json'))
1347+
self._collect_results_ensemble(proc_ensemble)
13661348

13671349
if load_models:
13681350
self._logger.info("Loading models...")
@@ -1641,7 +1623,7 @@ def fit_pipeline(
16411623
exclude=self.exclude_components,
16421624
search_space_updates=self.search_space_updates)
16431625
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
1644-
self._backend.replace_datamanager(dataset)
1626+
self._backend.save_datamanager(dataset)
16451627

16461628
if self._logger is None:
16471629
self._logger = self._get_logger(dataset.dataset_name)
@@ -1832,7 +1814,7 @@ def fit_ensemble(
18321814
ensemble_fit_task_name = 'EnsembleFit'
18331815
self._stopwatch.start_task(ensemble_fit_task_name)
18341816
if enable_traditional_pipeline:
1835-
if func_eval_time_limit_secs is None or func_eval_time_limit_secs > time_for_task:
1817+
if func_eval_time_limit_secs > time_for_task:
18361818
self._logger.warning(
18371819
'Time limit for a single run is higher than total time '
18381820
'limit. Capping the limit for a single run to the total '
@@ -1873,12 +1855,8 @@ def fit_ensemble(
18731855
)
18741856

18751857
manager.build_ensemble(self._dask_client)
1876-
future = manager.futures.pop()
1877-
result = future.result()
1878-
if result is None:
1879-
raise ValueError("Errors occurred while building the ensemble - please"
1880-
" check the log file and command line output for error messages.")
1881-
self.ensemble_performance_history, _, _, _ = result
1858+
if manager is not None:
1859+
self._collect_results_ensemble(manager)
18821860

18831861
if load_models:
18841862
self._load_models()
@@ -1956,6 +1934,31 @@ def _init_ensemble_builder(
19561934

19571935
return proc_ensemble
19581936

1937+
def _collect_results_ensemble(
1938+
self,
1939+
manager: EnsembleBuilderManager
1940+
) -> None:
1941+
1942+
if self._logger is None:
1943+
raise ValueError("logger should be initialized to fit ensemble")
1944+
1945+
self._results_manager.ensemble_performance_history = list(manager.history)
1946+
1947+
if len(manager.futures) > 0:
1948+
# Also add ensemble runs that did not finish within smac time
1949+
# and add them into the ensemble history
1950+
self._logger.info("Ensemble script still running, waiting for it to finish.")
1951+
result = manager.futures.pop().result()
1952+
if result:
1953+
ensemble_history, _, _, _ = result
1954+
self._results_manager.ensemble_performance_history.extend(ensemble_history)
1955+
self._logger.info("Ensemble script finished, continue shutdown.")
1956+
1957+
# save the ensemble performance history file
1958+
if len(self.ensemble_performance_history) > 0:
1959+
pd.DataFrame(self.ensemble_performance_history).to_json(
1960+
os.path.join(self._backend.internals_directory, 'ensemble_history.json'))
1961+
19591962
def predict(
19601963
self,
19611964
X_test: np.ndarray,

autoPyTorch/api/tabular_classification.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1919
from autoPyTorch.datasets.resampling_strategy import (
2020
HoldoutValTypes,
21+
CrossValTypes,
2122
ResamplingStrategies,
2223
)
2324
from autoPyTorch.datasets.tabular_dataset import TabularDataset
@@ -449,6 +450,7 @@ def search(
449450

450451
if self.dataset is None:
451452
raise ValueError("`dataset` in {} must be initialized, but got None".format(self.__class__.__name__))
453+
452454
return self._search(
453455
dataset=self.dataset,
454456
optimize_metric=optimize_metric,
@@ -488,23 +490,23 @@ def predict(
488490
raise ValueError("predict() is only supported after calling search. Kindly call first "
489491
"the estimator search() method.")
490492

491-
X_test = self.input_validator.feature_validator.transform(X_test)
493+
X_test = self.InputValidator.feature_validator.transform(X_test)
492494
predicted_probabilities = super().predict(X_test, batch_size=batch_size,
493495
n_jobs=n_jobs)
494496

495-
if self.input_validator.target_validator.is_single_column_target():
497+
if self.InputValidator.target_validator.is_single_column_target():
496498
predicted_indexes = np.argmax(predicted_probabilities, axis=1)
497499
else:
498500
predicted_indexes = (predicted_probabilities > 0.5).astype(int)
499501

500502
# Allow to predict in the original domain -- that is, the user is not interested
501503
# in our encoded values
502-
return self.input_validator.target_validator.inverse_transform(predicted_indexes)
504+
return self.InputValidator.target_validator.inverse_transform(predicted_indexes)
503505

504506
def predict_proba(self,
505507
X_test: Union[np.ndarray, pd.DataFrame, List],
506508
batch_size: Optional[int] = None, n_jobs: int = 1) -> np.ndarray:
507-
if self.input_validator is None or not self.input_validator._is_fitted:
509+
if self.InputValidator is None or not self.InputValidator._is_fitted:
508510
raise ValueError("predict() is only supported after calling search. Kindly call first "
509511
"the estimator search() method.")
510512
X_test = self.input_validator.feature_validator.transform(X_test)

autoPyTorch/api/tabular_regression.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1919
from autoPyTorch.datasets.resampling_strategy import (
2020
HoldoutValTypes,
21+
CrossValTypes,
2122
ResamplingStrategies,
2223
)
2324
from autoPyTorch.datasets.tabular_dataset import TabularDataset
@@ -449,6 +450,7 @@ def search(
449450

450451
if self.dataset is None:
451452
raise ValueError("`dataset` in {} must be initialized, but got None".format(self.__class__.__name__))
453+
452454
return self._search(
453455
dataset=self.dataset,
454456
optimize_metric=optimize_metric,
@@ -474,14 +476,14 @@ def predict(
474476
batch_size: Optional[int] = None,
475477
n_jobs: int = 1
476478
) -> np.ndarray:
477-
if self.input_validator is None or not self.input_validator._is_fitted:
479+
if self.InputValidator is None or not self.InputValidator._is_fitted:
478480
raise ValueError("predict() is only supported after calling search. Kindly call first "
479481
"the estimator search() method.")
480482

481-
X_test = self.input_validator.feature_validator.transform(X_test)
483+
X_test = self.InputValidator.feature_validator.transform(X_test)
482484
predicted_values = super().predict(X_test, batch_size=batch_size,
483485
n_jobs=n_jobs)
484486

485487
# Allow to predict in the original domain -- that is, the user is not interested
486488
# in our encoded values
487-
return self.input_validator.target_validator.inverse_transform(predicted_values)
489+
return self.InputValidator.target_validator.inverse_transform(predicted_values)

autoPyTorch/data/base_target_validator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def fit(
8686
np.shape(y_test)
8787
))
8888
if isinstance(y_train, pd.DataFrame):
89-
y_train = cast(pd.DataFrame, y_train)
9089
y_test = cast(pd.DataFrame, y_test)
9190
if y_train.columns.tolist() != y_test.columns.tolist():
9291
raise ValueError(

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from logging import Logger
33
from typing import Dict, List, Optional, Tuple, Union, cast
44

5+
56
import numpy as np
67

78
import pandas as pd
@@ -275,7 +276,7 @@ def transform(
275276
if isinstance(X, np.ndarray):
276277
X = self.numpy_to_pandas(X)
277278

278-
if hasattr(X, "iloc") and not issparse(X):
279+
if ispandas(X) and not issparse(X):
279280
X = cast(pd.DataFrame, X)
280281

281282
# Check the data here so we catch problems on new test data

autoPyTorch/data/tabular_target_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Union, cast
1+
from typing import List, Optional, cast
22

33
import numpy as np
44
import numpy.ma as ma

autoPyTorch/evaluation/fit_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111
from smac.tae import StatusType
1212

13+
from autoPyTorch.automl_common.common.utils.backend import Backend
1314
from autoPyTorch.datasets.resampling_strategy import NoResamplingStrategyTypes
1415
from autoPyTorch.evaluation.abstract_evaluator import (
1516
AbstractEvaluator,
1617
fit_and_suppress_warnings
1718
)
1819
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
19-
from autoPyTorch.utils.backend import Backend
2020
from autoPyTorch.utils.common import subsampler
2121
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
2222

autoPyTorch/optimizer/smbo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(self,
120120
resampling_strategy_args: Optional[Dict[str, Any]] = None,
121121
include: Optional[Dict[str, Any]] = None,
122122
exclude: Optional[Dict[str, Any]] = None,
123-
disable_file_output: List = [],
123+
disable_file_output: Union[bool, List[str]] = False,
124124
smac_scenario_args: Optional[Dict[str, Any]] = None,
125125
get_smac_object_callback: Optional[Callable] = None,
126126
all_supported_metrics: bool = True,

autoPyTorch/pipeline/components/setup/network_backbone/utils.py

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,7 @@ class ShakeDropFunction(Function):
9292
Github URL: https://github.com/owruby/shake-drop_pytorch/blob/master/models/shakedrop.py
9393
"""
9494
@staticmethod
95-
<<<<<<< HEAD
9695
def forward(ctx: Any,
97-
=======
98-
def forward(ctx: typing.Any,
99-
>>>>>>> Bug fixes (#249)
10096
x: torch.Tensor,
10197
alpha: torch.Tensor,
10298
beta: torch.Tensor,
@@ -123,31 +119,20 @@ def backward(ctx: Any,
123119
shake_drop = ShakeDropFunction.apply
124120

125121

126-
<<<<<<< HEAD
127-
def shake_get_alpha_beta(is_training: bool, is_cuda: bool
128-
) -> Tuple[torch.Tensor, torch.Tensor]:
129-
"""
130-
The methods used in this function have been introduced in 'ShakeShake Regularisation'
131-
Currently, this function supports `shake-shake`.
132-
=======
133122
def shake_get_alpha_beta(
134123
is_training: bool,
135124
is_cuda: bool,
136125
method: str
137-
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
126+
) -> Tuple[torch.Tensor, torch.Tensor]:
138127
"""
139128
The methods used in this function have been introduced in 'ShakeShake Regularisation'
140129
Each method name is available in the referred paper.
141130
Currently, this function supports `even-even`, `shake-even`, `shake-shake` and `M3`.
142-
>>>>>>> Bug fixes (#249)
143131
144132
Args:
145133
is_training (bool): Whether the computation for the training
146134
is_cuda (bool): Whether the tensor is on CUDA
147-
<<<<<<< HEAD
148-
=======
149135
method (str): The shake method either `even-even`, `shake-even`, `shake-shake` or `M3`
150-
>>>>>>> Bug fixes (#249)
151136
152137
Returns:
153138
alpha, beta (Tuple[float, float]):
@@ -159,14 +144,8 @@ def shake_get_alpha_beta(
159144
Author: Xavier Gastaldi
160145
URL: https://arxiv.org/abs/1705.07485
161146
162-
<<<<<<< HEAD
163-
Note:
164-
The names have been taken from the paper as well.
165-
Currently, this function supports `shake-shake`.
166-
=======
167147
The names have been taken from the paper as well.
168148
Currently, this function supports `even-even`, `shake-even`, `shake-shake` and `M3`.
169-
>>>>>>> Bug fixes (#249)
170149
"""
171150
if not is_training:
172151
result = (torch.FloatTensor([0.5]), torch.FloatTensor([0.5]))
@@ -196,27 +175,15 @@ def shake_get_alpha_beta(
196175

197176

198177
def shake_drop_get_bl(
199-
<<<<<<< HEAD
200-
block_index: int,
201-
min_prob_no_shake: float,
202-
num_blocks: int,
203-
is_training: bool,
204-
is_cuda: bool
205-
=======
206178
block_index: int,
207179
min_prob_no_shake: float,
208180
num_blocks: int,
209181
is_training: bool,
210182
is_cuda: bool
211-
>>>>>>> Bug fixes (#249)
212183
) -> torch.Tensor:
213184
"""
214185
The sampling of Bernoulli random variable
215186
based on Eq. (4) in the paper
216-
<<<<<<< HEAD
217-
218-
=======
219-
>>>>>>> Bug fixes (#249)
220187
Args:
221188
block_index (int): The index of the block from the input layer
222189
min_prob_no_shake (float): The initial shake probability
@@ -226,28 +193,16 @@ def shake_drop_get_bl(
226193
227194
Returns:
228195
bl (torch.Tensor): a Bernoulli random variable in {0, 1}
229-
<<<<<<< HEAD
230-
231-
=======
232-
>>>>>>> Bug fixes (#249)
233196
Reference:
234197
ShakeDrop Regularization for Deep Residual Learning
235198
Yoshihiro Yamada et. al. (2020)
236199
paper: https://arxiv.org/pdf/1802.02375.pdf
237200
implementation: https://github.com/imenurok/ShakeDrop
238201
"""
239-
<<<<<<< HEAD
240-
241-
pl = 1 - ((block_index + 1) / num_blocks) * (1 - min_prob_no_shake)
242-
243-
if is_training:
244-
# Move to torch.rand(1) for reproducibility
245-
=======
246202
pl = 1 - ((block_index + 1) / num_blocks) * (1 - min_prob_no_shake)
247203

248204
if is_training:
249205
# Move to torch.randn(1) for reproducibility
250-
>>>>>>> Bug fixes (#249)
251206
bl = torch.as_tensor(1.0) if torch.rand(1) <= pl else torch.as_tensor(0.0)
252207
else:
253208
bl = torch.as_tensor(pl)

0 commit comments

Comments
 (0)