Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 17 additions & 25 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,14 @@ def _load_best_individual_model(self) -> SingleBest:

return ensemble

def _do_dummy_prediction(self, num_run: int) -> None:
def _do_dummy_prediction(self) -> None:

assert self._metric is not None
assert self._logger is not None

# For dummy estimator, we always expect the num_run to be 1
num_run = 1

self._logger.info("Starting to create dummy predictions.")

memory_limit = self._memory_limit
Expand Down Expand Up @@ -551,29 +554,20 @@ def _do_dummy_prediction(self, num_run: int) -> None:
% (str(status), str(additional_info))
)

def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_time_limit_secs: int
) -> int:
def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs: int) -> None:
"""
Fits traditional machine learning algorithms to the provided dataset, while
complying with time resource allocation.

This method currently only supports classification.

Args:
num_run: (int)
An identifier to indicate the current machine learning algorithm
being processed
time_left: (int)
Hard limit on how many machine learning algorithms can be fit. Depending on how
fast a traditional machine learning algorithm trains, it will allow multiple
models to be fitted.
func_eval_time_limit_secs: (int)
Maximum training time each algorithm is allowed to take, during training

Returns:
num_run: (int)
The incremented identifier index. This depends on how many machine learning
models were fitted.
"""

# Mypy Checkings -- Traditional prediction is only called for search
Expand All @@ -592,8 +586,8 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
available_classifiers = get_available_classifiers()
dask_futures = []

total_number_classifiers = len(available_classifiers) + num_run
for n_r, classifier in enumerate(available_classifiers, start=num_run):
total_number_classifiers = len(available_classifiers)
for n_r, classifier in enumerate(available_classifiers):

# Only launch a task if there is time
start_time = time.time()
Expand All @@ -612,7 +606,7 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
logger_port=self._logger_port,
cost_for_crash=get_cost_of_crash(self._metric),
abort_on_first_run_crash=False,
initial_num_run=n_r,
initial_num_run=self._backend.get_next_num_run(),
stats=stats,
memory_limit=memory_limit,
disable_file_output=True if len(self._disable_file_output) > 0 else False,
Expand All @@ -626,9 +620,6 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
)
])

# Increment the launched job index
num_run = n_r

# When managing time, we need to take into account the allocated time resources,
# which are dependent on the number of cores. 'dask_futures' is a proxy to the number
# of workers /n_jobs that we have, in that if there are 4 cores allocated, we can run at most
Expand Down Expand Up @@ -691,7 +682,7 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
self.run_history.update(run_history, DataOrigin.EXTERNAL_SAME_INSTANCES)
run_history.save_json(os.path.join(self._backend.internals_directory, 'traditional_run_history.json'),
save_external=True)
return num_run
return

