Skip to content

Commit 4b6107b

Browse files
AA_151
1 parent 68fc77f commit 4b6107b

File tree

5 files changed

+117
-27
lines changed

5 files changed

+117
-27
lines changed

autoPyTorch/api/base_task.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -486,11 +486,14 @@ def _load_best_individual_model(self) -> SingleBest:
486486

487487
return ensemble
488488

489-
def _do_dummy_prediction(self, num_run: int) -> None:
489+
def _do_dummy_prediction(self) -> None:
490490

491491
assert self._metric is not None
492492
assert self._logger is not None
493493

494+
# For dummy estimator, we always expect the num_run to be 1
495+
num_run = 1
496+
494497
self._logger.info("Starting to create dummy predictions.")
495498

496499
memory_limit = self._memory_limit
@@ -551,29 +554,20 @@ def _do_dummy_prediction(self, num_run: int) -> None:
551554
% (str(status), str(additional_info))
552555
)
553556

554-
def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_time_limit_secs: int
555-
) -> int:
557+
def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs: int) -> None:
556558
"""
557559
Fits traditional machine learning algorithms to the provided dataset, while
558560
complying with time resource allocation.
559561
560562
This method currently only supports classification.
561563
562564
Args:
563-
num_run: (int)
564-
An identifier to indicate the current machine learning algorithm
565-
being processed
566565
time_left: (int)
567566
Hard limit on how many machine learning algorithms can be fit. Depending on how
568567
fast a traditional machine learning algorithm trains, it will allow multiple
569568
models to be fitted.
570569
func_eval_time_limit_secs: (int)
571570
Maximum training time each algorithm is allowed to take, during training
572-
573-
Returns:
574-
num_run: (int)
575-
The incremented identifier index. This depends on how many machine learning
576-
models were fitted.
577571
"""
578572

579573
# Mypy Checkings -- Traditional prediction is only called for search
@@ -588,8 +582,8 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
588582
available_classifiers = get_available_classifiers()
589583
dask_futures = []
590584

591-
total_number_classifiers = len(available_classifiers) + num_run
592-
for n_r, classifier in enumerate(available_classifiers, start=num_run):
585+
total_number_classifiers = len(available_classifiers)
586+
for n_r, classifier in enumerate(available_classifiers):
593587

594588
# Only launch a task if there is time
595589
start_time = time.time()
@@ -608,7 +602,7 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
608602
logger_port=self._logger_port,
609603
cost_for_crash=get_cost_of_crash(self._metric),
610604
abort_on_first_run_crash=False,
611-
initial_num_run=n_r,
605+
initial_num_run=self._backend.get_next_num_run(),
612606
stats=stats,
613607
memory_limit=memory_limit,
614608
disable_file_output=True if len(self._disable_file_output) > 0 else False,
@@ -622,9 +616,6 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
622616
)
623617
])
624618

625-
# Increment the launched job index
626-
num_run = n_r
627-
628619
# When managing time, we need to take into account the allocated time resources,
629620
# which are dependent on the number of cores. 'dask_futures' is a proxy to the number
630621
# of workers /n_jobs that we have, in that if there are 4 cores allocated, we can run at most
@@ -677,7 +668,7 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
677668
"Please consider increasing the run time to further improve performance.")
678669
break
679670

680-
return num_run
671+
return
681672

