diff --git a/autoPyTorch/api/tabular_classification.py b/autoPyTorch/api/tabular_classification.py index 3effdf371..1a73d8625 100644 --- a/autoPyTorch/api/tabular_classification.py +++ b/autoPyTorch/api/tabular_classification.py @@ -159,7 +159,8 @@ def get_dataset(self, validator=InputValidator, resampling_strategy=resampling_strategy, resampling_strategy_args=resampling_strategy_args, - dataset_name=dataset_name + dataset_name=dataset_name, + seed=self.seed ) if not return_only: self.InputValidator = InputValidator diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index 939ed9391..e7fb919bd 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -151,7 +151,8 @@ def get_dataset(self, validator=InputValidator, resampling_strategy=resampling_strategy, resampling_strategy_args=resampling_strategy_args, - dataset_name=dataset_name + dataset_name=dataset_name, + seed=self.seed ) if not return_only: self.InputValidator = InputValidator