Skip to content

Commit 0bf5408

Browse files
committed
Fix global step update when the epoch is skipped (#7677)
* Fix global step update when the epoch is skipped * Update CHANGELOG * Move test
1 parent 2db6b5a commit 0bf5408

File tree

4 files changed

+24
-23
lines changed

4 files changed

+24
-23
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717
- Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592))
1818
- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))
1919
- Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674))
20+
- Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677))
2021

2122
## [1.3.2] - 2021-05-18
2223

pytorch_lightning/trainer/training_loop.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -574,9 +574,8 @@ def run_training_epoch(self):
574574
self.trainer.run_evaluation(on_epoch=True)
575575
self.trainer.training = True
576576

577-
# increment the global step once
578-
# progress global step according to grads progress
579-
self.increment_accumulated_grad_global_step()
577+
if batch_output.signal != -1:
578+
self.increment_accumulated_grad_global_step()
580579

581580
def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
582581
# inform logger the batch loop has finished

tests/models/test_hooks.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -225,26 +225,6 @@ def train_dataloader(self):
225225
trainer.fit(model)
226226

227227

228-
@pytest.mark.parametrize('max_epochs,batch_idx_', [(2, 5), (3, 8), (4, 12)])
229-
def test_on_train_batch_start_hook(max_epochs, batch_idx_):
230-
231-
class CurrentModel(BoringModel):
232-
233-
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
234-
if batch_idx == batch_idx_:
235-
return -1
236-
237-
model = CurrentModel()
238-
trainer = Trainer(max_epochs=max_epochs)
239-
trainer.fit(model)
240-
if batch_idx_ > len(model.val_dataloader()) - 1:
241-
assert trainer.batch_idx == len(model.val_dataloader()) - 1
242-
assert trainer.global_step == len(model.val_dataloader()) * max_epochs
243-
else:
244-
assert trainer.batch_idx == batch_idx_
245-
assert trainer.global_step == (batch_idx_ + 1) * max_epochs
246-
247-
248228
def test_trainer_model_hook_system(tmpdir):
249229
"""Test the LightningModule hook system."""
250230

tests/trainer/loops/test_training_loop.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import pytest
1415
import torch
1516

1617
from pytorch_lightning import seed_everything, Trainer
@@ -201,3 +202,23 @@ def run_training(**trainer_kwargs):
201202
num_sanity_val_steps=2,
202203
)
203204
assert torch.allclose(sequence0, sequence1)
205+
206+
207+
@pytest.mark.parametrize(['max_epochs', 'batch_idx_'], [(2, 5), (3, 8), (4, 12)])
208+
def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_):
209+
210+
class CurrentModel(BoringModel):
211+
212+
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
213+
if batch_idx == batch_idx_:
214+
return -1
215+
216+
model = CurrentModel()
217+
trainer = Trainer(max_epochs=max_epochs, limit_train_batches=10)
218+
trainer.fit(model)
219+
if batch_idx_ > trainer.num_training_batches - 1:
220+
assert trainer.batch_idx == trainer.num_training_batches - 1
221+
assert trainer.global_step == trainer.num_training_batches * max_epochs
222+
else:
223+
assert trainer.batch_idx == batch_idx_
224+
assert trainer.global_step == batch_idx_ * max_epochs

0 commit comments

Comments
 (0)