diff --git a/CHANGELOG.md b/CHANGELOG.md index 1239f349e8f5f..d2a2bc5bef6f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -126,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) +- Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677)) + + - Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592)) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a555146875eb5..ca50b088c665b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -577,9 +577,8 @@ def run_training_epoch(self): self.trainer._run_evaluation(on_epoch=True) self.trainer.training = True - # increment the global step once - # progress global step according to grads progress - self.increment_accumulated_grad_global_step() + if batch_output.signal != -1: + self.increment_accumulated_grad_global_step() def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: # inform logger the batch loop has finished diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 78f8d2c0a94e9..e8351072d2cc0 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -229,26 +229,6 @@ def train_dataloader(self): trainer.fit(model) -@pytest.mark.parametrize('max_epochs,batch_idx_', [(2, 5), (3, 8), (4, 12)]) -def test_on_train_batch_start_hook(max_epochs, batch_idx_): - - class CurrentModel(BoringModel): - - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - if batch_idx == batch_idx_: - return -1 - - model = CurrentModel() - trainer = Trainer(max_epochs=max_epochs) - trainer.fit(model) - if batch_idx_ > len(model.val_dataloader()) - 1: - assert trainer.train_loop.batch_idx == len(model.val_dataloader()) - 1 - assert trainer.global_step == len(model.val_dataloader()) * max_epochs - else: - assert trainer.train_loop.batch_idx == batch_idx_ - assert trainer.global_step == (batch_idx_ + 1) * max_epochs - - def test_trainer_model_hook_system(tmpdir): """Test the LightningModule hook system.""" diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index 2d32d8c8878e4..db87a0baabe9d 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from pytorch_lightning import seed_everything, Trainer @@ -201,3 +202,23 @@ def run_training(**trainer_kwargs): num_sanity_val_steps=2, ) assert torch.allclose(sequence0, sequence1) + + +@pytest.mark.parametrize(['max_epochs', 'batch_idx_'], [(2, 5), (3, 8), (4, 12)]) +def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_): + + class CurrentModel(BoringModel): + + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + if batch_idx == batch_idx_: + return -1 + + model = CurrentModel() + trainer = Trainer(max_epochs=max_epochs, limit_train_batches=10) + trainer.fit(model) + if batch_idx_ > trainer.num_training_batches - 1: + assert trainer.train_loop.batch_idx == trainer.num_training_batches - 1 + assert trainer.global_step == trainer.num_training_batches * max_epochs + else: + assert trainer.train_loop.batch_idx == batch_idx_ + assert trainer.global_step == batch_idx_ * max_epochs