From b490fe79af33a9442480f00b5421554cfd7c4ccb Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 7 Jan 2021 01:22:27 +0530 Subject: [PATCH 01/11] ref and fix call for on_pretrained_routine --- pytorch_lightning/accelerators/accelerator.py | 1 + .../accelerators/cpu_accelerator.py | 8 ++- .../accelerators/ddp2_accelerator.py | 6 +-- .../accelerators/ddp_accelerator.py | 6 +-- .../accelerators/ddp_cpu_spawn_accelerator.py | 6 +-- .../accelerators/ddp_hpc_accelerator.py | 6 +-- .../accelerators/ddp_spawn_accelerator.py | 6 +-- .../accelerators/dp_accelerator.py | 5 +- .../accelerators/gpu_accelerator.py | 6 +-- .../accelerators/horovod_accelerator.py | 8 +-- .../accelerators/tpu_accelerator.py | 6 +-- .../connectors/checkpoint_connector.py | 4 +- pytorch_lightning/trainer/trainer.py | 49 ++++++++++++++++++- pytorch_lightning/trainer/training_loop.py | 45 ++--------------- tests/callbacks/test_callbacks.py | 4 +- tests/models/test_hooks.py | 6 +-- 16 files changed, 87 insertions(+), 85 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 77f30219ba8c0..2b0240bd20fff 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -66,6 +66,7 @@ def train_or_test(self): if self.trainer.testing: results = self.trainer.run_test() else: + self.trainer.train_loop.setup_training() results = self.trainer.train() return results diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index 25302cabbc70f..edad8a5bfa4c7 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union, Callable +from typing import Any, Callable, Optional, Union import torch @@ -53,10 +53,8 @@ def setup(self, model): self.trainer.model = model def train(self): - model = self.trainer.model - - # set up training routine - self.trainer.train_loop.setup_training(model) + # set up trainer + self.trainer.setup_trainer(self.trainer.model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 2e864029f8767..d6fbdd972c255 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -26,7 +26,7 @@ from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType +from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available if HYDRA_AVAILABLE: @@ -200,8 +200,8 @@ def ddp_train(self, process_idx, mp_queue, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up training routine - self.trainer.train_loop.setup_training(model) + # set up trainer + self.trainer.setup_trainer(model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index da9eb2d3ea937..ee9fd644cfa50 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -30,7 +30,7 @@ from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType +from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE from pytorch_lightning.utilities.distributed import ( all_gather_ddp_if_available, find_free_network_port, @@ -299,9 +299,9 @@ def ddp_train(self, process_idx, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up training routine + # set up trainer self.barrier('ddp_setup') - self.trainer.train_loop.setup_training(model) + self.trainer.setup_trainer(model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 91a6dee484f30..f4a5d4990b24a 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -26,7 +26,7 @@ from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType +from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE from pytorch_lightning.utilities.distributed import ( all_gather_ddp_if_available, find_free_network_port, @@ -160,8 +160,8 @@ def ddp_train(self, process_idx, mp_queue, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up training routine - self.trainer.train_loop.setup_training(model) + # set up trainer + self.trainer.setup_trainer(model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index b257884e34aef..51365110276a6 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -26,7 +26,7 @@ from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType +from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available if HYDRA_AVAILABLE: @@ -191,8 +191,8 @@ def ddp_train(self, process_idx, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up training routine - self.trainer.train_loop.setup_training(model) + # set up trainer + self.trainer.setup_trainer(model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index a49e17fc0b31d..1818b00b79f73 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -27,7 +27,7 @@ from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType +from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import ( @@ -175,8 +175,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up training routine - self.trainer.train_loop.setup_training(model) + # set up trainer + self.trainer.setup_trainer(model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 834a920b505d9..7527cd45dc114 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -104,9 +104,8 @@ def __init_nvidia_apex(self, model): return model def train(self): - model = self.trainer.model - # set up training routine - self.trainer.train_loop.setup_training(model) + # set up trainer + self.trainer.setup_trainer(self.trainer.model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index 1310777e0d890..265bbdd821fba 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -59,10 +59,8 @@ def setup(self, model): self.trainer.model = model def train(self): - model = self.trainer.model - - # set up training routine - self.trainer.train_loop.setup_training(model) + # set up trainer + self.trainer.setup_trainer(self.trainer.model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 5895025673b9a..fd6da290e009e 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import ExitStack -from typing import Any, Optional, Union, Callable +from typing import Any, Callable, Optional, Union import torch from torch.optim.lr_scheduler import _LRScheduler @@ -20,7 +20,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.utilities import HOROVOD_AVAILABLE, AMPType +from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_only if HOROVOD_AVAILABLE: @@ -106,8 +106,8 @@ def train(self): # Synchronization will be performed explicitly following backward() stack.enter_context(optimizer.skip_synchronize()) - # set up training routine - self.trainer.train_loop.setup_training(self.trainer.model) + # set up trainer + self.trainer.setup_trainer(self.trainer.model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 9d1eec5594d82..a88563b2a3745 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -26,11 +26,11 @@ from pytorch_lightning.core import LightningModule from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities import ( - TPU_AVAILABLE, move_data_to_device, rank_zero_info, rank_zero_only, rank_zero_warn, + TPU_AVAILABLE, ) from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -134,8 +134,8 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine # setup TPU training self.__setup_tpu_training(model, trainer) - # set up training routine - self.trainer.train_loop.setup_training(model) + # set up trainer + self.trainer.setup_trainer(model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index fc9c70ba46d2e..a01b7645caafa 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -13,8 +13,8 @@ # limitations under the License. import os -from pathlib import Path import re +from pathlib import Path from typing import Optional, Union import torch @@ -64,7 +64,7 @@ def restore_weights(self, model: LightningModule) -> None: rank_zero_info(f'restored hpc model from: {checkpoint_path}') # 2. Attempt to restore states from `resume_from_checkpoint` file - elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing: + elif self.trainer.resume_from_checkpoint is not None: self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu) # wait for all to catch up diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 06717c6333829..ebcb33f35a793 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -15,9 +15,9 @@ """Trainer to automate the training.""" import os +import warnings from pathlib import Path from typing import Dict, Iterable, List, Optional, Union -import warnings import torch from torch.utils.data import DataLoader @@ -57,7 +57,7 @@ from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import DeviceType, rank_zero_warn +from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -411,6 +411,51 @@ def __init__( # Callback system self.on_init_end() + def setup_trainer(self, model: LightningModule): + """ + Sanity check a few things before starting actual training or testing. + + Args: + model: The model to run sanity test on. + """ + # -------------------------- + # Setup?? + # -------------------------- + ref_model = model + if self.data_parallel: + ref_model = model.module + + # set the ranks and devices + self.accelerator_backend.dist.rank = self.global_rank + self.accelerator_backend.dist.device = ref_model.device + + # give model convenience properties + ref_model.trainer = self + + # set local properties on the model + self.model_connector.copy_trainer_model_properties(ref_model) + + # init amp. Must be done here instead of __init__ to allow ddp to work + if self.amp_backend == AMPType.NATIVE and self.precision == 16 and not self.use_tpu: + self.scaler = self.precision_connector.backend.scaler + + # log hyper-parameters + if self.logger is not None: + # save exp to get started (this is where the first experiment logs are written) + self.logger.log_hyperparams(ref_model.hparams_initial) + self.logger.log_graph(ref_model) + self.logger.save() + + # wait for all to join if on distributed + self.accelerator_backend.barrier("setup_trainer") + + # register auto-resubmit when on SLURM + self.slurm_connector.register_slurm_signal_handlers() + + # track model now. + # if cluster resets state, the model will update with the saved weights + self.model = model + def fit( self, model: LightningModule, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0d99b071d4567..b376a0ddc3828 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -125,46 +125,15 @@ def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # check that model is configured correctly self.trainer.config_validator.verify_loop_configurations(model) - def setup_training(self, model: LightningModule): - """Sanity check a few things before starting actual training. - - Args: - model: The model to run sanity test on. + def setup_training(self): """ - # -------------------------- - # Setup?? - # -------------------------- + Sanity check a few things before starting actual training. + """ + model = self.trainer.model ref_model = model if self.trainer.data_parallel: ref_model = model.module - # set the ranks and devices - self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank - self.trainer.accelerator_backend.dist.device = ref_model.device - - # give model convenience properties - ref_model.trainer = self.trainer - - # set local properties on the model - self.trainer.model_connector.copy_trainer_model_properties(ref_model) - - # init amp. Must be done here instead of __init__ to allow ddp to work - if self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16 and not self.trainer.use_tpu: - self.trainer.scaler = self.trainer.precision_connector.backend.scaler - - # log hyper-parameters - if self.trainer.logger is not None: - # save exp to get started (this is where the first experiment logs are written) - self.trainer.logger.log_hyperparams(ref_model.hparams_initial) - self.trainer.logger.log_graph(ref_model) - self.trainer.logger.save() - - # wait for all to join if on distributed - self.trainer.accelerator_backend.barrier("setup_training") - - # register auto-resubmit when on SLURM - self.trainer.slurm_connector.register_slurm_signal_handlers() - # -------------------------- # Pre-train # -------------------------- @@ -174,13 +143,9 @@ def setup_training(self, model: LightningModule): ref_model.on_pretrain_routine_start() # print model summary - if self.trainer.is_global_zero and not self.trainer.testing: + if self.trainer.is_global_zero: ref_model.summarize(mode=self.trainer.weights_summary) - # track model now. - # if cluster resets state, the model will update with the saved weights - self.trainer.model = model - # restore training state and model weights before hpc is called self.trainer.checkpoint_connector.restore_weights(model) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 070bb4e9f6989..53464aff880c5 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock -from unittest.mock import ANY, MagicMock, call +from unittest.mock import ANY, call, MagicMock from pytorch_lightning import Trainer from tests.base import BoringModel @@ -111,8 +111,6 @@ def test_trainer_callback_system(torch_save): call.on_init_end(trainer), call.setup(trainer, model, 'test'), call.on_fit_start(trainer, model), - call.on_pretrain_routine_start(trainer, model), - call.on_pretrain_routine_end(trainer, model), call.on_test_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index f3af5b745a380..c7de4fc74ba2a 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from unittest.mock import MagicMock import pytest import torch -from unittest.mock import MagicMock from pytorch_lightning import Trainer from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator -from tests.base import EvalModelTemplate, BoringModel +from tests.base import BoringModel, EvalModelTemplate @pytest.mark.parametrize('max_steps', [1, 2, 3]) @@ -348,8 +348,6 @@ def on_test_model_train(self): expected = [ 'on_fit_start', - 'on_pretrain_routine_start', - 'on_pretrain_routine_end', 'on_test_model_eval', 'on_test_epoch_start', 'on_test_batch_start', From d6b9f2730f522a48b9ad74adc7bb7b053e07c418 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 7 Jan 2021 02:43:20 +0530 Subject: [PATCH 02/11] avoid failing tests --- tests/core/test_datamodules.py | 18 +++++++++--------- tests/models/test_torchscript.py | 6 +++--- .../optimization/test_manual_optimization.py | 12 ++++++------ 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index d286bbf3a9de6..64dc25101eae6 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -13,21 +13,21 @@ # limitations under the License. import pickle from argparse import ArgumentParser -from unittest.mock import MagicMock from typing import Optional +from unittest.mock import MagicMock import pytest import torch from torch.utils.data import DataLoader, random_split -from pytorch_lightning import LightningDataModule, Trainer, seed_everything +from pytorch_lightning import LightningDataModule, seed_everything, Trainer +from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities.model_utils import is_overridden from tests.base import EvalModelTemplate -from tests.base.datasets import TrialMNIST from tests.base.datamodules import TrialMNISTDataModule +from tests.base.datasets import TrialMNIST from tests.base.develop_utils import reset_seed -from pytorch_lightning.utilities.model_utils import is_overridden -from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator -from pytorch_lightning.callbacks import ModelCheckpoint def test_can_prepare_data(tmpdir): @@ -170,14 +170,14 @@ def test_data_hooks_called_with_stage_kwarg(tmpdir): def test_dm_add_argparse_args(tmpdir): parser = ArgumentParser() parser = TrialMNISTDataModule.add_argparse_args(parser) - args = parser.parse_args(['--data_dir', './my_data']) - assert args.data_dir == './my_data' + args = parser.parse_args(['--data_dir', str(tmpdir)]) + assert args.data_dir == str(tmpdir) def test_dm_init_from_argparse_args(tmpdir): parser = ArgumentParser() parser = TrialMNISTDataModule.add_argparse_args(parser) - args = parser.parse_args(['--data_dir', './my_data']) + args = parser.parse_args(['--data_dir', str(tmpdir)]) dm = TrialMNISTDataModule.from_argparse_args(args) dm.prepare_data() dm.setup() diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 3c43b201f52e4..75e1ec7724967 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -18,7 +18,7 @@ from tests.base import BoringModel from tests.base.datamodules import TrialMNISTDataModule -from tests.base.models import ParityModuleRNN, BasicGAN +from tests.base.models import BasicGAN, ParityModuleRNN @pytest.mark.parametrize("modelclass", [ @@ -116,10 +116,10 @@ def test_torchscript_retain_training_state(): ParityModuleRNN, BasicGAN, ]) -def test_torchscript_properties(modelclass): +def test_torchscript_properties(tmpdir, modelclass): """ Test that scripted LightningModule has unnecessary methods removed. """ model = modelclass() - model.datamodule = TrialMNISTDataModule() + model.datamodule = TrialMNISTDataModule(tmpdir) script = model.to_torchscript() assert not hasattr(script, "datamodule") assert not hasattr(model, "batch_size") or hasattr(script, "batch_size") diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 33d14e852b285..f80e233a646ec 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -21,7 +21,7 @@ import torch.distributed as torch_distrib import torch.nn.functional as F -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.utilities import APEX_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel @@ -658,11 +658,11 @@ def automatic_optimization(self) -> bool: assert model.called +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_step_with_optimizer_closure(tmpdir): """ Tests that `step` works with optimizer_closure """ - os.environ['PL_DEV_DEBUG'] = '1' class TestModel(BoringModel): @@ -739,11 +739,11 @@ def automatic_optimization(self) -> bool: assert trainer.logger_connector.progress_bar_metrics["train_loss_epoch"] == torch.stack(model._losses).mean() +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_step_with_optimizer_closure_and_accumulated_grad(tmpdir): """ Tests that `step` works with optimizer_closure and accumulated_grad """ - os.environ['PL_DEV_DEBUG'] = '1' class TestModel(BoringModel): def training_step(self, batch, batch_idx): @@ -802,12 +802,12 @@ def automatic_optimization(self) -> bool: assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * 2 +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @patch("torch.optim.SGD.step") def test_step_with_optimizer_closure_and_extra_arguments(step_mock, tmpdir): """ Tests that `step` works with optimizer_closure and extra arguments """ - os.environ['PL_DEV_DEBUG'] = '1' class TestModel(BoringModel): def training_step(self, batch, batch_idx): @@ -859,13 +859,13 @@ def automatic_optimization(self) -> bool: step_mock.assert_has_calls(expected_calls) +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @patch("torch.optim.Adam.step") @patch("torch.optim.SGD.step") def test_step_with_optimizer_closure_with_different_frequencies(mock_sgd_step, mock_adam_step, tmpdir): """ Tests that `step` works with optimizer_closure and different accumulated_gradient frequency """ - os.environ['PL_DEV_DEBUG'] = '1' class TestModel(BoringModel): def training_step(self, batch, batch_idx, optimizer_idx): @@ -939,6 +939,7 @@ def automatic_optimization(self) -> bool: mock_adam_step.assert_has_calls(expected_calls) +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @patch("torch.optim.Adam.step") @patch("torch.optim.SGD.step") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -947,7 +948,6 @@ def test_step_with_optimizer_closure_with_different_frequencies_ddp(mock_sgd_ste """ Tests that `step` works with optimizer_closure and different accumulated_gradient frequency """ - os.environ['PL_DEV_DEBUG'] = '1' class TestModel(BoringModel): From f2cbff0fdf494408f922d4ede0716628579830ce Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 8 Jan 2021 00:13:56 +0530 Subject: [PATCH 03/11] unnecessary_call --- pytorch_lightning/trainer/evaluation_loop.py | 3 --- pytorch_lightning/trainer/trainer.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 4b70917c8c43d..cd04dea7e5546 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -123,9 +123,6 @@ def is_using_eval_results(self): return using_eval_result def setup(self, model, max_batches, dataloaders): - # copy properties for forward overrides - self.trainer.model_connector.copy_trainer_model_properties(model) - # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ebcb33f35a793..c1df93e5a4e0e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -429,9 +429,6 @@ def setup_trainer(self, model: LightningModule): self.accelerator_backend.dist.rank = self.global_rank self.accelerator_backend.dist.device = ref_model.device - # give model convenience properties - ref_model.trainer = self - # set local properties on the model self.model_connector.copy_trainer_model_properties(ref_model) From 9753d2f73a355a229bf3a385cc20001bdd22b8b3 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 8 Jan 2021 00:15:17 +0530 Subject: [PATCH 04/11] unnecessary call in accelerators --- pytorch_lightning/accelerators/ddp2_accelerator.py | 3 --- pytorch_lightning/accelerators/ddp_accelerator.py | 3 --- pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py | 3 --- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 3 --- pytorch_lightning/accelerators/ddp_spawn_accelerator.py | 3 --- 5 files changed, 15 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index d6fbdd972c255..9a158a3869161 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -186,9 +186,6 @@ def ddp_train(self, process_idx, mp_queue, model): self.ddp_plugin.on_after_setup_optimizers(self.trainer) - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - # 16-bit model = self.trainer.precision_connector.connect(model) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index ee9fd644cfa50..1781dd5678e9d 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -285,9 +285,6 @@ def ddp_train(self, process_idx, model): # allow for lr schedulers as well self.setup_optimizers(model) - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - # 16-bit model = self.trainer.precision_connector.connect(model) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index f4a5d4990b24a..20cf833a87a9d 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -146,9 +146,6 @@ def ddp_train(self, process_idx, mp_queue, model): self.ddp_plugin.on_after_setup_optimizers(self.trainer) - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - # 16-bit model = self.trainer.precision_connector.connect(model) diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 51365110276a6..65f4d13ef1e5c 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -177,9 +177,6 @@ def ddp_train(self, process_idx, model): self.ddp_plugin.on_after_setup_optimizers(self.trainer) - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - # 16-bit model = self.trainer.precision_connector.connect(model) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index 1818b00b79f73..6e4089b311ee1 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -161,9 +161,6 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 self.ddp_plugin.on_after_setup_optimizers(self.trainer) - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - # 16-bit model = self.trainer.precision_connector.connect(model) From 6e0d63ac4d62e5a0c89eaf82e1645b6c7ba8f463 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 8 Jan 2021 01:15:23 +0530 Subject: [PATCH 05/11] tmpdir --- tests/trainer/logging_tests/test_eval_loop_logging_1_0.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py index da08ffe710e75..53636bed66f56 100644 --- a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py @@ -292,7 +292,7 @@ def validation_epoch_end(self, outputs) -> None: max_epochs=1, log_every_n_steps=1, weights_summary=None, - callbacks=[ModelCheckpoint(dirpath='val_loss')], + callbacks=[ModelCheckpoint(dirpath=tmpdir)], ) trainer.fit(model) From 64a3e7c4b36a2a707f9a83a043e84f43aa47ad70 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 8 Jan 2021 01:16:22 +0530 Subject: [PATCH 06/11] rm test_mode --- .../logger_connector/logger_connector.py | 14 ++--- pytorch_lightning/trainer/evaluation_loop.py | 51 +++++++++---------- pytorch_lightning/trainer/trainer.py | 18 +++---- pytorch_lightning/trainer/training_loop.py | 2 +- pytorch_lightning/utilities/debugging.py | 6 +-- .../test_eval_loop_dict_return.py | 19 +++---- 6 files changed, 52 insertions(+), 58 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 54bf2f9a90cea..e3ff3c1732c02 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy import os +from copy import deepcopy from pprint import pprint from typing import Iterable, Union @@ -211,9 +211,9 @@ def add_progress_bar_metrics(self, metrics): self.trainer.dev_debugger.track_pbar_metrics_history(metrics) - def track_metrics_deprecated(self, deprecated_eval_results, using_eval_result, test_mode): + def track_metrics_deprecated(self, deprecated_eval_results, using_eval_result): self._track_callback_metrics(deprecated_eval_results, using_eval_result) - self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode) + self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results) def evaluation_epoch_end(self, testing): # reset dataloader idx @@ -242,7 +242,7 @@ def prepare_eval_loop_results(self): for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): self.add_to_eval_loop_results(dl_idx, has_been_initialized) - def get_evaluate_epoch_results(self, test_mode): + def get_evaluate_epoch_results(self): if not self.trainer.running_sanity_check: # log all the metrics as a single dict metrics_to_log = self.cached_results.get_epoch_log_metrics() @@ -252,7 +252,7 @@ def get_evaluate_epoch_results(self, test_mode): self.prepare_eval_loop_results() # log results of test - if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test: + if self.trainer.testing and self.trainer.is_global_zero and self.trainer.verbose_test: print('-' * 80) for result_idx, results in enumerate(self.eval_loop_results): print(f'DATALOADER:{result_idx} TEST RESULTS') @@ -333,7 +333,7 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric if len(dataloader_result_metrics) > 0: self.eval_loop_results.append(dataloader_result_metrics) - def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mode): + def __process_eval_epoch_end_results_and_log_legacy(self, eval_results): if self.trainer.running_sanity_check: return @@ -353,7 +353,7 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod callback_metrics = result.callback_metrics # in testing we don't need the callback metrics - if test_mode: + if self.trainer.testing: callback_metrics = {} else: _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(result) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index cd04dea7e5546..8a5732d94a451 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -24,7 +24,6 @@ class EvaluationLoop(object): def __init__(self, trainer): self.trainer = trainer - self.testing = False self.outputs = [] self.step_metrics = [] self.predictions = None @@ -52,7 +51,7 @@ def get_evaluation_dataloaders(self, max_batches): model = self.trainer.get_model() # select dataloaders - if self.testing: + if self.trainer.testing: self.trainer.reset_test_dataloader(model) dataloaders = self.trainer.test_dataloaders @@ -85,34 +84,34 @@ def should_skip_evaluation(self, dataloaders, max_batches): return False def on_evaluation_start(self, *args, **kwargs): - if self.testing: + if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_start', *args, **kwargs) def on_evaluation_model_eval(self, *args, **kwargs): model_ref = self.trainer.get_model() - if self.testing: + if self.trainer.testing: model_ref.on_test_model_eval() else: model_ref.on_validation_model_eval() def on_evaluation_model_train(self, *args, **kwargs): model_ref = self.trainer.get_model() - if self.testing: + if self.trainer.testing: model_ref.on_test_model_train() else: model_ref.on_validation_model_train() def on_evaluation_end(self, *args, **kwargs): - if self.testing: + if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_end', *args, **kwargs) def reload_evaluation_dataloaders(self): model = self.trainer.get_model() - if self.testing: + if self.trainer.testing: self.trainer.reset_test_dataloader(model) else: self.trainer.reset_val_dataloader(model) @@ -135,17 +134,17 @@ def setup(self, model, max_batches, dataloaders): self.num_dataloaders = self._get_num_dataloaders(dataloaders) def on_evaluation_epoch_start(self, *args, **kwargs): - if self.testing: + if self.trainer.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) - def build_args(self, test_mode, batch, batch_idx, dataloader_idx): + def build_args(self, batch, batch_idx, dataloader_idx): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] - multiple_val_loaders = (not test_mode and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1) - multiple_test_loaders = (test_mode and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) + multiple_val_loaders = (not self.trainer.testing and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1) + multiple_test_loaders = (self.trainer.testing and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) if multiple_test_loaders or multiple_val_loaders: args.append(dataloader_idx) @@ -160,14 +159,14 @@ def _get_num_dataloaders(self, dataloaders): length = len(dataloaders[0]) return length - def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): + def evaluation_step(self, batch, batch_idx, dataloader_idx): # configure args - args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) + args = self.build_args(batch, batch_idx, dataloader_idx) model_ref = self.trainer.get_model() model_ref._results = Result() # run actual test step - if self.testing: + if self.trainer.testing: model_ref._current_fx_name = "test_step" output = self.trainer.accelerator_backend.test_step(args) else: @@ -189,7 +188,7 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): return output def evaluation_step_end(self, *args, **kwargs): - if self.testing: + if self.trainer.testing: output = self.trainer.call_hook('test_step_end', *args, **kwargs) else: output = self.trainer.call_hook('validation_step_end', *args, **kwargs) @@ -197,7 +196,7 @@ def evaluation_step_end(self, *args, **kwargs): def evaluation_epoch_end(self): # unset dataloder_idx in model - self.trainer.logger_connector.evaluation_epoch_end(self.testing) + self.trainer.logger_connector.evaluation_epoch_end(self.trainer.testing) using_eval_result = self.is_using_eval_results() @@ -213,7 +212,7 @@ def evaluation_epoch_end(self): def log_epoch_metrics_on_evaluation_end(self): # get the final loop results - eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results(self.testing) + eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results() return eval_loop_results def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): @@ -227,7 +226,7 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): user_reduced = False - if self.testing: + if self.trainer.testing: if is_overridden('test_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) @@ -247,7 +246,7 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): self.trainer.logger_connector.cache_logged_metrics() # depre warning if eval_results is not None and user_reduced: - step = 'testing_epoch_end' if self.testing else 'validation_epoch_end' + step = 'testing_epoch_end' if self.trainer.testing else 'validation_epoch_end' self.warning_cache.warn( f'The {step} should not return anything as of 9.1.' ' To log, use self.log(...) or self.write(...) directly in the LightningModule' @@ -260,7 +259,7 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): eval_results = [eval_results] # track depreceated metrics - self.trainer.logger_connector.track_metrics_deprecated(eval_results, using_eval_result, self.testing) + self.trainer.logger_connector.track_metrics_deprecated(eval_results, using_eval_result) return eval_results @@ -297,15 +296,15 @@ def __auto_reduce_result_objs(self, outputs): def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): # set dataloader_idx to model and track batch_size self.trainer.logger_connector.on_evaluation_batch_start( - self.testing, batch, dataloader_idx, self.num_dataloaders) + self.trainer.testing, batch, dataloader_idx, self.num_dataloaders) - if self.testing: + if self.trainer.testing: self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) else: self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) def on_evaluation_batch_end(self, output, batch, batch_idx, dataloader_idx): - if self.testing: + if self.trainer.testing: self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) else: self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) @@ -316,16 +315,16 @@ def on_evaluation_batch_end(self, output, batch, batch_idx, dataloader_idx): def store_predictions(self, output, batch_idx, dataloader_idx): # Add step predictions to prediction collection to write later if output is not None: - do_write_predictions = isinstance(output, Result) and self.testing + do_write_predictions = isinstance(output, Result) and self.trainer.testing if do_write_predictions: self.predictions.add(output.pop('predictions', None)) # track debug metrics - self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output) + self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) def on_evaluation_epoch_end(self, *args, **kwargs): # call the callback hook - if self.testing: + if self.trainer.testing: self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c1df93e5a4e0e..1668b87bf7719 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -487,10 +487,6 @@ def fit( # hook self.data_connector.prepare_data(model) - # bookkeeping - # we reuse fit in .test() but change its behavior using this flag - self.testing = os.environ.get('PL_TESTING_MODE', self.testing) - # ---------------------------- # SET UP TRAINING # ---------------------------- @@ -595,13 +591,13 @@ def train(self): # hook self.train_loop.on_train_end() - def run_evaluation(self, test_mode: bool = False, max_batches=None): + def run_evaluation(self, max_batches=None): # used to know if we are logging for val, test + reset cached results - self.logger_connector.set_stage(test_mode, reset=True) + self.logger_connector.set_stage(self.testing, reset=True) # bookkeeping - self.evaluation_loop.testing = test_mode + self.evaluation_loop.testing = self.testing # prepare dataloaders dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches) @@ -647,7 +643,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # lightning module methods with self.profiler.profile("evaluation_step_and_end"): - output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx) + output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx) output = self.evaluation_loop.evaluation_step_end(output) # hook + store predictions @@ -700,7 +696,7 @@ def run_test(self): # only load test dataloader for testing # self.reset_test_dataloader(ref_model) with self.profiler.profile("run_test_evaluation"): - eval_loop_results, _ = self.run_evaluation(test_mode=True) + eval_loop_results, _ = self.run_evaluation() if len(eval_loop_results) == 0: return 1 @@ -731,7 +727,7 @@ def run_sanity_check(self, ref_model): self.on_sanity_check_start() # run eval step - _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches) + _, eval_results = self.run_evaluation(max_batches=self.num_sanity_val_batches) # allow no returns from eval if eval_results is not None and len(eval_results) > 0: @@ -835,11 +831,9 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): # run tests self.tested_ckpt_path = ckpt_path self.testing = True - os.environ['PL_TESTING_MODE'] = '1' self.model = model results = self.fit(model) self.testing = False - del os.environ['PL_TESTING_MODE'] # teardown if self.is_function_implemented('teardown'): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b376a0ddc3828..d93e8a978d75c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -563,7 +563,7 @@ def run_training_epoch(self): # ----------------------------------------- should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) if should_check_val: - self.trainer.run_evaluation(test_mode=False) + self.trainer.run_evaluation() # reset stage to train self.trainer.logger_connector.set_stage("train") diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 9264e2a49810d..c9fac5cc04a45 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -16,7 +16,7 @@ import time from collections import Counter from functools import wraps -from typing import Callable, Any, Optional +from typing import Any, Callable, Optional def enabled_only(fn: Callable): @@ -133,7 +133,7 @@ def track_lr_schedulers_update(self, batch_idx, interval, scheduler_idx, old_lr, self.saved_lr_scheduler_updates.append(loss_dict) @enabled_only - def track_eval_loss_history(self, test_mode, batch_idx, dataloader_idx, output): + def track_eval_loss_history(self, batch_idx, dataloader_idx, output): loss_dict = { 'sanity_check': self.trainer.running_sanity_check, 'dataloader_idx': dataloader_idx, @@ -142,7 +142,7 @@ def track_eval_loss_history(self, test_mode, batch_idx, dataloader_idx, output): 'output': output } - if test_mode: + if self.trainer.testing: self.saved_test_losses.append(loss_dict) else: self.saved_val_losses.append(loss_dict) diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py index 9e2023d27d928..3a9a87f84e5d9 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py @@ -15,8 +15,9 @@ Tests to ensure that the training loop works with a dict """ import os -from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning import Trainer +from pytorch_lightning.core.lightning import LightningModule from tests.base.deterministic_model import DeterministicModel @@ -43,7 +44,7 @@ def backward(self, loss, optimizer, optimizer_idx): # out are the results of the full loop # eval_results are output of _evaluate - out, eval_results = trainer.run_evaluation(test_mode=False) + out, eval_results = trainer.run_evaluation() assert len(out) == 1 assert len(eval_results) == 0 @@ -74,7 +75,7 @@ def test_validation_step_scalar_return(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - out, eval_results = trainer.run_evaluation(test_mode=False) + out, eval_results = trainer.run_evaluation() assert len(out) == 1 assert len(eval_results) == 2 assert eval_results[0] == 171 and eval_results[1] == 171 @@ -106,7 +107,7 @@ def test_validation_step_arbitrary_dict_return(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + callback_metrics, eval_results = trainer.run_evaluation() assert len(callback_metrics) == 1 assert len(eval_results) == 2 assert eval_results[0]['some'] == 171 @@ -144,7 +145,7 @@ def test_validation_step_dict_return(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + callback_metrics, eval_results = trainer.run_evaluation() assert len(callback_metrics) == 1 assert len(callback_metrics[0]) == 5 assert len(eval_results) == 2 @@ -186,7 +187,7 @@ def test_val_step_step_end_no_return(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + callback_metrics, eval_results = trainer.run_evaluation() assert len(callback_metrics) == 1 assert len(eval_results) == 0 @@ -218,7 +219,7 @@ def test_val_step_step_end(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + callback_metrics, eval_results = trainer.run_evaluation() assert len(callback_metrics) == 1 assert len(callback_metrics[0]) == 6 @@ -264,7 +265,7 @@ def test_no_val_step_end(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + callback_metrics, eval_results = trainer.run_evaluation() assert len(callback_metrics) == 1 assert len(callback_metrics[0]) == 6 assert len(eval_results) == 1 @@ -308,7 +309,7 @@ def test_full_val_loop(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + callback_metrics, eval_results = trainer.run_evaluation() assert len(callback_metrics) == 1 assert len(callback_metrics[0]) == 7 assert len(eval_results) == 1 From 5d39e21e8ec67c5f053b91d9d9bf3c47600139df Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 8 Jan 2021 01:21:12 +0530 Subject: [PATCH 07/11] pep --- pytorch_lightning/trainer/evaluation_loop.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 8a5732d94a451..44806c09c43b4 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -143,8 +143,14 @@ def build_args(self, batch, batch_idx, dataloader_idx): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] - multiple_val_loaders = (not self.trainer.testing and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1) - multiple_test_loaders = (self.trainer.testing and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) + multiple_val_loaders = ( + not self.trainer.testing + and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1 + ) + multiple_test_loaders = ( + self.trainer.testing + and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1 + ) if multiple_test_loaders or multiple_val_loaders: args.append(dataloader_idx) From 27e35167befb94c827d5842636b25c9528150e77 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 8 Jan 2021 03:17:26 +0530 Subject: [PATCH 08/11] updates --- pytorch_lightning/accelerators/accelerator.py | 4 ++++ pytorch_lightning/accelerators/cpu_accelerator.py | 8 -------- pytorch_lightning/accelerators/ddp2_accelerator.py | 1 - pytorch_lightning/accelerators/ddp_accelerator.py | 1 - .../accelerators/ddp_cpu_spawn_accelerator.py | 1 - pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 1 - pytorch_lightning/accelerators/ddp_spawn_accelerator.py | 1 - pytorch_lightning/accelerators/dp_accelerator.py | 9 --------- pytorch_lightning/accelerators/gpu_accelerator.py | 8 -------- pytorch_lightning/accelerators/horovod_accelerator.py | 1 - pytorch_lightning/accelerators/tpu_accelerator.py | 1 - pytorch_lightning/trainer/evaluation_loop.py | 4 ++-- 12 files changed, 6 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 2b0240bd20fff..8bb335f2e7847 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -52,6 +52,10 @@ def __init__(self, def setup(self, model): pass + def train(self): + self.trainer.setup_trainer(self.trainer.model) + return self.train_or_test() + def teardown(self): # Ensure if necessary all processes are finished self.barrier() diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index edad8a5bfa4c7..990ad8af92694 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -52,14 +52,6 @@ def setup(self, model): self.trainer.model = model - def train(self): - # set up trainer - self.trainer.setup_trainer(self.trainer.model) - - # train or test - results = self.train_or_test() - return results - def _step(self, model_step: Callable, args): if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 9a158a3869161..96a58e76d1bb7 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -197,7 +197,6 @@ def ddp_train(self, process_idx, mp_queue, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up trainer self.trainer.setup_trainer(model) # train or test diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 1781dd5678e9d..579df6b7e0630 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -296,7 +296,6 @@ def ddp_train(self, process_idx, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up trainer self.barrier('ddp_setup') self.trainer.setup_trainer(model) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 20cf833a87a9d..5d92fc4906c83 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -157,7 +157,6 @@ def ddp_train(self, process_idx, mp_queue, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up trainer self.trainer.setup_trainer(model) # train or test diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 65f4d13ef1e5c..68488e4e1d6c0 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -188,7 +188,6 @@ def ddp_train(self, process_idx, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up trainer self.trainer.setup_trainer(model) # train or test diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index 6e4089b311ee1..3dfbeca88556a 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -172,7 +172,6 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up trainer self.trainer.setup_trainer(model) # train or test diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 7527cd45dc114..7b661c1ed01a6 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -103,15 +103,6 @@ def __init_nvidia_apex(self, model): return model - def train(self): - # set up trainer - self.trainer.setup_trainer(self.trainer.model) - - # train or test - results = self.train_or_test() - - return results - def teardown(self): # replace the original fwd function self.trainer.model.forward = self.model_autocast_original_forward diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index 265bbdd821fba..0ea580eb2561f 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -58,14 +58,6 @@ def setup(self, model): self.trainer.model = model - def train(self): - # set up trainer - self.trainer.setup_trainer(self.trainer.model) - - # train or test - results = self.train_or_test() - return results - def _step(self, model_step: Callable, args): args[0] = self.to_device(args[0]) diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index fd6da290e009e..072e594433ec6 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -106,7 +106,6 @@ def train(self): # Synchronization will be performed explicitly following backward() stack.enter_context(optimizer.skip_synchronize()) - # set up trainer self.trainer.setup_trainer(self.trainer.model) # train or test diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index a88563b2a3745..531aeff69e0c4 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -134,7 +134,6 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine # setup TPU training self.__setup_tpu_training(model, trainer) - # set up trainer self.trainer.setup_trainer(model) # train or test diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 44806c09c43b4..63f65bead2579 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -139,7 +139,7 @@ def on_evaluation_epoch_start(self, *args, **kwargs): else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) - def build_args(self, batch, batch_idx, dataloader_idx): + def _build_args(self, batch, batch_idx, dataloader_idx): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] @@ -167,7 +167,7 @@ def _get_num_dataloaders(self, dataloaders): def evaluation_step(self, batch, batch_idx, dataloader_idx): # configure args - args = self.build_args(batch, batch_idx, dataloader_idx) + args = self._build_args(batch, batch_idx, dataloader_idx) model_ref = self.trainer.get_model() model_ref._results = Result() From 5d9e95f87343a4d9853eb30ca883d1dbfba369c6 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 8 Jan 2021 04:02:05 +0530 Subject: [PATCH 09/11] more ref --- pytorch_lightning/accelerators/accelerator.py | 4 ++-- pytorch_lightning/accelerators/ddp2_accelerator.py | 8 ++------ pytorch_lightning/accelerators/ddp_accelerator.py | 10 +++------- .../accelerators/ddp_cpu_spawn_accelerator.py | 9 ++------- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 8 ++------ .../accelerators/ddp_spawn_accelerator.py | 9 ++------- pytorch_lightning/accelerators/horovod_accelerator.py | 8 +++----- pytorch_lightning/accelerators/tpu_accelerator.py | 9 ++------- pytorch_lightning/trainer/trainer.py | 2 +- 9 files changed, 19 insertions(+), 48 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8bb335f2e7847..7c4e30b3f62a9 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -52,8 +52,8 @@ def __init__(self, def setup(self, model): pass - def train(self): - self.trainer.setup_trainer(self.trainer.model) + def train(self, model: LightningModule): + self.trainer.setup_trainer(model) return self.train_or_test() def teardown(self): diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 96a58e76d1bb7..825f97b52f10b 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -58,8 +58,7 @@ def setup(self, model): self.trainer.model = model self.task_idx = self.cluster_environment.local_rank() - def train(self): - model = self.trainer.model + def train(self, model: LightningModule): return self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) def training_step(self, args): @@ -197,10 +196,7 @@ def ddp_train(self, process_idx, mp_queue, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - self.trainer.setup_trainer(model) - - # train or test - results = self.train_or_test() + results = super().train(model) # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 579df6b7e0630..2f3a51f167dec 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -146,10 +146,9 @@ def _call_children_scripts(self): delay = np.random.uniform(1, 5, 1)[0] sleep(delay) - def train(self): - model = self.trainer.model - + def train(self, model: LightningModule): results = self.ddp_train(process_idx=self.task_idx, model=model) + if 'WORLD_SIZE' in os.environ: del os.environ['WORLD_SIZE'] return results @@ -297,10 +296,7 @@ def ddp_train(self, process_idx, model): model = self.configure_ddp(model, device_ids) self.barrier('ddp_setup') - self.trainer.setup_trainer(model) - - # train or test - results = self.train_or_test() + results = super().train(model) # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 5d92fc4906c83..2d7b36333eec2 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -71,9 +71,7 @@ def setup(self, model): self.trainer.model = model - def train(self): - model = self.trainer.model - + def train(self, model: LightningModule): # train in children process mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,)) @@ -157,10 +155,7 @@ def ddp_train(self, process_idx, mp_queue, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - self.trainer.setup_trainer(model) - - # train or test - results = self.train_or_test() + results = super().train(model) # get original model model = self.trainer.get_model() diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 68488e4e1d6c0..2f39080ae1331 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -59,8 +59,7 @@ def setup(self, model): self.trainer.model = model self.task_idx = self.cluster_environment.local_rank() - def train(self): - model = self.trainer.model + def train(self, model: LightningModule): self.ddp_train(process_idx=self.task_idx, model=model) def set_world_ranks(self, process_idx): @@ -188,10 +187,7 @@ def ddp_train(self, process_idx, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - self.trainer.setup_trainer(model) - - # train or test - results = self.train_or_test() + results = super().train(model) # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index 3dfbeca88556a..372af24dcb17e 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -75,9 +75,7 @@ def setup(self, model): self.trainer.model = model - def train(self): - model = self.trainer.model - + def train(self, model: LightningModule): # train in children process mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,)) @@ -172,10 +170,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # allow user to configure ddp model = self.configure_ddp(model, device_ids) - self.trainer.setup_trainer(model) - - # train or test - results = self.train_or_test() + results = super().train(model) # get original model model = self.trainer.get_model() diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 072e594433ec6..9c5a09f52a1c7 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -20,6 +20,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.cluster_environments import ClusterEnvironment +from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_only @@ -100,16 +101,13 @@ def _filter_named_parameters(model, optimizer): self.trainer.model = model - def train(self): + def train(self, model: LightningModule): with ExitStack() as stack: for optimizer in self.trainer.optimizers: # Synchronization will be performed explicitly following backward() stack.enter_context(optimizer.skip_synchronize()) - self.trainer.setup_trainer(self.trainer.model) - - # train or test - results = self.train_or_test() + results = super().train(model) # Make sure all workers have finished training before returning to the user hvd.join() diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 531aeff69e0c4..35d9ebe36d3c6 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -99,9 +99,7 @@ def teardown(self): self.__load_weights_on_main_process() return results - def train(self): - model = self.trainer.model - + def train(self, model: LightningModule): # train if self.trainer.tpu_id is not None: self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue) @@ -134,10 +132,7 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine # setup TPU training self.__setup_tpu_training(model, trainer) - self.trainer.setup_trainer(model) - - # train or test - results = self.train_or_test() + results = super().train(model) # save weights at the end of training self.__save_end_of_training_weights(model, trainer) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1668b87bf7719..e7e4f9c044a5a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -507,7 +507,7 @@ def fit( # hook self.call_hook('on_fit_start') - results = self.accelerator_backend.train() + results = self.accelerator_backend.train(self.model) self.accelerator_backend.teardown() # ---------------------------- From b8554a1d715f19b36e71d5aee68caf29e4d80db3 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 9 Jan 2021 00:04:11 +0530 Subject: [PATCH 10/11] Revert "more ref" This reverts commit 5d9e95f87343a4d9853eb30ca883d1dbfba369c6. --- pytorch_lightning/accelerators/accelerator.py | 4 ++-- pytorch_lightning/accelerators/ddp2_accelerator.py | 8 ++++++-- pytorch_lightning/accelerators/ddp_accelerator.py | 10 +++++++--- .../accelerators/ddp_cpu_spawn_accelerator.py | 9 +++++++-- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 8 ++++++-- .../accelerators/ddp_spawn_accelerator.py | 9 +++++++-- pytorch_lightning/accelerators/horovod_accelerator.py | 8 +++++--- pytorch_lightning/accelerators/tpu_accelerator.py | 9 +++++++-- pytorch_lightning/trainer/trainer.py | 2 +- 9 files changed, 48 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 7c4e30b3f62a9..8bb335f2e7847 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -52,8 +52,8 @@ def __init__(self, def setup(self, model): pass - def train(self, model: LightningModule): - self.trainer.setup_trainer(model) + def train(self): + self.trainer.setup_trainer(self.trainer.model) return self.train_or_test() def teardown(self): diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 040853644f03c..e3e23d2ece7f2 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -58,7 +58,8 @@ def setup(self, model): self.trainer.model = model self.task_idx = self.cluster_environment.local_rank() - def train(self, model: LightningModule): + def train(self): + model = self.trainer.model return self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) def training_step(self, args): @@ -196,7 +197,10 @@ def ddp_train(self, process_idx, mp_queue, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - results = super().train(model) + self.trainer.setup_trainer(model) + + # train or test + results = self.train_or_test() # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index b61d8f3d34109..4f0313faec664 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -146,9 +146,10 @@ def _call_children_scripts(self): delay = np.random.uniform(1, 5, 1)[0] sleep(delay) - def train(self, model: LightningModule): - results = self.ddp_train(process_idx=self.task_idx, model=model) + def train(self): + model = self.trainer.model + results = self.ddp_train(process_idx=self.task_idx, model=model) if 'WORLD_SIZE' in os.environ: del os.environ['WORLD_SIZE'] return results @@ -296,7 +297,10 @@ def ddp_train(self, process_idx, model): model = self.configure_ddp(model, device_ids) self.barrier('ddp_setup') - results = super().train(model) + self.trainer.setup_trainer(model) + + # train or test + results = self.train_or_test() # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index dcc953c7de5ca..93228e7228376 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -71,7 +71,9 @@ def setup(self, model): self.trainer.model = model - def train(self, model: LightningModule): + def train(self): + model = self.trainer.model + # train in children process mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,)) @@ -155,7 +157,10 @@ def ddp_train(self, process_idx, mp_queue, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - results = super().train(model) + self.trainer.setup_trainer(model) + + # train or test + results = self.train_or_test() # get original model model = self.trainer.get_model() diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index ea30fafc309ff..7c05cddac0d2f 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -59,7 +59,8 @@ def setup(self, model): self.trainer.model = model self.task_idx = self.cluster_environment.local_rank() - def train(self, model: LightningModule): + def train(self): + model = self.trainer.model self.ddp_train(process_idx=self.task_idx, model=model) def set_world_ranks(self, process_idx): @@ -187,7 +188,10 @@ def ddp_train(self, process_idx, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - results = super().train(model) + self.trainer.setup_trainer(model) + + # train or test + results = self.train_or_test() # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index b54c5b90c7e95..ce519bd644e42 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -75,7 +75,9 @@ def setup(self, model): self.trainer.model = model - def train(self, model: LightningModule): + def train(self): + model = self.trainer.model + # train in children process mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,)) @@ -170,7 +172,10 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # allow user to configure ddp model = self.configure_ddp(model, device_ids) - results = super().train(model) + self.trainer.setup_trainer(model) + + # train or test + results = self.train_or_test() # get original model model = self.trainer.get_model() diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 9c5a09f52a1c7..072e594433ec6 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -20,7 +20,6 @@ from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_only @@ -101,13 +100,16 @@ def _filter_named_parameters(model, optimizer): self.trainer.model = model - def train(self, model: LightningModule): + def train(self): with ExitStack() as stack: for optimizer in self.trainer.optimizers: # Synchronization will be performed explicitly following backward() stack.enter_context(optimizer.skip_synchronize()) - results = super().train(model) + self.trainer.setup_trainer(self.trainer.model) + + # train or test + results = self.train_or_test() # Make sure all workers have finished training before returning to the user hvd.join() diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 35d9ebe36d3c6..531aeff69e0c4 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -99,7 +99,9 @@ def teardown(self): self.__load_weights_on_main_process() return results - def train(self, model: LightningModule): + def train(self): + model = self.trainer.model + # train if self.trainer.tpu_id is not None: self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue) @@ -132,7 +134,10 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine # setup TPU training self.__setup_tpu_training(model, trainer) - results = super().train(model) + self.trainer.setup_trainer(model) + + # train or test + results = self.train_or_test() # save weights at the end of training self.__save_end_of_training_weights(model, trainer) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e7e4f9c044a5a..1668b87bf7719 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -507,7 +507,7 @@ def fit( # hook self.call_hook('on_fit_start') - results = self.accelerator_backend.train(self.model) + results = self.accelerator_backend.train() self.accelerator_backend.teardown() # ---------------------------- From 469fd00ae7da646b3947c2f72ae57d5417bae54b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 9 Jan 2021 01:16:53 +0530 Subject: [PATCH 11/11] more refac --- .../trainer/connectors/checkpoint_connector.py | 2 +- pytorch_lightning/trainer/trainer.py | 6 ++---- pytorch_lightning/trainer/training_loop.py | 9 +++------ 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index a01b7645caafa..03d46132fb177 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -44,7 +44,7 @@ def __init__(self, trainer): # used to validate checkpointing logic self.has_trained = False - def restore_weights(self, model: LightningModule) -> None: + def restore_weights(self) -> None: """ Attempt to restore a checkpoint (e.g. weights) in this priority: 1. from HPC weights diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1668b87bf7719..8f263a12c2b2d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -421,16 +421,14 @@ def setup_trainer(self, model: LightningModule): # -------------------------- # Setup?? # -------------------------- - ref_model = model - if self.data_parallel: - ref_model = model.module + ref_model = self.get_model() # set the ranks and devices self.accelerator_backend.dist.rank = self.global_rank self.accelerator_backend.dist.device = ref_model.device # set local properties on the model - self.model_connector.copy_trainer_model_properties(ref_model) + self.model_connector.copy_trainer_model_properties(model) # init amp. Must be done here instead of __init__ to allow ddp to work if self.amp_backend == AMPType.NATIVE and self.precision == 16 and not self.use_tpu: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b2b44f9628dac..28d3bfa156e6d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -128,14 +128,11 @@ def setup_training(self): """ Sanity check a few things before starting actual training. """ - model = self.trainer.model - ref_model = model - if self.trainer.data_parallel: - ref_model = model.module - # -------------------------- # Pre-train # -------------------------- + ref_model = self.trainer.get_model() + # on pretrain routine start self.trainer.on_pretrain_routine_start(ref_model) if self.trainer.is_function_implemented("on_pretrain_routine_start"): @@ -146,7 +143,7 @@ def setup_training(self): ref_model.summarize(mode=self.trainer.weights_summary) # restore training state and model weights before hpc is called - self.trainer.checkpoint_connector.restore_weights(model) + self.trainer.checkpoint_connector.restore_weights() # on pretrain routine end self.trainer.on_pretrain_routine_end(ref_model)