diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py deleted file mode 100644 index 52fc0c9a80615..0000000000000 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union - -from pytorch_lightning.loggers.base import DummyLogger -from pytorch_lightning.utilities import rank_zero_info -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -class DebuggingConnector: - def __init__(self, trainer): - self.trainer = trainer - - def on_init_start( - self, - limit_train_batches, - limit_val_batches, - limit_test_batches, - limit_predict_batches, - val_check_interval, - overfit_batches, - fast_dev_run, - ): - if not isinstance(fast_dev_run, (bool, int)): - raise MisconfigurationException( - f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be either a bool or an int >= 0" - ) - - if isinstance(fast_dev_run, int) and (fast_dev_run < 0): - raise MisconfigurationException( - f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be >= 0." - ) - - self.trainer.fast_dev_run = fast_dev_run - fast_dev_run = int(fast_dev_run) - - # set fast_dev_run=True when it is 1, used while logging - if fast_dev_run == 1: - self.trainer.fast_dev_run = True - - if fast_dev_run: - limit_train_batches = fast_dev_run - limit_val_batches = fast_dev_run - limit_test_batches = fast_dev_run - limit_predict_batches = fast_dev_run - self.trainer.fit_loop.max_steps = fast_dev_run - self.trainer.num_sanity_val_steps = 0 - self.trainer.fit_loop.max_epochs = 1 - val_check_interval = 1.0 - self.trainer.check_val_every_n_epoch = 1 - self.trainer.logger = DummyLogger() if self.trainer.logger is not None else None - - rank_zero_info( - "Running in fast_dev_run mode: will run a full train," - f" val, test and prediction loop using {fast_dev_run} batch(es)." - ) - - self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, "limit_train_batches") - self.trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches") - self.trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches") - self.trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches") - self.trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval") - self.trainer.overfit_batches = _determine_batch_limits(overfit_batches, "overfit_batches") - self.determine_data_use_amount(self.trainer.overfit_batches) - - def determine_data_use_amount(self, overfit_batches: float) -> None: - """Use less data for debugging purposes.""" - if overfit_batches > 0: - self.trainer.limit_train_batches = overfit_batches - self.trainer.limit_val_batches = overfit_batches - self.trainer.limit_test_batches = overfit_batches - - -def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: - if 0 <= batches <= 1: - return batches - if batches > 1 and batches % 1.0 == 0: - return int(batches) - raise MisconfigurationException( - f"You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int." - ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 74522424c5326..944739e1c3681 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -33,7 +33,7 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.loggers.base import LoggerCollection +from pytorch_lightning.loggers.base import DummyLogger, LoggerCollection from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop @@ -53,7 +53,6 @@ from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector -from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection @@ -450,7 +449,6 @@ def __init__( ) self.logger_connector = LoggerConnector(self, log_gpu_memory) self._callback_connector = CallbackConnector(self) - self.debugging_connector = DebuggingConnector(self) self.checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint) self.signal_connector = SignalConnector(self) self.tuner = Tuner(self) @@ -574,7 +572,7 @@ def __init__( self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu) # init debugging flags - self.debugging_connector.on_init_start( + self._init_debugging_flags( limit_train_batches, limit_val_batches, limit_test_batches, @@ -587,6 +585,65 @@ def __init__( # Callback system self.on_init_end() + def _init_debugging_flags( + self, + limit_train_batches, + limit_val_batches, + limit_test_batches, + limit_predict_batches, + val_check_interval, + overfit_batches, + fast_dev_run, + ): + if not isinstance(fast_dev_run, (bool, int)): + raise MisconfigurationException( + f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be either a bool or an int >= 0" + ) + + if isinstance(fast_dev_run, int) and (fast_dev_run < 0): + raise MisconfigurationException( + f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be >= 0." + ) + + self.fast_dev_run = fast_dev_run + fast_dev_run = int(fast_dev_run) + + # set fast_dev_run=True when it is 1, used while logging + if fast_dev_run == 1: + self.fast_dev_run = True + + if fast_dev_run: + limit_train_batches = fast_dev_run + limit_val_batches = fast_dev_run + limit_test_batches = fast_dev_run + limit_predict_batches = fast_dev_run + self.fit_loop.max_steps = fast_dev_run + self.num_sanity_val_steps = 0 + self.fit_loop.max_epochs = 1 + val_check_interval = 1.0 + self.check_val_every_n_epoch = 1 + self.logger = DummyLogger() if self.logger is not None else None + + rank_zero_info( + "Running in fast_dev_run mode: will run a full train," + f" val, test and prediction loop using {fast_dev_run} batch(es)." + ) + + self.limit_train_batches = _determine_batch_limits(limit_train_batches, "limit_train_batches") + self.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches") + self.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches") + self.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches") + self.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval") + self.overfit_batches = _determine_batch_limits(overfit_batches, "overfit_batches") + self.determine_data_use_amount(self.overfit_batches) + + def determine_data_use_amount(self, overfit_batches: float) -> None: + """Use less data for debugging purposes.""" + if overfit_batches > 0: + self.limit_train_batches = overfit_batches + self.limit_val_batches = overfit_batches + self.limit_test_batches = overfit_batches + def _setup_on_init(self, num_sanity_val_steps: int) -> None: self._log_device_info() @@ -2133,3 +2190,13 @@ def terminate_on_nan(self, val: bool) -> None: f" Please set `Trainer(detect_anomaly={val})` instead." ) self._terminate_on_nan = val # : 212 + + +def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: + if 0 <= batches <= 1: + return batches + if batches > 1 and batches % 1.0 == 0: + return int(batches) + raise MisconfigurationException( + f"You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int." + )