Skip to content

Commit cb8a216

Browse files
Remove debugging_connector.py (#10113)
1 parent 2b24be2 commit cb8a216

File tree

2 files changed

+71
-97
lines changed

2 files changed

+71
-97
lines changed

pytorch_lightning/trainer/connectors/debugging_connector.py

Lines changed: 0 additions & 93 deletions
This file was deleted.

pytorch_lightning/trainer/trainer.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from pytorch_lightning.core.datamodule import LightningDataModule
3434
from pytorch_lightning.core.optimizer import LightningOptimizer
3535
from pytorch_lightning.loggers import LightningLoggerBase
36-
from pytorch_lightning.loggers.base import LoggerCollection
36+
from pytorch_lightning.loggers.base import DummyLogger, LoggerCollection
3737
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
3838
from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop
3939
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
@@ -53,7 +53,6 @@
5353
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
5454
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
5555
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
56-
from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector
5756
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
5857
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
5958
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
@@ -450,7 +449,6 @@ def __init__(
450449
)
451450
self.logger_connector = LoggerConnector(self, log_gpu_memory)
452451
self._callback_connector = CallbackConnector(self)
453-
self.debugging_connector = DebuggingConnector(self)
454452
self.checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint)
455453
self.signal_connector = SignalConnector(self)
456454
self.tuner = Tuner(self)
@@ -574,7 +572,7 @@ def __init__(
574572
self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu)
575573

576574
# init debugging flags
577-
self.debugging_connector.on_init_start(
575+
self._init_debugging_flags(
578576
limit_train_batches,
579577
limit_val_batches,
580578
limit_test_batches,
@@ -587,6 +585,65 @@ def __init__(
587585
# Callback system
588586
self.on_init_end()
589587

588+
def _init_debugging_flags(
589+
self,
590+
limit_train_batches,
591+
limit_val_batches,
592+
limit_test_batches,
593+
limit_predict_batches,
594+
val_check_interval,
595+
overfit_batches,
596+
fast_dev_run,
597+
):
598+
if not isinstance(fast_dev_run, (bool, int)):
599+
raise MisconfigurationException(
600+
f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be either a bool or an int >= 0"
601+
)
602+
603+
if isinstance(fast_dev_run, int) and (fast_dev_run < 0):
604+
raise MisconfigurationException(
605+
f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be >= 0."
606+
)
607+
608+
self.fast_dev_run = fast_dev_run
609+
fast_dev_run = int(fast_dev_run)
610+
611+
# set fast_dev_run=True when it is 1, used while logging
612+
if fast_dev_run == 1:
613+
self.fast_dev_run = True
614+
615+
if fast_dev_run:
616+
limit_train_batches = fast_dev_run
617+
limit_val_batches = fast_dev_run
618+
limit_test_batches = fast_dev_run
619+
limit_predict_batches = fast_dev_run
620+
self.fit_loop.max_steps = fast_dev_run
621+
self.num_sanity_val_steps = 0
622+
self.fit_loop.max_epochs = 1
623+
val_check_interval = 1.0
624+
self.check_val_every_n_epoch = 1
625+
self.logger = DummyLogger() if self.logger is not None else None
626+
627+
rank_zero_info(
628+
"Running in fast_dev_run mode: will run a full train,"
629+
f" val, test and prediction loop using {fast_dev_run} batch(es)."
630+
)
631+
632+
self.limit_train_batches = _determine_batch_limits(limit_train_batches, "limit_train_batches")
633+
self.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches")
634+
self.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches")
635+
self.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches")
636+
self.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval")
637+
self.overfit_batches = _determine_batch_limits(overfit_batches, "overfit_batches")
638+
self.determine_data_use_amount(self.overfit_batches)
639+
640+
def determine_data_use_amount(self, overfit_batches: float) -> None:
641+
"""Use less data for debugging purposes."""
642+
if overfit_batches > 0:
643+
self.limit_train_batches = overfit_batches
644+
self.limit_val_batches = overfit_batches
645+
self.limit_test_batches = overfit_batches
646+
590647
def _setup_on_init(self, num_sanity_val_steps: int) -> None:
591648
self._log_device_info()
592649

@@ -2133,3 +2190,13 @@ def terminate_on_nan(self, val: bool) -> None:
21332190
f" Please set `Trainer(detect_anomaly={val})` instead."
21342191
)
21352192
self._terminate_on_nan = val # : 212
2193+
2194+
2195+
def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
2196+
if 0 <= batches <= 1:
2197+
return batches
2198+
if batches > 1 and batches % 1.0 == 0:
2199+
return int(batches)
2200+
raise MisconfigurationException(
2201+
f"You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int."
2202+
)

0 commit comments

Comments
 (0)