3333from pytorch_lightning .core .datamodule import LightningDataModule
3434from pytorch_lightning .core .optimizer import LightningOptimizer
3535from pytorch_lightning .loggers import LightningLoggerBase
36- from pytorch_lightning .loggers .base import LoggerCollection
36+ from pytorch_lightning .loggers .base import DummyLogger , LoggerCollection
3737from pytorch_lightning .loggers .tensorboard import TensorBoardLogger
3838from pytorch_lightning .loops import PredictionLoop , TrainingBatchLoop , TrainingEpochLoop
3939from pytorch_lightning .loops .dataloader .evaluation_loop import EvaluationLoop
5353from pytorch_lightning .trainer .connectors .callback_connector import CallbackConnector
5454from pytorch_lightning .trainer .connectors .checkpoint_connector import CheckpointConnector
5555from pytorch_lightning .trainer .connectors .data_connector import DataConnector
56- from pytorch_lightning .trainer .connectors .debugging_connector import DebuggingConnector
5756from pytorch_lightning .trainer .connectors .env_vars_connector import _defaults_from_env_vars
5857from pytorch_lightning .trainer .connectors .logger_connector import LoggerConnector
5958from pytorch_lightning .trainer .connectors .logger_connector .result import ResultCollection
@@ -450,7 +449,6 @@ def __init__(
450449 )
451450 self .logger_connector = LoggerConnector (self , log_gpu_memory )
452451 self ._callback_connector = CallbackConnector (self )
453- self .debugging_connector = DebuggingConnector (self )
454452 self .checkpoint_connector = CheckpointConnector (self , resume_from_checkpoint )
455453 self .signal_connector = SignalConnector (self )
456454 self .tuner = Tuner (self )
@@ -574,7 +572,7 @@ def __init__(
574572 self .logger_connector .on_trainer_init (logger , flush_logs_every_n_steps , log_every_n_steps , move_metrics_to_cpu )
575573
576574 # init debugging flags
577- self .debugging_connector . on_init_start (
575+ self ._init_debugging_flags (
578576 limit_train_batches ,
579577 limit_val_batches ,
580578 limit_test_batches ,
@@ -587,6 +585,65 @@ def __init__(
587585 # Callback system
588586 self .on_init_end ()
589587
588+ def _init_debugging_flags (
589+ self ,
590+ limit_train_batches ,
591+ limit_val_batches ,
592+ limit_test_batches ,
593+ limit_predict_batches ,
594+ val_check_interval ,
595+ overfit_batches ,
596+ fast_dev_run ,
597+ ):
598+ if not isinstance (fast_dev_run , (bool , int )):
599+ raise MisconfigurationException (
600+ f"fast_dev_run={ fast_dev_run } is not a valid configuration. It should be either a bool or an int >= 0"
601+ )
602+
603+ if isinstance (fast_dev_run , int ) and (fast_dev_run < 0 ):
604+ raise MisconfigurationException (
605+ f"fast_dev_run={ fast_dev_run } is not a valid configuration. It should be >= 0."
606+ )
607+
608+ self .fast_dev_run = fast_dev_run
609+ fast_dev_run = int (fast_dev_run )
610+
611+ # set fast_dev_run=True when it is 1, used while logging
612+ if fast_dev_run == 1 :
613+ self .fast_dev_run = True
614+
615+ if fast_dev_run :
616+ limit_train_batches = fast_dev_run
617+ limit_val_batches = fast_dev_run
618+ limit_test_batches = fast_dev_run
619+ limit_predict_batches = fast_dev_run
620+ self .fit_loop .max_steps = fast_dev_run
621+ self .num_sanity_val_steps = 0
622+ self .fit_loop .max_epochs = 1
623+ val_check_interval = 1.0
624+ self .check_val_every_n_epoch = 1
625+ self .logger = DummyLogger () if self .logger is not None else None
626+
627+ rank_zero_info (
628+ "Running in fast_dev_run mode: will run a full train,"
629+ f" val, test and prediction loop using { fast_dev_run } batch(es)."
630+ )
631+
632+ self .limit_train_batches = _determine_batch_limits (limit_train_batches , "limit_train_batches" )
633+ self .limit_val_batches = _determine_batch_limits (limit_val_batches , "limit_val_batches" )
634+ self .limit_test_batches = _determine_batch_limits (limit_test_batches , "limit_test_batches" )
635+ self .limit_predict_batches = _determine_batch_limits (limit_predict_batches , "limit_predict_batches" )
636+ self .val_check_interval = _determine_batch_limits (val_check_interval , "val_check_interval" )
637+ self .overfit_batches = _determine_batch_limits (overfit_batches , "overfit_batches" )
638+ self .determine_data_use_amount (self .overfit_batches )
639+
640+ def determine_data_use_amount (self , overfit_batches : float ) -> None :
641+ """Use less data for debugging purposes."""
642+ if overfit_batches > 0 :
643+ self .limit_train_batches = overfit_batches
644+ self .limit_val_batches = overfit_batches
645+ self .limit_test_batches = overfit_batches
646+
590647 def _setup_on_init (self , num_sanity_val_steps : int ) -> None :
591648 self ._log_device_info ()
592649
@@ -2133,3 +2190,13 @@ def terminate_on_nan(self, val: bool) -> None:
21332190 f" Please set `Trainer(detect_anomaly={ val } )` instead."
21342191 )
21352192 self ._terminate_on_nan = val # : 212
2193+
2194+
2195+ def _determine_batch_limits (batches : Union [int , float ], name : str ) -> Union [int , float ]:
2196+ if 0 <= batches <= 1 :
2197+ return batches
2198+ if batches > 1 and batches % 1.0 == 0 :
2199+ return int (batches )
2200+ raise MisconfigurationException (
2201+ f"You have passed invalid value { batches } for { name } , it has to be in [0.0, 1.0] or an int."
2202+ )
0 commit comments