def _search(
self,
Expand Down Expand Up @@ -861,10 +852,9 @@ def _search(
)

# ============> Run dummy predictions
num_run = 1
dummy_task_name = 'runDummy'
self._stopwatch.start_task(dummy_task_name)
self._do_dummy_prediction(num_run)
self._do_dummy_prediction()
self._stopwatch.stop_task(dummy_task_name)

# ============> Run traditional ml
Expand All @@ -880,8 +870,8 @@ def _search(
time_for_traditional = int(
self._time_for_task - elapsed_time - func_eval_time_limit_secs
)
num_run = self._do_traditional_prediction(
num_run=num_run + 1, func_eval_time_limit_secs=func_eval_time_limit_secs,
self._do_traditional_prediction(
func_eval_time_limit_secs=func_eval_time_limit_secs,
time_left=time_for_traditional,
)
self._stopwatch.stop_task(traditional_task_name)
Expand Down Expand Up @@ -957,7 +947,9 @@ def _search(
pipeline_config={**self.pipeline_options, **budget_config},
ensemble_callback=proc_ensemble,
logger_port=self._logger_port,
start_num_run=num_run,
# We do not increase the num_run here, this is something
# smac does internally
start_num_run=self._backend.get_next_num_run(peek=True),
search_space_updates=self.search_space_updates
)
try:
Expand Down Expand Up @@ -1063,7 +1055,7 @@ def refit(
'train_indices': dataset.splits[split_id][0],
'val_indices': dataset.splits[split_id][1],
'split_id': split_id,
'num_run': 0
'num_run': self._backend.get_next_num_run(),
})
X.update({**self.pipeline_options, **budget_config})
if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None:
Expand Down Expand Up @@ -1140,7 +1132,7 @@ def fit(self,
'train_indices': dataset.splits[split_id][0],
'val_indices': dataset.splits[split_id][1],
'split_id': split_id,
'num_run': 0
'num_run': self._backend.get_next_num_run(),
})
X.update({**self.pipeline_options, **budget_config})

Expand Down
8 changes: 6 additions & 2 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def __init__(self, config: Configuration,
random_state: Optional[Union[int, np.random.RandomState]] = None,
init_params: Optional[Dict] = None
) -> None:
self.configuration = config
self.config = config
self.init_params = init_params
self.random_state = random_state
if config == 1:
super(DummyClassificationPipeline, self).__init__(strategy="uniform")
else:
Expand Down Expand Up @@ -208,7 +210,9 @@ class DummyRegressionPipeline(DummyRegressor):
def __init__(self, config: Configuration,
random_state: Optional[Union[int, np.random.RandomState]] = None,
init_params: Optional[Dict] = None) -> None:
self.configuration = config
self.config = config
self.init_params = init_params
self.random_state = random_state
if config == 1:
super(DummyRegressionPipeline, self).__init__(strategy='mean')
else:
Expand Down
45 changes: 45 additions & 0 deletions autoPyTorch/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def __init__(self, context: BackendContext):
self._logger = None # type: Optional[PicklableClientLogger]
self.context = context

# Track the number of configurations launched
# num_run == 1 means a dummy estimator run
self.active_num_run = 1

# Create the temporary directory if it does not yet exist
try:
os.makedirs(self.temporary_directory)
Expand Down Expand Up @@ -329,6 +333,47 @@ def get_runs_directory(self) -> str:
def get_numrun_directory(self, seed: int, num_run: int, budget: float) -> str:
return os.path.join(self.internals_directory, 'runs', '%d_%d_%s' % (seed, num_run, budget))

def get_next_num_run(self, peek: bool = False) -> int:
"""
Every pipeline that is fitted by the estimator is stored with an
identifier called num_run. A dummy classifier will always have a num_run
equal to 1, and all other new configurations that are explored will
have a sequentially increasing identifier.

This method returns the next num_run a configuration should take.

Parameters
----------
peek: bool
By default, the next num_rum will be returned, i.e. self.active_num_run + 1
Yet, if this bool parameter is equal to True, the value of the current
num_run is provided, i.e, self.active_num_run.
In other words, peek allows to get the current maximum identifier
of a configuration.

Returns
-------
num_run: int
An unique identifier for a configuration
"""

# If there are other num_runs, their name would be runs/<seed>_<num_run>_<budget>
other_num_runs = [int(os.path.basename(run_dir).split('_')[1])
for run_dir in glob.glob(os.path.join(self.internals_directory,
'runs',
'*'))]
if len(other_num_runs) > 0:
# We track the number of runs from two forefronts:
# The physically available num_runs (which might be deleted or a crash could happen)
# From a internally kept attribute. The later should be sufficient, but we
# want to be robust against multiple backend copies on different workers
self.active_num_run = max([self.active_num_run] + other_num_runs)

# We are interested in the next run id
if not peek:
self.active_num_run += 1
return self.active_num_run

def get_model_filename(self, seed: int, idx: int, budget: float) -> str:
return '%s.%s.%s.model' % (seed, idx, budget)

Expand Down
47 changes: 47 additions & 0 deletions test/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import sklearn
import sklearn.datasets
from sklearn.base import clone
from sklearn.ensemble import VotingClassifier, VotingRegressor

from smac.runhistory.runhistory import RunHistory
Expand All @@ -25,6 +26,7 @@
HoldoutValTypes,
)
from autoPyTorch.optimizer.smbo import AutoMLSMBO
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy


