From 0e9b03e7598a3fb9b9d4c954b5b56b58eed2f4fb Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 19 Jan 2021 14:38:18 +0000 Subject: [PATCH 01/11] fix toggle_optimizer --- pytorch_lightning/core/lightning.py | 33 ++++++++-- pytorch_lightning/trainer/training_loop.py | 4 ++ tests/core/test_lightning_module.py | 77 +++++++++++++++++++++- 3 files changed, 108 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f750c8aff7caf..bc453e1d712a5 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1170,12 +1170,35 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): optimizer: optimizer_idx: """ - for param in self.parameters(): - param.requires_grad = False + param_requires_grad_state = {} + for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)): + if optimizer_idx != opt_idx: + for group in opt.param_groups: + for param in group['params']: + param_requires_grad_state[param] = param.requires_grad + param.requires_grad = False + + self._param_requires_grad_state = param_requires_grad_state + + def untoggle_optimizer(self, optimizer_idx: int): + """ + Makes sure only the gradients of the current optimizer's parameters are calculated + in the training step to prevent dangling gradients in multiple-optimizer setup. - for group in optimizer.param_groups: - for param in group['params']: - param.requires_grad = True + .. note:: Only called when using multiple optimizers + + Override for your own behavior + + Args: + optimizer_idx: + """ + for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)): + if optimizer_idx != opt_idx: + for group in opt.param_groups: + for param in group['params']: + param.requires_grad = self._param_requires_grad_state[param] + # save memory + del self._param_requires_grad_state def optimizer_step( self, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 47e254606af93..0925bc78a9533 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -798,6 +798,10 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, if self.trainer.terminate_on_nan: self.trainer.detect_nan_tensors(result.loss) + if len(self.trainer.optimizers) > 1: + # revert back to previous state + self.trainer.get_model().untoggle_optimizer(opt_idx) + return result def backward(self, result, optimizer, opt_idx, *args, **kwargs): diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 64b68245ba66e..ea98a1a5d943c 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from argparse import ArgumentParser import pickle +from argparse import ArgumentParser from typing import Optional from unittest.mock import MagicMock, patch import pytest import torch +from torch import nn from torch.optim import Adam, SGD from torch.utils.data import DataLoader, random_split @@ -139,3 +140,77 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos ) trainer.fit(model) + + +def test_toggle_untoggle(tmpdir): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.layer_1 = nn.Sequential( + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + ) + + self.layer_2 = nn.Sequential( + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 2) + ) + + # set some weights to False to check untoggle works as expected. + self.layer_1[2].weight.requires_grad = False + self.layer_1[4].weight.requires_grad = False + + self.layer_2[1].weight.requires_grad = False + self.layer_2[3].weight.requires_grad = False + + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def configure_optimizers(self): + optimizer = SGD(self.layer_1.parameters(), lr=0.1) + optimizer_2 = Adam(self.layer_2.parameters(), lr=0.1) + return [optimizer, optimizer_2] + + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): + if optimizer_idx == 0: + assert self.layer_1[0].weight.requires_grad is True + assert self.layer_1[2].weight.requires_grad is False + assert self.layer_1[4].weight.requires_grad is False + + assert self.layer_2[1].weight.requires_grad is False + assert self.layer_2[3].weight.requires_grad is False + assert self.layer_2[5].weight.requires_grad is False + optimizer.step(closure=closure) + + if optimizer_idx == 1: + assert self.layer_1[0].weight.requires_grad is False + assert self.layer_1[2].weight.requires_grad is False + assert self.layer_1[4].weight.requires_grad is False + + assert self.layer_2[1].weight.requires_grad is False + assert self.layer_2[3].weight.requires_grad is False + assert self.layer_2[5].weight.requires_grad is True + optimizer.step(closure=closure) + + model = TestModel() + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=8, + accumulate_grad_batches=1, + ) + + trainer.fit(model) From 400530d481eb657a10d40b2849a4ec6795b4fdd1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 19 Jan 2021 14:41:54 +0000 Subject: [PATCH 02/11] update doc --- pytorch_lightning/core/lightning.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bc453e1d712a5..d435a40d0f116 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1166,6 +1166,8 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): Override for your own behavior + It works with `untoggle_optimizer` to make sure param_requires_grad_state is properly reset. + Args: optimizer: optimizer_idx: @@ -1182,9 +1184,6 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): def untoggle_optimizer(self, optimizer_idx: int): """ - Makes sure only the gradients of the current optimizer's parameters are calculated - in the training step to prevent dangling gradients in multiple-optimizer setup. - .. note:: Only called when using multiple optimizers Override for your own behavior From b25921a43643552e8e4a0a737a71b3532ea64e8d Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 19 Jan 2021 16:36:46 +0000 Subject: [PATCH 03/11] resolve bug --- pytorch_lightning/core/lightning.py | 8 +++++++- pytorch_lightning/trainer/training_loop.py | 8 ++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d435a40d0f116..679b6f5c8f903 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1180,6 +1180,11 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): param_requires_grad_state[param] = param.requires_grad param.requires_grad = False + for group in optimizer.param_groups: + for param in group['params']: + if param in param_requires_grad_state: + param.requires_grad = param_requires_grad_state[param] + self._param_requires_grad_state = param_requires_grad_state def untoggle_optimizer(self, optimizer_idx: int): @@ -1195,7 +1200,8 @@ def untoggle_optimizer(self, optimizer_idx: int): if optimizer_idx != opt_idx: for group in opt.param_groups: for param in group['params']: - param.requires_grad = self._param_requires_grad_state[param] + if param in self._param_requires_grad_state: + param.requires_grad = self._param_requires_grad_state[param] # save memory del self._param_requires_grad_state diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0925bc78a9533..0de43fcf4755e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -26,6 +26,7 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum from pytorch_lightning.utilities import AMPType, parsing, TPU_AVAILABLE +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach @@ -945,6 +946,13 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): # use to track metrics internally self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) + # set requieres_grad to True for all tensors + def _convert(value): + value.requieres_grad = True + return value + + split_batch = apply_to_collection(split_batch, torch.Tensor, _convert) + def update_running_loss(self): accumulated_loss = self.accumulated_loss.mean() From 8252877aa8b4bb0a00d85f8a7232e8fc6b3b154c Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 19 Jan 2021 16:38:57 +0000 Subject: [PATCH 04/11] update --- pytorch_lightning/trainer/training_loop.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0de43fcf4755e..0925bc78a9533 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -26,7 +26,6 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum from pytorch_lightning.utilities import AMPType, parsing, TPU_AVAILABLE -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach @@ -946,13 +945,6 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): # use to track metrics internally self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) - # set requieres_grad to True for all tensors - def _convert(value): - value.requieres_grad = True - return value - - split_batch = apply_to_collection(split_batch, torch.Tensor, _convert) - def update_running_loss(self): accumulated_loss = self.accumulated_loss.mean() From 64314fa590480f4373c2e455578946e554ef041a Mon Sep 17 00:00:00 2001 From: chaton Date: Tue, 19 Jan 2021 19:32:44 +0000 Subject: [PATCH 05/11] Update pytorch_lightning/core/lightning.py Co-authored-by: Rohit Gupta --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 679b6f5c8f903..6a63f7015712a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1166,7 +1166,7 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): Override for your own behavior - It works with `untoggle_optimizer` to make sure param_requires_grad_state is properly reset. + It works with ``untoggle_optimizer`` to make sure param_requires_grad_state is properly reset. Args: optimizer: From be15bf15a9f731b907c39fd11aff896ce213a291 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 10:37:23 +0000 Subject: [PATCH 06/11] update on comments --- pytorch_lightning/core/lightning.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6a63f7015712a..ba476c997be65 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1169,22 +1169,25 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): It works with ``untoggle_optimizer`` to make sure param_requires_grad_state is properly reset. Args: - optimizer: - optimizer_idx: + optimizer: Current optimizer used in training_loop + optimizer_idx: Current optimizer idx in training_loop """ param_requires_grad_state = {} - for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)): - if optimizer_idx != opt_idx: - for group in opt.param_groups: - for param in group['params']: + # make sure current optimizer is latest to be iterated over. + optimizers = [opt for opt in self.optimizers(use_pl_optimizer=False) if opt != optimizer] + [optimizer] + num_optimizers = len(optimizers) - 1 + for opt_idx, opt in enumerate(optimizers): + for group in opt.param_groups: + for param in group['params']: + if num_optimizers == opt_idx: + # If a param appears in 2 optimizer, revert `requires_grad_state` to before toggle + if param in param_requires_grad_state: + param.requires_grad = param_requires_grad_state[param] + else: + # save requires_grad for later restoration param_requires_grad_state[param] = param.requires_grad param.requires_grad = False - for group in optimizer.param_groups: - for param in group['params']: - if param in param_requires_grad_state: - param.requires_grad = param_requires_grad_state[param] - self._param_requires_grad_state = param_requires_grad_state def untoggle_optimizer(self, optimizer_idx: int): @@ -1194,7 +1197,7 @@ def untoggle_optimizer(self, optimizer_idx: int): Override for your own behavior Args: - optimizer_idx: + optimizer_idx: Current optimizer idx in training_loop """ for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)): if optimizer_idx != opt_idx: From c195e4cc9a97f962571be69f09b413a728dfe053 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 19:55:09 +0000 Subject: [PATCH 07/11] update on comments --- pytorch_lightning/core/lightning.py | 2 +- tests/core/test_lightning_module.py | 8 +------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ba476c997be65..f3a52eea9539c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1180,7 +1180,7 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): for group in opt.param_groups: for param in group['params']: if num_optimizers == opt_idx: - # If a param appears in 2 optimizer, revert `requires_grad_state` to before toggle + # If a param appears in 2 optimizers, revert `requires_grad` to before toggle. if param in param_requires_grad_state: param.requires_grad = param_requires_grad_state[param] else: diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index ea98a1a5d943c..48b7ff63143c1 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -172,11 +172,6 @@ def __init__(self): self.layer_2[1].weight.requires_grad = False self.layer_2[3].weight.requires_grad = False - def training_step(self, batch, batch_idx, optimizer_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - return {"loss": loss} - def configure_optimizers(self): optimizer = SGD(self.layer_1.parameters(), lr=0.1) optimizer_2 = Adam(self.layer_2.parameters(), lr=0.1) @@ -191,7 +186,6 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos assert self.layer_2[1].weight.requires_grad is False assert self.layer_2[3].weight.requires_grad is False assert self.layer_2[5].weight.requires_grad is False - optimizer.step(closure=closure) if optimizer_idx == 1: assert self.layer_1[0].weight.requires_grad is False @@ -201,7 +195,7 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos assert self.layer_2[1].weight.requires_grad is False assert self.layer_2[3].weight.requires_grad is False assert self.layer_2[5].weight.requires_grad is True - optimizer.step(closure=closure) + optimizer.step(closure=closure) model = TestModel() model.training_epoch_end = None From beb42741cf2f47863fd158df0fffa6872dd45302 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 20:33:58 +0000 Subject: [PATCH 08/11] update --- tests/core/test_lightning_module.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 48b7ff63143c1..6c4416da380e0 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -146,6 +146,9 @@ def test_toggle_untoggle(tmpdir): class TestModel(BoringModel): + def training_step(self, batch, batch_idx, optimizer_idx=None): + return super().training_step(batch, batch_idx) + def __init__(self): super().__init__() self.layer_1 = nn.Sequential( From eda5e448492b14f06fe7b32c1107e77c6b30a738 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 21 Jan 2021 08:14:56 +0000 Subject: [PATCH 09/11] update changelog --- CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fd70e3583c01..64dad1925dc0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,12 +47,23 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195)) + + - Logging only on `not should_accumulate()` during training ([#5417](https://github.com/PyTorchLightning/pytorch-lightning/pull/5417)) + + - Resolve interpolation bug with Hydra ([#5406](https://github.com/PyTorchLightning/pytorch-lightning/pull/5406)) + + - Check environ before selecting a seed to prevent warning message ([#4743](https://github.com/PyTorchLightning/pytorch-lightning/pull/4743)) + + - Fixed signature mismatch in `model_to_device` of `DDPCPUHPCAccelerator` ([#5505](https://github.com/PyTorchLightning/pytorch-lightning/pull/5505)) +- Fixed `toggle_optimizer` to reset `requieres_grad` state ([#5574](https://github.com/PyTorchLightning/pytorch-lightning/pull/5574)) + + ## [1.1.3] - 2021-01-05 ### Added From 20340939262b7f33cfb1f5875ad7b10c81676832 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 21 Jan 2021 08:15:34 +0000 Subject: [PATCH 10/11] update changelog --- CHANGELOG.md | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 64dad1925dc0a..877302bea1f16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,11 +29,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed a visual bug in the progress bar display initialization ([#4579](https://github.com/PyTorchLightning/pytorch-lightning/pull/4579)) + + - Fixed logging `on_train_batch_end` in a callback with multiple optimizers ([#5521](https://github.com/PyTorchLightning/pytorch-lightning/pull/5521)) + + - Fixed `reinit_scheduler_properties` with correct optimizer ([#5519](https://github.com/PyTorchLightning/pytorch-lightning/pull/5519)) + + - Fixed `val_check_interval` with `fast_dev_run` ([#5540](https://github.com/PyTorchLightning/pytorch-lightning/pull/5540)) +- Fixed `toggle_optimizer` to reset `requieres_grad` state ([#5574](https://github.com/PyTorchLightning/pytorch-lightning/pull/5574)) + + ## [1.1.4] - 2021-01-12 ### Added @@ -60,10 +69,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed signature mismatch in `model_to_device` of `DDPCPUHPCAccelerator` ([#5505](https://github.com/PyTorchLightning/pytorch-lightning/pull/5505)) - -- Fixed `toggle_optimizer` to reset `requieres_grad` state ([#5574](https://github.com/PyTorchLightning/pytorch-lightning/pull/5574)) - - ## [1.1.3] - 2021-01-05 ### Added From 1d36cc9131b92fc589c82b07fe4181045130a97e Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 22 Jan 2021 09:25:29 +0000 Subject: [PATCH 11/11] update changelog --- CHANGELOG.md | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 877302bea1f16..788807cab4345 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `toggle_optimizer` to reset `requieres_grad` state ([#5574](https://github.com/PyTorchLightning/pytorch-lightning/pull/5574)) + @@ -29,20 +31,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed a visual bug in the progress bar display initialization ([#4579](https://github.com/PyTorchLightning/pytorch-lightning/pull/4579)) - - - Fixed logging `on_train_batch_end` in a callback with multiple optimizers ([#5521](https://github.com/PyTorchLightning/pytorch-lightning/pull/5521)) - - - Fixed `reinit_scheduler_properties` with correct optimizer ([#5519](https://github.com/PyTorchLightning/pytorch-lightning/pull/5519)) - - - Fixed `val_check_interval` with `fast_dev_run` ([#5540](https://github.com/PyTorchLightning/pytorch-lightning/pull/5540)) -- Fixed `toggle_optimizer` to reset `requieres_grad` state ([#5574](https://github.com/PyTorchLightning/pytorch-lightning/pull/5574)) - - ## [1.1.4] - 2021-01-12 ### Added @@ -56,17 +49,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195)) - - - Logging only on `not should_accumulate()` during training ([#5417](https://github.com/PyTorchLightning/pytorch-lightning/pull/5417)) - - - Resolve interpolation bug with Hydra ([#5406](https://github.com/PyTorchLightning/pytorch-lightning/pull/5406)) - - - Check environ before selecting a seed to prevent warning message ([#4743](https://github.com/PyTorchLightning/pytorch-lightning/pull/4743)) - - - Fixed signature mismatch in `model_to_device` of `DDPCPUHPCAccelerator` ([#5505](https://github.com/PyTorchLightning/pytorch-lightning/pull/5505)) ## [1.1.3] - 2021-01-05