682673
def _search(
683674
self,
@@ -847,10 +838,9 @@ def _search(
847838
)
848839

849840
# ============> Run dummy predictions
850-
num_run = 1
851841
dummy_task_name = 'runDummy'
852842
self._stopwatch.start_task(dummy_task_name)
853-
self._do_dummy_prediction(num_run)
843+
self._do_dummy_prediction()
854844
self._stopwatch.stop_task(dummy_task_name)
855845

856846
# ============> Run traditional ml
@@ -866,8 +856,8 @@ def _search(
866856
time_for_traditional = int(
867857
self._time_for_task - elapsed_time - func_eval_time_limit_secs
868858
)
869-
num_run = self._do_traditional_prediction(
870-
num_run=num_run + 1, func_eval_time_limit_secs=func_eval_time_limit_secs,
859+
self._do_traditional_prediction(
860+
func_eval_time_limit_secs=func_eval_time_limit_secs,
871861
time_left=time_for_traditional,
872862
)
873863
self._stopwatch.stop_task(traditional_task_name)
@@ -943,7 +933,9 @@ def _search(
943933
pipeline_config={**self.pipeline_options, **budget_config},
944934
ensemble_callback=proc_ensemble,
945935
logger_port=self._logger_port,
946-
start_num_run=num_run,
936+
# We do not increase the num_run here, this is something
937+
# smac does internally
938+
start_num_run=self._backend.get_next_num_run(peek=True),
947939
search_space_updates=self.search_space_updates
948940
)
949941
try:
@@ -1048,7 +1040,7 @@ def refit(
10481040
'train_indices': dataset.splits[split_id][0],
10491041
'val_indices': dataset.splits[split_id][1],
10501042
'split_id': split_id,
1051-
'num_run': 0
1043+
'num_run': self._backend.get_next_num_run(),
10521044
})
10531045
X.update({**self.pipeline_options, **budget_config})
10541046
if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None:
@@ -1125,7 +1117,7 @@ def fit(self,
11251117
'train_indices': dataset.splits[split_id][0],
11261118
'val_indices': dataset.splits[split_id][1],
11271119
'split_id': split_id,
1128-
'num_run': 0
1120+
'num_run': self._backend.get_next_num_run(),
11291121
})
11301122
X.update({**self.pipeline_options, **budget_config})
11311123

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ def __init__(self, config: Configuration,
134134
random_state: Optional[Union[int, np.random.RandomState]] = None,
135135
init_params: Optional[Dict] = None
136136
) -> None:
137-
self.configuration = config
137+
self.config = config
138+
self.init_params = init_params
139+
self.random_state = random_state
138140
if config == 1:
139141
super(DummyClassificationPipeline, self).__init__(strategy="uniform")
140142
else:
@@ -198,7 +200,9 @@ class DummyRegressionPipeline(DummyRegressor):
198200
def __init__(self, config: Configuration,
199201
random_state: Optional[Union[int, np.random.RandomState]] = None,
200202
init_params: Optional[Dict] = None) -> None:
201-
self.configuration = config
203+
self.config = config
204+
self.init_params = init_params
205+
self.random_state = random_state
202206
if config == 1:
203207
super(DummyRegressionPipeline, self).__init__(strategy='mean')
204208
else:

autoPyTorch/utils/backend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ def __init__(self, context: BackendContext):
169169
self._logger = None # type: Optional[PicklableClientLogger]
170170
self.context = context
171171

172+
# Track the number of configurations launched
173+
# num_run == 1 means a dummy estimator run
174+
self.active_num_run = 1
175+
172176
# Create the temporary directory if it does not yet exist
173177
try:
174178
os.makedirs(self.temporary_directory)
@@ -329,6 +333,25 @@ def get_runs_directory(self) -> str:
329333
def get_numrun_directory(self, seed: int, num_run: int, budget: float) -> str:
330334
return os.path.join(self.internals_directory, 'runs', '%d_%d_%s' % (seed, num_run, budget))
331335

336+
def get_next_num_run(self, peek: bool = False) -> int:
337+
338+
# If there are other num_runs, their name would be runs/<seed>_<num_run>_<budget>
339+
other_num_runs = [int(os.path.basename(run_dir).split('_')[1])
340+
for run_dir in glob.glob(os.path.join(self.internals_directory,
341+
'runs',
342+
'*'))]
343+
if len(other_num_runs) > 0:
344+
# We track the number of runs from two forefronts:
345+
# The physically available num_runs (which might be deleted or a crash could happen)
346+
# From a internally kept attribute. The later should be sufficient, but we
347+
# want to be robust against multiple backend copies on different workers
348+
self.active_num_run = max([self.active_num_run] + other_num_runs)
349+
350+
# We are interested in the next run id
351+
if not peek:
352+
self.active_num_run += 1
353+
return self.active_num_run
354+
332355
def get_model_filename(self, seed: int, idx: int, budget: float) -> str:
333356
return '%s.%s.%s.model' % (seed, idx, budget)
334357

test/test_api/test_api.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212

1313
import sklearn
1414
import sklearn.datasets
15+
from sklearn.base import clone
1516
from sklearn.ensemble import VotingClassifier, VotingRegressor
1617

18+
1719
import torch
1820

1921
from autoPyTorch.api.tabular_classification import TabularClassificationTask
@@ -23,6 +25,7 @@
2325
HoldoutValTypes,
2426
)
2527
from autoPyTorch.optimizer.smbo import AutoMLSMBO
28+
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy
2629

