diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e126d948090c..f4159606cddbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,22 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [1.3.3] - 2021-05-27 + +### Changed + +- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)) + +### Fixed + +- Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608)) +- Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592)) +- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566)) +- Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674)) +- Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677)) +- Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692)) + + ## [1.3.2] - 2021-05-18 ### Changed @@ -18,9 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed setting correct `DistribType` for `ddp_cpu` (spawn) backend ([#7492](https://github.com/PyTorchLightning/pytorch-lightning/pull/7492)) - Fixed incorrect number of calls to LR scheduler when `check_val_every_n_epoch > 1` ([#7032](https://github.com/PyTorchLightning/pytorch-lightning/pull/7032)) - ## [1.3.1] - 2021-05-11 + ### Fixed - Fixed DeepSpeed with IterableDatasets ([#7362](https://github.com/PyTorchLightning/pytorch-lightning/pull/7362)) diff --git a/dockers/base-xla/Dockerfile b/dockers/base-xla/Dockerfile index 5049cb4b8d8f4..3d25b3b5a935a 100644 --- a/dockers/base-xla/Dockerfile +++ b/dockers/base-xla/Dockerfile @@ -14,7 +14,7 @@ FROM google/cloud-sdk:slim -MAINTAINER PyTorchLightning +LABEL maintainer="PyTorchLightning " # CALL: docker image build -t pytorch-lightning:XLA-extras-py3.6 -f dockers/base-xla/Dockerfile . --build-arg PYTHON_VERSION=3.6 # This Dockerfile installs pytorch/xla 3.7 wheels. There are also 3.6 wheels available; see below. diff --git a/dockers/nvidia/Dockerfile b/dockers/nvidia/Dockerfile index 2027c4f3d5d7d..fbfd2224a66a9 100644 --- a/dockers/nvidia/Dockerfile +++ b/dockers/nvidia/Dockerfile @@ -15,7 +15,7 @@ # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes FROM nvcr.io/nvidia/pytorch:21.04-py3 -MAINTAINER PyTorchLightning +LABEL maintainer="PyTorchLightning " ARG LIGHTNING_VERSION="" diff --git a/dockers/release/Dockerfile b/dockers/release/Dockerfile index 5cd53385f660b..3e2edee0466b2 100644 --- a/dockers/release/Dockerfile +++ b/dockers/release/Dockerfile @@ -17,7 +17,7 @@ ARG PYTORCH_VERSION=1.5 FROM pytorchlightning/pytorch_lightning:base-cuda-py${PYTHON_VERSION}-torch${PYTORCH_VERSION} -MAINTAINER PyTorchLightning +LABEL maintainer="PyTorchLightning " ARG LIGHTNING_VERSION="" diff --git a/dockers/tpu-tests/Dockerfile b/dockers/tpu-tests/Dockerfile index 93d6244121891..b7ec2445786fd 100644 --- a/dockers/tpu-tests/Dockerfile +++ b/dockers/tpu-tests/Dockerfile @@ -17,7 +17,7 @@ ARG PYTORCH_VERSION=1.6 FROM pytorchlightning/pytorch_lightning:base-xla-py${PYTHON_VERSION}-torch${PYTORCH_VERSION} -MAINTAINER PyTorchLightning +LABEL maintainer="PyTorchLightning " #SHELL ["/bin/bash", "-c"] diff --git a/pytorch_lightning/__about__.py b/pytorch_lightning/__about__.py index 8073e34802df2..f62cef02182c1 100644 --- a/pytorch_lightning/__about__.py +++ b/pytorch_lightning/__about__.py @@ -1,7 +1,7 @@ import time _this_year = time.strftime("%Y") -__version__ = '1.3.2' +__version__ = '1.3.3' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index be9d2f44356f5..60edcc339d671 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -283,6 +283,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0): self.main_progress_bar = None self.val_progress_bar = None self.test_progress_bar = None + self.predict_progress_bar = None def __getstate__(self): # can't pickle the tqdm objects @@ -290,6 +291,7 @@ def __getstate__(self): state['main_progress_bar'] = None state['val_progress_bar'] = None state['test_progress_bar'] = None + state['predict_progress_bar'] = None return state @property @@ -471,12 +473,14 @@ def print( ): active_progress_bar = None - if not self.main_progress_bar.disable: + if self.main_progress_bar is not None and not self.main_progress_bar.disable: active_progress_bar = self.main_progress_bar - elif not self.val_progress_bar.disable: + elif self.val_progress_bar is not None and not self.val_progress_bar.disable: active_progress_bar = self.val_progress_bar - elif not self.test_progress_bar.disable: + elif self.test_progress_bar is not None and not self.test_progress_bar.disable: active_progress_bar = self.test_progress_bar + elif self.predict_progress_bar is not None and not self.predict_progress_bar.disable: + active_progress_bar = self.predict_progress_bar if active_progress_bar is not None: s = sep.join(map(str, args)) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 51c602add9541..b4fc1cb200fd7 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -71,12 +71,11 @@ def auto_transfer_args(self, *args, **kwargs): def parameter_validation(fn: Callable) -> Callable: """ - Decorator for :meth:`~pytorch_lightning.core.LightningModule.to` method. Validates that the module parameter lengths match after moving to the device. It is useful when tying weights on TPU's. Args: - fn: ``.to`` method + fn: ``model_to_device`` method Note: TPU's require weights to be tied/shared after moving the module to the device. @@ -90,10 +89,10 @@ def parameter_validation(fn: Callable) -> Callable: @wraps(fn) def inner_fn(self, *args, **kwargs): - pre_layer_count = len(list(self.parameters())) + pre_layer_count = len(list(self.model.parameters())) module = fn(self, *args, **kwargs) - self.on_post_move_to_device() - post_layer_count = len(list(self.parameters())) + self.model.on_post_move_to_device() + post_layer_count = len(list(self.model.parameters())) if not pre_layer_count == post_layer_count: rank_zero_warn( diff --git a/pytorch_lightning/overrides/torch_distributed.py b/pytorch_lightning/overrides/torch_distributed.py index 67b64c046dc18..ebee83828e0d7 100644 --- a/pytorch_lightning/overrides/torch_distributed.py +++ b/pytorch_lightning/overrides/torch_distributed.py @@ -3,7 +3,7 @@ import torch -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 log = logging.getLogger(__name__) @@ -88,7 +88,7 @@ def _broadcast_object_list(object_list, src=0, group=None): object_list[i] = _tensor_to_object(obj_view, obj_size) -if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available(): +if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.is_available(): from torch.distributed.distributed_c10d import broadcast_object_list else: broadcast_object_list = _broadcast_object_list diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index fce325f322cc3..7e971de8bf181 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -15,6 +15,7 @@ import torch +from pytorch_lightning.core.decorators import parameter_validation from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import move_data_to_device @@ -43,6 +44,7 @@ def on_tpu(self) -> bool: def is_distributed(self) -> bool: return False + @parameter_validation def model_to_device(self) -> None: self.model.to(self.root_device) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 1d4a38498b20d..c232d86ed713c 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -23,6 +23,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl +from pytorch_lightning.core.decorators import parameter_validation from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader @@ -171,6 +172,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: if self.global_rank == 0: time.sleep(2) + @parameter_validation def model_to_device(self) -> None: self.device = xm.xla_device() self.model = self.wrapped_model.to(self.device) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b3621ee176677..943864138e371 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -526,6 +526,8 @@ def run_training_epoch(self): self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) self.trainer.checkpoint_connector.has_trained = True + self.trainer.total_batch_idx += 1 + # max steps reached, end training if ( self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1 @@ -539,8 +541,6 @@ def run_training_epoch(self): if self.trainer.should_stop: break - self.trainer.total_batch_idx += 1 - # stop epoch if we limited the number of training batches if self._num_training_batches_reached(is_last_batch): break @@ -574,9 +574,8 @@ def run_training_epoch(self): self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True - # increment the global step once - # progress global step according to grads progress - self.increment_accumulated_grad_global_step() + if batch_output.signal != -1: + self.increment_accumulated_grad_global_step() def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: # inform logger the batch loop has finished @@ -727,7 +726,9 @@ def train_step_and_backward_closure(): # optimizer step self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) - + if len(self.trainer.optimizers) > 1: + # revert back to previous state + self.trainer.lightning_module.untoggle_optimizer(opt_idx) else: self._curr_step_result = self.training_step( split_batch, batch_idx, opt_idx, self.trainer.hiddens @@ -838,10 +839,6 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, "training_step returned None. If this was on purpose, ignore this warning..." ) - if len(self.trainer.optimizers) > 1: - # revert back to previous state - self.trainer.lightning_module.untoggle_optimizer(opt_idx) - return result def _check_finite(self, loss: torch.Tensor) -> None: diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 303af3a117d81..40173a0cd3aca 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -160,7 +160,10 @@ def _run_power_scaling( else: raise # some other error not memory related - if not changed: + if changed: + # Force the train dataloader to reset as the batch size has changed + trainer.reset_train_dataloader(model) + else: break return new_size @@ -192,7 +195,10 @@ def _run_binsearch_scaling( else: new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') - if not changed: + if changed: + # Force the train dataloader to reset as the batch size has changed + trainer.reset_train_dataloader(model) + else: break except RuntimeError as exception: diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index a6a26b142bc16..13f16d9b426ac 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -17,8 +17,6 @@ import torch from torch.nn import Module -from pytorch_lightning.core.decorators import parameter_validation - class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ['device', 'dtype'] @@ -47,7 +45,6 @@ def device(self) -> Union[str, torch.device]: return device - @parameter_validation def to(self, *args, **kwargs) -> Module: """Moves and/or casts the parameters and buffers. @@ -84,9 +81,6 @@ def to(self, *args, **kwargs) -> Module: ... def __init__(self, weight: torch.Tensor): ... super().__init__() ... self.register_buffer('weight', weight) - ... - ... def on_post_move_to_device(self): - ... pass >>> _ = torch.manual_seed(0) >>> module = ExampleModule(torch.rand(3, 4)) >>> module.weight #doctest: +ELLIPSIS diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index ecb0101a2279e..b25cbbac467b8 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -1,12 +1,14 @@ -from typing import Any, Dict, Iterator, List, Union - -import torch -from torchmetrics import Metric """ Convention: - Do not include any `_TYPE` suffix - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`) """ + +from typing import Any, Dict, Iterator, List, Union + +import torch +from torchmetrics import Metric + _METRIC = Union[Metric, torch.Tensor, int, float] STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]] EPOCH_OUTPUT = List[STEP_OUTPUT] diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index 2d096ee6be2a7..b57894816090d 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -95,25 +95,21 @@ def test_weight_tying_warning(tmpdir, capsys=None): trainer.fit(model) -# @RunIf(tpu=True) -# @pl_multi_process_test -# def test_if_weights_tied(tmpdir, capsys=None): -# """ -# Test if weights are properly tied on `on_post_move_to_device`. -# Ensure no warning for parameter mismatch is thrown. -# """ - -# # TODO (kaushikb11): Add `parameter_validation` specific to TPU Accelerators -# class Model(WeightSharingModule): +@RunIf(tpu=True) +@pl_multi_process_test +def test_if_weights_tied(tmpdir, capsys=None): + """ + Test if weights are properly tied on `on_post_move_to_device`. + Ensure no warning for parameter mismatch is thrown. + """ -# def on_post_move_to_device(self): -# self.layer_3.weight = self.layer_1.weight + class Model(WeightSharingModule): -# model = Model() -# trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) + def on_post_move_to_device(self): + self.layer_3.weight = self.layer_1.weight -# with pytest.warns(UserWarning) as warnings: -# trainer.fit(model) + model = Model() + trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) -# assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list)) -# assert len(trainer.test(model)) == 1 + with pytest.warns(UserWarning, match="The model layers do not match"): + trainer.fit(model) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 76f1e4cb0570f..70e20737e4a47 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import pickle import sys from typing import Optional, Union from unittest import mock @@ -432,6 +433,10 @@ def test_step(self, *args, **kwargs): self.print("test_step") return super().test_step(*args, **kwargs) + def predict_step(self, *args, **kwargs): + self.print("predict_step") + return super().predict_step(*args, **kwargs) + @mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") def test_progress_bar_print(tqdm_write, tmpdir): @@ -444,16 +449,45 @@ def test_progress_bar_print(tqdm_write, tmpdir): limit_train_batches=1, limit_val_batches=1, limit_test_batches=1, + limit_predict_batches=1, max_steps=1, callbacks=[bar], ) trainer.fit(model) trainer.test(model) - assert tqdm_write.call_count == 3 + trainer.predict(model) + assert tqdm_write.call_count == 4 assert tqdm_write.call_args_list == [ call("training_step", end="", file=None, nolock=False), call("validation_step", end=os.linesep, file=sys.stderr, nolock=False), call("test_step", end=os.linesep, file=None, nolock=False), + call("predict_step", end=os.linesep, file=None, nolock=False), + ] + + +@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") +def test_progress_bar_print_no_train(tqdm_write, tmpdir): + """ Test that printing in the LightningModule redirects arguments to the progress bar without training. """ + model = PrintModel() + bar = ProgressBar() + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_val_batches=1, + limit_test_batches=1, + limit_predict_batches=1, + max_steps=1, + callbacks=[bar], + ) + + trainer.validate(model) + trainer.test(model) + trainer.predict(model) + assert tqdm_write.call_count == 3 + assert tqdm_write.call_args_list == [ + call("validation_step", end=os.linesep, file=sys.stderr, nolock=False), + call("test_step", end=os.linesep, file=None, nolock=False), + call("predict_step", end=os.linesep, file=None, nolock=False), ] @@ -469,16 +503,33 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): limit_train_batches=1, limit_val_batches=1, limit_test_batches=1, + limit_predict_batches=1, max_steps=1, callbacks=[bar], ) bar.disable() trainer.fit(model) - trainer.test(model) + trainer.test(model, verbose=False) + trainer.predict(model) mock_print.assert_has_calls([ call("training_step", end=""), call("validation_step", file=ANY), call("test_step"), + call("predict_step"), ]) tqdm_write.assert_not_called() + + +def test_progress_bar_can_be_pickled(): + bar = ProgressBar() + trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1) + model = BoringModel() + + pickle.dumps(bar) + trainer.fit(model) + pickle.dumps(bar) + trainer.test(model) + pickle.dumps(bar) + trainer.predict(model) + pickle.dumps(bar) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 24bf29a9e2eac..58ebc8e271be6 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -225,26 +225,6 @@ def train_dataloader(self): trainer.fit(model) -@pytest.mark.parametrize('max_epochs,batch_idx_', [(2, 5), (3, 8), (4, 12)]) -def test_on_train_batch_start_hook(max_epochs, batch_idx_): - - class CurrentModel(BoringModel): - - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - if batch_idx == batch_idx_: - return -1 - - model = CurrentModel() - trainer = Trainer(max_epochs=max_epochs) - trainer.fit(model) - if batch_idx_ > len(model.val_dataloader()) - 1: - assert trainer.batch_idx == len(model.val_dataloader()) - 1 - assert trainer.global_step == len(model.val_dataloader()) * max_epochs - else: - assert trainer.batch_idx == batch_idx_ - assert trainer.global_step == (batch_idx_ + 1) * max_epochs - - def test_trainer_model_hook_system(tmpdir): """Test the LightningModule hook system.""" diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index 2d32d8c8878e4..94becf6488fc3 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from pytorch_lightning import seed_everything, Trainer @@ -201,3 +202,23 @@ def run_training(**trainer_kwargs): num_sanity_val_steps=2, ) assert torch.allclose(sequence0, sequence1) + + +@pytest.mark.parametrize(['max_epochs', 'batch_idx_'], [(2, 5), (3, 8), (4, 12)]) +def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_): + + class CurrentModel(BoringModel): + + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + if batch_idx == batch_idx_: + return -1 + + model = CurrentModel() + trainer = Trainer(max_epochs=max_epochs, limit_train_batches=10) + trainer.fit(model) + if batch_idx_ > trainer.num_training_batches - 1: + assert trainer.batch_idx == trainer.num_training_batches - 1 + assert trainer.global_step == trainer.num_training_batches * max_epochs + else: + assert trainer.batch_idx == batch_idx_ + assert trainer.global_step == batch_idx_ * max_epochs diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py index 24b32c8725963..aba3b53248a57 100644 --- a/tests/trainer/optimization/test_multiple_optimizers.py +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -168,3 +168,68 @@ def training_step(self, batch, batch_idx): with pytest.raises(ValueError, match='`training_step` is missing the `optimizer_idx`'): trainer.fit(TestModel()) + + +def test_custom_optimizer_step_with_multiple_optimizers(tmpdir): + """ + This tests ensures custom optimizer_step works, + even when optimizer.step is not called for a particular optimizer + """ + + class TestModel(BoringModel): + training_step_called = [0, 0] + optimizer_step_called = [0, 0] + + def __init__(self): + super().__init__() + self.layer_a = torch.nn.Linear(32, 2) + self.layer_b = torch.nn.Linear(32, 2) + + def configure_optimizers(self): + opt_a = torch.optim.SGD(self.layer_a.parameters(), lr=0.001) + opt_b = torch.optim.SGD(self.layer_b.parameters(), lr=0.001) + return opt_a, opt_b + + def training_step(self, batch, batch_idx, optimizer_idx): + self.training_step_called[optimizer_idx] += 1 + x = self.layer_a(batch[0]) if (optimizer_idx == 0) else self.layer_b(batch[0]) + loss = torch.nn.functional.mse_loss(x, torch.ones_like(x)) + return loss + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def optimizer_step( + self, + epoch, + batch_idx, + optimizer, + optimizer_idx, + optimizer_closure, + **_, + ): + # update first optimizer every step + if optimizer_idx == 0: + self.optimizer_step_called[optimizer_idx] += 1 + optimizer.step(closure=optimizer_closure) + + # update second optimizer every 2 steps + if optimizer_idx == 1: + if batch_idx % 2 == 0: + self.optimizer_step_called[optimizer_idx] += 1 + optimizer.step(closure=optimizer_closure) + + model = TestModel() + model.val_dataloader = None + + trainer = pl.Trainer( + default_root_dir=tmpdir, + limit_train_batches=4, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + assert model.training_step_called == [4, 2] + assert model.optimizer_step_called == [4, 2] diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index 608cb8c6778bf..a74af3862c473 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -197,31 +197,24 @@ def test_datamodule_parameter(tmpdir): def test_accumulation_and_early_stopping(tmpdir): - """ Test that early stopping of learning rate finder works, and that - accumulation also works for this feature """ + """ Test that early stopping of learning rate finder works, and that accumulation also works for this feature """ - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) + class TestModel(BoringModel): - before_lr = hparams.get('learning_rate') - # logger file to get meta + def __init__(self): + super().__init__() + self.lr = 1e-3 + + model = TestModel() trainer = Trainer( default_root_dir=tmpdir, accumulate_grad_batches=2, ) - lrfinder = trainer.tuner.lr_find(model, early_stop_threshold=None) - after_lr = lrfinder.suggestion() - expected_num_lrs = 100 - expected_batch_idx = 200 - 1 - - assert before_lr != after_lr, \ - 'Learning rate was not altered after running learning rate finder' - assert len(lrfinder.results['lr']) == expected_num_lrs, \ - 'Early stopping for learning rate finder did not work' - assert lrfinder._total_batch_idx == expected_batch_idx, \ - 'Accumulation parameter did not work' + assert lrfinder.suggestion() != 1e-3 + assert len(lrfinder.results['lr']) == 100 + assert lrfinder._total_batch_idx == 200 def test_suggestion_parameters_work(tmpdir): diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 7d4e05000d5da..f9e132662b220 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -24,14 +24,14 @@ from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate -from tests.helpers import BoringDataModule, BoringModel +from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.datamodules import MNISTDataModule from tests.helpers.runif import RunIf class BatchSizeDataModule(BoringDataModule): - def __init__(self, batch_size=None): + def __init__(self, batch_size): super().__init__() if batch_size is not None: self.batch_size = batch_size @@ -42,21 +42,23 @@ def train_dataloader(self): class BatchSizeModel(BoringModel): - def __init__(self, batch_size=None): + def __init__(self, batch_size): super().__init__() if batch_size is not None: self.batch_size = batch_size + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1)) -@pytest.mark.parametrize( - "model,datamodule", [ - (BatchSizeModel(2), None), - (BatchSizeModel(2), BatchSizeDataModule(2)), - (BatchSizeModel(2), BatchSizeDataModule(None)), - (BatchSizeModel(None), BatchSizeDataModule(2)), - ] -) -def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule): + +@pytest.mark.parametrize(["model_bs", "dm_bs"], [ + (2, -1), + (2, 2), + (2, None), + (None, 2), + (16, 16), +]) +def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_bs): """ Test the tuner method `Tuner.scale_batch_size` with a datamodule. """ trainer = Trainer( default_root_dir=tmpdir, @@ -65,14 +67,21 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod max_epochs=1, ) tuner = Tuner(trainer) - new_batch_size = tuner.scale_batch_size( - model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule - ) + + model = BatchSizeModel(model_bs) + datamodule = BatchSizeDataModule(dm_bs) if dm_bs != -1 else None + + new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule) assert new_batch_size == 16 - if hasattr(model, "batch_size"): - assert model.batch_size == 16 - if datamodule is not None and hasattr(datamodule, "batch_size"): - assert datamodule.batch_size == 16 + + if model_bs is not None: + assert model.batch_size == new_batch_size + if dm_bs == -1: + # datamodule batch size takes precedence + assert trainer.train_dataloader.loaders.batch_size == new_batch_size + if dm_bs not in (-1, None): + assert datamodule.batch_size == new_batch_size + assert trainer.train_dataloader.loaders.batch_size == new_batch_size def test_model_reset_correctly(tmpdir):