# Fixtures
Expand Down Expand Up @@ -402,3 +404,48 @@ def test_tabular_input_support(openml_id, backend):
enable_traditional_pipeline=False,
load_models=False,
)


@pytest.mark.parametrize("fit_dictionary_tabular", ['classification_categorical_only'], indirect=True)
def test_do_dummy_prediction(dask_client, fit_dictionary_tabular):
backend = fit_dictionary_tabular['backend']
estimator = TabularClassificationTask(
backend=backend,
resampling_strategy=HoldoutValTypes.holdout_validation,
ensemble_size=0,
)

# Setup pre-requisites normally set by search()
estimator._create_dask_client()
estimator._metric = accuracy
estimator._logger = estimator._get_logger('test')
estimator._memory_limit = 5000
estimator._time_for_task = 60
estimator._disable_file_output = []
estimator._all_supported_metrics = False

estimator._do_dummy_prediction()

# Ensure that the dummy predictions are not in the current working
# directory, but in the temporary directory.
assert not os.path.exists(os.path.join(os.getcwd(), '.autoPyTorch'))
assert os.path.exists(os.path.join(
backend.temporary_directory, '.autoPyTorch', 'runs', '1_1_1.0',
'predictions_ensemble_1_1_1.0.npy')
)

model_path = os.path.join(backend.temporary_directory,
'.autoPyTorch',
'runs', '1_1_1.0',
'1.1.1.0.model')

# Make sure the dummy model complies with scikit learn
# get/set params
assert os.path.exists(model_path)
with open(model_path, 'rb') as model_handler:
clone(pickle.load(model_handler))

estimator._close_dask_client()
estimator._clean_logger()

del estimator
4 changes: 2 additions & 2 deletions test/test_pipeline/test_tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,5 +439,5 @@ def test_constant_pipeline_iris(fit_dictionary_tabular):
val_score = run_summary.performance_tracker['val_metrics'][epoch_where_best]['balanced_accuracy']
train_score = run_summary.performance_tracker['train_metrics'][epoch_where_best]['balanced_accuracy']

assert val_score >= 0.9, run_summary.performance_tracker['val_metrics']
assert train_score >= 0.9, run_summary.performance_tracker['train_metrics']
assert val_score >= 0.8, run_summary.performance_tracker['val_metrics']
assert train_score >= 0.8, run_summary.performance_tracker['train_metrics']
23 changes: 23 additions & 0 deletions test/test_utils/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# -*- encoding: utf-8 -*-
import builtins
import logging.handlers
import unittest
import unittest.mock

import numpy as np

import pytest

from autoPyTorch.utils.backend import Backend
Expand Down Expand Up @@ -81,3 +84,23 @@ def test_loads_models_by_identifiers(exists_mock, openMock, pickleLoadMock, back

assert isinstance(actual_dict, dict)
assert expected_dict == actual_dict


def test_get_next_num_run(backend):
# Asking for a num_run increases the tracked num_run
assert backend.get_next_num_run() == 2
assert backend.get_next_num_run() == 3
# Then test that we are robust against new files being generated
backend.setup_logger('Test', logging.handlers.DEFAULT_TCP_LOGGING_PORT)
backend.save_numrun_to_dir(
seed=0,
idx=12,
budget=0.0,
model=dict(),
cv_model=None,
ensemble_predictions=np.zeros(10),
valid_predictions=None,
test_predictions=None,
)
assert backend.get_next_num_run() == 13
assert backend.get_next_num_run(peek=True) == 13