2730

2831
# Fixtures
@@ -394,3 +397,48 @@ def test_tabular_input_support(openml_id, backend):
394397
enable_traditional_pipeline=False,
395398
load_models=False,
396399
)
400+
401+
402+
@pytest.mark.parametrize("fit_dictionary_tabular", ['classification_categorical_only'], indirect=True)
403+
def test_do_dummy_prediction(dask_client, fit_dictionary_tabular):
404+
backend = fit_dictionary_tabular['backend']
405+
estimator = TabularClassificationTask(
406+
backend=backend,
407+
resampling_strategy=HoldoutValTypes.holdout_validation,
408+
ensemble_size=0,
409+
)
410+
411+
# Setup pre-requisites normally set by search()
412+
estimator._create_dask_client()
413+
estimator._metric = accuracy
414+
estimator._logger = estimator._get_logger('test')
415+
estimator._memory_limit = 5000
416+
estimator._time_for_task = 60
417+
estimator._disable_file_output = []
418+
estimator._all_supported_metrics = False
419+
420+
estimator._do_dummy_prediction()
421+
422+
# Ensure that the dummy predictions are not in the current working
423+
# directory, but in the temporary directory.
424+
assert not os.path.exists(os.path.join(os.getcwd(), '.autoPyTorch'))
425+
assert os.path.exists(os.path.join(
426+
backend.temporary_directory, '.autoPyTorch', 'runs', '1_1_1.0',
427+
'predictions_ensemble_1_1_1.0.npy')
428+
)
429+
430+
model_path = os.path.join(backend.temporary_directory,
431+
'.autoPyTorch',
432+
'runs', '1_1_1.0',
433+
'1.1.1.0.model')
434+
435+
# Make sure the dummy model complies with scikit learn
436+
# get/set params
437+
assert os.path.exists(model_path)
438+
with open(model_path, 'rb') as model_handler:
439+
clone(pickle.load(model_handler))
440+
441+
estimator._close_dask_client()
442+
estimator._clean_logger()
443+
444+
del estimator

test/test_utils/test_backend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# -*- encoding: utf-8 -*-
22
import builtins
3+
import logging.handlers
34
import unittest
45
import unittest.mock
56

7+
import numpy as np
8+
69
import pytest
710

811
from autoPyTorch.utils.backend import Backend
@@ -81,3 +84,23 @@ def test_loads_models_by_identifiers(exists_mock, openMock, pickleLoadMock, back
8184

8285
assert isinstance(actual_dict, dict)
8386
assert expected_dict == actual_dict
87+
88+
89+
def test_get_next_num_run(backend):
90+
# Asking for a num_run increases the tracked num_run
91+
assert backend.get_next_num_run() == 2
92+
assert backend.get_next_num_run() == 3
93+
# Then test that we are robust against new files being generated
94+
backend.setup_logger('Test', logging.handlers.DEFAULT_TCP_LOGGING_PORT)
95+
backend.save_numrun_to_dir(
96+
seed=0,
97+
idx=12,
98+
budget=0.0,
99+
model=dict(),
100+
cv_model=None,
101+
ensemble_predictions=np.zeros(10),
102+
valid_predictions=None,
103+
test_predictions=None,
104+
)
105+
assert backend.get_next_num_run() == 13
106+
assert backend.get_next_num_run(peek=True) == 13

0 commit comments

Comments
 (0)