Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 0 additions & 93 deletions pytorch_lightning/trainer/connectors/debugging_connector.py

This file was deleted.

75 changes: 71 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.loggers.base import DummyLogger, LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
Expand All @@ -53,7 +53,6 @@
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
Expand Down Expand Up @@ -450,7 +449,6 @@ def __init__(
)
self.logger_connector = LoggerConnector(self, log_gpu_memory)
self._callback_connector = CallbackConnector(self)
self.debugging_connector = DebuggingConnector(self)
self.checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint)
self.signal_connector = SignalConnector(self)
self.tuner = Tuner(self)
Expand Down Expand Up @@ -574,7 +572,7 @@ def __init__(
self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu)

# init debugging flags
self.debugging_connector.on_init_start(
self._init_debugging_flags(
limit_train_batches,
limit_val_batches,
limit_test_batches,
Expand All @@ -587,6 +585,65 @@ def __init__(
# Callback system
self.on_init_end()

def _init_debugging_flags(
self,
limit_train_batches,
limit_val_batches,
limit_test_batches,
limit_predict_batches,
val_check_interval,
overfit_batches,
fast_dev_run,
):
if not isinstance(fast_dev_run, (bool, int)):
raise MisconfigurationException(
f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be either a bool or an int >= 0"
)

if isinstance(fast_dev_run, int) and (fast_dev_run < 0):
raise MisconfigurationException(
f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be >= 0."
)

self.fast_dev_run = fast_dev_run
fast_dev_run = int(fast_dev_run)

# set fast_dev_run=True when it is 1, used while logging
if fast_dev_run == 1:
self.fast_dev_run = True

if fast_dev_run:
limit_train_batches = fast_dev_run
limit_val_batches = fast_dev_run
limit_test_batches = fast_dev_run
limit_predict_batches = fast_dev_run
self.fit_loop.max_steps = fast_dev_run
self.num_sanity_val_steps = 0
self.fit_loop.max_epochs = 1
val_check_interval = 1.0
self.check_val_every_n_epoch = 1
self.logger = DummyLogger() if self.logger is not None else None

rank_zero_info(
"Running in fast_dev_run mode: will run a full train,"
f" val, test and prediction loop using {fast_dev_run} batch(es)."
)

self.limit_train_batches = _determine_batch_limits(limit_train_batches, "limit_train_batches")
self.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches")
self.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches")
self.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches")
self.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval")
self.overfit_batches = _determine_batch_limits(overfit_batches, "overfit_batches")
self.determine_data_use_amount(self.overfit_batches)

def determine_data_use_amount(self, overfit_batches: float) -> None:
"""Use less data for debugging purposes."""
if overfit_batches > 0:
self.limit_train_batches = overfit_batches
self.limit_val_batches = overfit_batches
self.limit_test_batches = overfit_batches

def _setup_on_init(self, num_sanity_val_steps: int) -> None:
self._log_device_info()

Expand Down Expand Up @@ -2133,3 +2190,13 @@ def terminate_on_nan(self, val: bool) -> None:
f" Please set `Trainer(detect_anomaly={val})` instead."
)
self._terminate_on_nan = val # : 212


def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
if 0 <= batches <= 1:
return batches
if batches > 1 and batches % 1.0 == 0:
return int(batches)
raise MisconfigurationException(
f"You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int."
)