From 9a37a62609df07fda3fbb34b36731393a4d21617 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 6 Feb 2021 13:47:59 +0000 Subject: [PATCH 1/3] resolve manual_optimization --- pytorch_lightning/accelerators/accelerator.py | 39 ++++++--------- pytorch_lightning/accelerators/tpu.py | 21 ++------- pytorch_lightning/plugins/base_plugin.py | 13 ++--- .../plugins/precision/native_amp.py | 47 +++++++++++-------- .../plugins/precision/precision_plugin.py | 18 +++++-- .../plugins/training_type/ddp.py | 2 +- .../plugins/training_type/ddp_spawn.py | 2 +- .../plugins/training_type/horovod.py | 2 +- .../training_type/training_type_plugin.py | 12 ++--- .../optimization/test_manual_optimization.py | 14 +++--- 10 files changed, 78 insertions(+), 92 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 7377b89d7b5c4..a8e63776f93d8 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -15,6 +15,7 @@ import torch from torch.optim import Optimizer +from torch.utils.data import DataLoader from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision import ( @@ -228,8 +229,8 @@ def predict(self, args): return self.training_type_plugin.predict(*args) def process_dataloader( - self, dataloader: Union[Iterable, torch.utils.data.DataLoader] - ) -> Union[Iterable, torch.utils.data.DataLoader]: + self, dataloader: Union[Iterable, DataLoader] + ) -> Union[Iterable, DataLoader]: """Wraps the dataloader if necessary Args: @@ -240,7 +241,7 @@ def process_dataloader( def backward( self, closure_loss: torch.Tensor, - optimizer: torch.optim.Optimizer, + optimizer: Optimizer, opt_idx: int, should_accumulate: bool, *args, @@ -254,17 +255,17 @@ def backward( opt_idx: the index of the optimizer should_accumulate: whether to accumulate gradients """ - self.training_type_plugin.pre_backward(closure_loss, optimizer, opt_idx) + self.training_type_plugin.pre_backward(closure_loss, should_accumulate, optimizer, opt_idx) output = self.precision_plugin.backward( self.lightning_module, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs ) - self.training_type_plugin.post_backward(closure_loss, optimizer, opt_idx) + self.training_type_plugin.post_backward(closure_loss, should_accumulate, optimizer, opt_idx) return output - def optimizer_step(self, optimizer: torch.optim.Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs): + def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs): """performs the actual optimizer step. Args: @@ -273,33 +274,23 @@ def optimizer_step(self, optimizer: torch.optim.Optimizer, opt_idx: int, lambda_ lambda_closure: closure calculating the loss value """ - - self.precision_plugin.pre_optimizer_step(optimizer, opt_idx) - self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx) - - if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): - # apex does not support passing a closure to the optimizer, call it by itself - lambda_closure() - lambda_closure = None - - optimizer.step(closure=lambda_closure, **kwargs) - + make_optimizer_step = self.precision_plugin.pre_optimizer_step( + self.lightning_module, optimizer, opt_idx, lambda_closure, **kwargs) + if make_optimizer_step: + self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs) self.precision_plugin.post_optimizer_step(optimizer, opt_idx) - self.training_type_plugin.post_optimizer_step(optimizer, opt_idx) - if self.rpc_enabled and self.training_type_plugin.is_main_rpc_process: - - # Initialize optimizer step on main process - self.training_type_plugin.worker_optimizer_step(model=self.lightning_module, opt_idx=opt_idx, **kwargs) + def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs): + optimizer.step(closure=lambda_closure, **kwargs) def optimizer_zero_grad( - self, current_epoch: int, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int + self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int ) -> None: """Zeros all model parameter's gradients""" model_ref = self.lightning_module model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) - def clip_gradients(self, optimizer: torch.optim.Optimizer, clip_val: Union[int, float]) -> None: + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: """clips all the optimizer parameters to the given value""" self.precision_plugin.clip_gradients(optimizer, clip_val) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 4843665ec4a0b..201c6e88d6f3b 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -1,7 +1,7 @@ from typing import Callable import torch - +from torch.optim import Optimizer from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin @@ -26,20 +26,5 @@ def setup(self, trainer, model): raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.") return super().setup(trainer, model) - def optimizer_step(self, optimizer: torch.optim.Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs): - """performs the actual optimizer step. - - Args: - optimizer: the optimizer performing the step - opt_idx: index of the current optimizer - lambda_closure: closure calculating the loss value - - """ - - self.precision_plugin.pre_optimizer_step(optimizer, opt_idx) - self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx) - - xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs}) - - self.precision_plugin.post_optimizer_step(optimizer, opt_idx) - self.training_type_plugin.post_optimizer_step(optimizer, opt_idx) + def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs): + xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs}) \ No newline at end of file diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py index b316a8663f9ff..4a5bb7b00d913 100644 --- a/pytorch_lightning/plugins/base_plugin.py +++ b/pytorch_lightning/plugins/base_plugin.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +from torch.nn import Module from abc import ABC, abstractmethod -from typing import Any, Generator, Optional, overload, Sequence, Tuple +from typing import Any, Callable, Generator, Optional, overload, Sequence, Tuple import torch @@ -22,18 +23,12 @@ class Plugin(ABC): """Basic Plugin class to derive precision and training type plugins from.""" @abstractmethod - def connect(self, model: torch.nn.Module, *args: Sequence, - **kwargs: Sequence) -> Optional[Tuple[torch.nn.Module, Sequence, Sequence]]: + def connect(self, model: Module, *args: Sequence, + **kwargs: Sequence) -> Optional[Tuple[Module, Sequence, Sequence]]: """Connects the plugin with the accelerator (and thereby with trainer and model). Will be called by the accelerator. """ - def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: - """Hook to do something before each optimizer step.""" - - def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: - """Hook to do something after each optimizer step.""" - def pre_training(self) -> None: """Hook to do something before the training starts.""" diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 8cdaba833af85..2ed14b16b531d 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Generator +from typing import Callable, Generator import torch - +from torch.optim import Optimizer, LBFGS from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType @@ -33,25 +33,11 @@ def __init__(self): self.backend = AMPType.NATIVE self.scaler = torch.cuda.amp.GradScaler() - def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: - """always called before the optimizer step. - Checks that the optimizer is not LBFGS, as this one is not supported by native amp - """ - if isinstance(optimizer, torch.optim.LBFGS): - raise MisconfigurationException( - f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." - " To request, please file a Github issue in PyTorch and tag @mcarilli" - ) - - def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: - """Updates the GradScaler""" - self.scaler.update() - def backward( self, model: LightningModule, closure_loss: torch.Tensor, - optimizer: torch.optim.Optimizer, + optimizer: Optimizer, opt_idx: int, should_accumulate: bool, *args, @@ -69,16 +55,37 @@ def backward( """ closure_loss = self.scaler.scale(closure_loss) - automatic_optimization = model.automatic_optimization - closure_loss = super().backward(model, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs) # unscale gradient to allow analyze within `on_after_backward` - if not should_accumulate and automatic_optimization: + if not should_accumulate and model.automatic_optimization: self.scaler.unscale_(optimizer) return closure_loss + def pre_optimizer_step(self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs) -> bool: + """always called before the optimizer step. + Checks that the optimizer is not LBFGS, as this one is not supported by native amp + """ + if isinstance(optimizer, LBFGS): + raise MisconfigurationException( + f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." + " To request, please file a Github issue in PyTorch and tag @mcarilli" + ) + lambda_closure() + + if not pl_module.automatic_optimization: + self.scaler.unscale_(optimizer) + + pl_module.trainer.call_hook("on_after_backward") + + return False + + def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: + """Updates the GradScaler""" + self.scaler.step(optimizer) + self.scaler.update() + @contextmanager def train_step_context(self) -> Generator[autocast, None, None]: """Enable autocast context""" diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 3e74442e92277..f46582be2d60b 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Any, Generator, Sequence, Tuple, Union +from typing import Any, Generator, Sequence, Tuple, Union, Callable import torch +from torch.nn import Module from torch.optim import Optimizer from pytorch_lightning.core import LightningModule @@ -28,7 +29,7 @@ class PrecisionPlugin(Plugin): EPSILON = 1e-6 precision = 32 - def master_params(self, optimizer: torch.optim.Optimizer) -> Generator[torch.Tensor, None, None]: + def master_params(self, optimizer: Optimizer) -> Generator[torch.Tensor, None, None]: """The master params of the model. Returns the plain model params here. Maybe different in other precision plugins. @@ -37,8 +38,8 @@ def master_params(self, optimizer: torch.optim.Optimizer) -> Generator[torch.Ten for p in group["params"]: yield p - def connect(self, model: torch.nn.Module, optimizers: Sequence, - lr_schedulers: Sequence) -> Tuple[torch.nn.Module, Sequence, Sequence]: + def connect(self, model: Module, optimizers: Sequence, + lr_schedulers: Sequence) -> Tuple[Module, Sequence, Sequence]: """Connects this plugin to the accelerator and the training process""" return model, optimizers, lr_schedulers @@ -46,7 +47,7 @@ def backward( self, model: LightningModule, closure_loss: torch.Tensor, - optimizer: torch.optim.Optimizer, + optimizer: Optimizer, opt_idx: int, should_accumulate: bool, *args: Any, @@ -75,6 +76,13 @@ def backward( return closure_loss + def pre_optimizer_step(self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, closure: Callable, **kwargs) -> bool: + """Hook to do something before each optimizer step.""" + return True + + def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: + """Hook to do something after each optimizer step.""" + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)) -> None: """Clips the gradients to a specific value""" # TODO: separate TPU case from here diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 29b35ef1ec0b2..274078d8a80d4 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -268,7 +268,7 @@ def barrier(self, *args, **kwargs): def broadcast(self, obj: object, src: int = 0) -> object: return self.dist.broadcast(obj) - def pre_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int): + def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: prepare_for_backward(self.model, closure_loss) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 34f64eee5cc36..a7e8e00fe55a5 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -239,7 +239,7 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def pre_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int): + def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: prepare_for_backward(self.model, closure_loss) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 3deff8befde26..2393c040bcc8f 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -116,7 +116,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: obj = hvd.broadcast_object(obj, src) return obj - def post_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int): + def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): optimizer.synchronize() def model_to_device(self): diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 738bcc9347d94..c26f5fbc1b743 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,11 +13,11 @@ # limitations under the License. import os from abc import ABC, abstractmethod -from typing import Any, Optional, Sequence, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union import torch +from torch.nn import Module from torch.optim import Optimizer -from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.base_plugin import Plugin @@ -69,19 +69,19 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool: """Reduce the early stopping decision across all possibly spawned processes""" return should_stop - def pre_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int): + def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" - def post_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int): + def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run after precision plugin executes backward""" @property - def model(self) -> torch.nn.Module: + def model(self) -> Module: """Returns the potentially wrapped LightningModule""" return self._model @model.setter - def model(self, new_model: torch.nn.Module) -> None: + def model(self, new_model: Module) -> None: self._model = new_model @property diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 64558a71b59c9..8ab00817c4764 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -538,7 +538,7 @@ def training_step(self, batch, batch_idx): if self.should_update: self.manual_backward(loss, opt) - opt.step() + opt.step(make_optimizer_step=self.should_have_updated) return loss.detach() if self.detach else loss @@ -557,7 +557,7 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): assert torch.sum(self.layer.weight.grad) != 0 self.count += 1 - def on_train_end(self): + def on_train_epoch_end(self, *_, **__): assert self.called["training_step"] == 20 assert self.called["on_train_batch_start"] == 20 assert self.called["on_train_batch_end"] == 20 @@ -613,7 +613,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): assert torch.all(self.layer.weight.grad == 0) self.manual_backward(loss_1, opt_a) - opt_a.step() + opt_a.step(make_optimizer_step=True) # fake discriminator loss_2 = self(x) @@ -625,7 +625,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): self.manual_backward(loss_2, opt_a, retain_graph=True) assert self.layer.weight.grad is not None - opt_b.step() + opt_b.step(make_optimizer_step=True) def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer @@ -707,7 +707,7 @@ def optimizer_closure(): weight_before = self.layer.weight.clone() - opt.step(closure=optimizer_closure) + opt.step(closure=optimizer_closure, make_optimizer_step=True) weight_after = self.layer.weight.clone() assert not torch.equal(weight_before, weight_after) @@ -767,7 +767,7 @@ def optimizer_closure(): weight_before = self.layer.weight.clone() - opt.step(closure=optimizer_closure) + opt.step(closure=optimizer_closure, make_optimizer_step=True) weight_after = self.layer.weight.clone() if not self.trainer.train_loop.should_accumulate(): @@ -828,7 +828,7 @@ def optimizer_closure(): retain_graph = num_backward != backward_idx # noqa E225 self.manual_backward(loss_1, opt, retain_graph=retain_graph) - opt.step(closure=optimizer_closure) + opt.step(closure=optimizer_closure, make_optimizer_step=True) def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer From b447413a9f3670beecfe5087aec6aebc3a59c216 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 6 Feb 2021 13:53:44 +0000 Subject: [PATCH 2/3] update --- pytorch_lightning/accelerators/tpu.py | 4 ++-- pytorch_lightning/plugins/precision/native_amp.py | 9 ++++++--- pytorch_lightning/plugins/precision/precision_plugin.py | 6 ++++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 201c6e88d6f3b..abafc9f40a6bf 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -1,7 +1,7 @@ from typing import Callable -import torch from torch.optim import Optimizer + from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin @@ -27,4 +27,4 @@ def setup(self, trainer, model): return super().setup(trainer, model) def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs): - xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs}) \ No newline at end of file + xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs}) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 2ed14b16b531d..e8a6511798664 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -15,7 +15,8 @@ from typing import Callable, Generator import torch -from torch.optim import Optimizer, LBFGS +from torch.optim import LBFGS, Optimizer + from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType @@ -63,7 +64,9 @@ def backward( return closure_loss - def pre_optimizer_step(self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs) -> bool: + def pre_optimizer_step( + self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs + ) -> bool: """always called before the optimizer step. Checks that the optimizer is not LBFGS, as this one is not supported by native amp """ @@ -76,7 +79,7 @@ def pre_optimizer_step(self, pl_module: LightningModule, optimizer: Optimizer, o if not pl_module.automatic_optimization: self.scaler.unscale_(optimizer) - + pl_module.trainer.call_hook("on_after_backward") return False diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index f46582be2d60b..2216d3ae46d53 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Any, Generator, Sequence, Tuple, Union, Callable +from typing import Any, Callable, Generator, Sequence, Tuple, Union import torch from torch.nn import Module @@ -76,7 +76,9 @@ def backward( return closure_loss - def pre_optimizer_step(self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, closure: Callable, **kwargs) -> bool: + def pre_optimizer_step( + self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, closure: Callable, **kwargs + ) -> bool: """Hook to do something before each optimizer step.""" return True From a55fd4a47c9bdcf93e22bf2f231f48d7f6af38d7 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 6 Feb 2021 13:57:42 +0000 Subject: [PATCH 3/3] update --- tests/trainer/optimization/test_manual_optimization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 8ab00817c4764..30fc4d4ed08e8 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -613,7 +613,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): assert torch.all(self.layer.weight.grad == 0) self.manual_backward(loss_1, opt_a) - opt_a.step(make_optimizer_step=True) + opt_a.step() # fake discriminator loss_2 = self(x) @@ -625,7 +625,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): self.manual_backward(loss_2, opt_a, retain_graph=True) assert self.layer.weight.grad is not None - opt_b.step(make_optimizer_step=True) + opt_b.step() def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer @@ -707,7 +707,7 @@ def optimizer_closure(): weight_before = self.layer.weight.clone() - opt.step(closure=optimizer_closure, make_optimizer_step=True) + opt.step(closure=optimizer_closure) weight_after = self.layer.weight.clone() assert not torch.equal(weight_before, weight_after) @@ -767,7 +767,7 @@ def optimizer_closure(): weight_before = self.layer.weight.clone() - opt.step(closure=optimizer_closure, make_optimizer_step=True) + opt.step(closure=optimizer_closure) weight_after = self.layer.weight.clone() if not self.trainer.train_loop.should_accumulate():