Skip to content

Commit f278ac4

Browse files
authored
Revert/Fix: epoch indexing from 1, to be from 0 (#2289)
* Revert "deprecated: epoch indexing from 1 (#2206)" This reverts commit f94b919 * chlog * grad index * Apply suggestions from code review * tests * fix * test
1 parent 554fb47 commit f278ac4

File tree

11 files changed

+27
-25
lines changed

11 files changed

+27
-25
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
### Changed
1212

13+
- Changed epoch indexing from 0 instead of 1 ([#2289](https://github.com/PyTorchLightning/pytorch-lightning/pull/2289))
14+
1315
### Deprecated
1416

1517
### Removed

pytorch_lightning/callbacks/gradient_accumulation_scheduler.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@ class GradientAccumulationScheduler(Callback):
1717
Args:
1818
scheduling: scheduling in format {epoch: accumulation_factor}
1919
20-
.. warning::
21-
Epochs indexing starts from "1" until v0.6.x,
22-
but will start from "0" in v0.8.0.
23-
2420
Example::
2521
2622
>>> from pytorch_lightning import Trainer
@@ -42,13 +38,13 @@ def __init__(self, scheduling: dict):
4238

4339
for key in scheduling:
4440
if not isinstance(key, int) or not isinstance(scheduling[key], int):
45-
raise TypeError("All epochs and accumulation factor must be integers")
41+
raise TypeError("All epoches and accumulation factor must be integers")
4642

4743
minimal_epoch = min(scheduling.keys())
48-
if minimal_epoch < 1:
44+
if minimal_epoch < 0:
4945
raise IndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct")
50-
if minimal_epoch != 1: # if user didnt define first epoch accumulation factor
51-
scheduling.update({1: 1})
46+
if minimal_epoch != 0: # if user didnt define first epoch accumulation factor
47+
scheduling.update({0: 1})
5248

5349
self.scheduling = scheduling
5450
self.epochs = sorted(scheduling.keys())

pytorch_lightning/callbacks/progress.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def total_val_batches(self) -> int:
9696
if trainer.fast_dev_run and trainer.val_dataloaders is not None:
9797
total_val_batches = len(trainer.val_dataloaders)
9898
elif not self.trainer.disable_validation:
99-
is_val_epoch = trainer.current_epoch % trainer.check_val_every_n_epoch == 0
99+
is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0
100100
total_val_batches = sum(trainer.num_val_batches) if is_val_epoch else 0
101101
return total_val_batches
102102

@@ -318,7 +318,7 @@ def on_epoch_start(self, trainer, pl_module):
318318
total_batches = total_train_batches + total_val_batches
319319
if not self.main_progress_bar.disable:
320320
self.main_progress_bar.reset(convert_inf(total_batches))
321-
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}')
321+
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}')
322322

323323
def on_batch_end(self, trainer, pl_module):
324324
super().on_batch_end(trainer, pl_module)

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
518518

519519
# AMP
520520
# run through amp wrapper before going to distributed DP
521-
# TODO: remove in v0.8.0
521+
# TODO: remove with dropping NVIDIA AMP support
522522
if self.use_amp and not self.use_native_amp:
523523
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
524524
self.optimizers = optimizers

pytorch_lightning/trainer/distrib_parts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def single_gpu_train(self, model):
174174
# allow for lr schedulers as well
175175
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
176176

177-
# TODO: update for 0.8.0
177+
# TODO: remove with dropping NVIDIA AMP support
178178
if self.use_amp and not self.use_native_amp:
179179
# An example
180180
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
@@ -240,7 +240,7 @@ def dp_train(self, model):
240240
# wrap the user's forward in autocast and give it back at the end
241241
model.forward = torch.cuda.amp.autocast()(model.forward)
242242

243-
# TODO: remove in v0.8.0
243+
# TODO: remove with dropping NVIDIA AMP support
244244
# check for this bug (amp + dp + !01 doesn't work)
245245
# https://github.com/NVIDIA/apex/issues/227
246246
if self.use_dp and self.use_amp and not self.use_native_amp:

pytorch_lightning/trainer/training_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
323323
structured dictionary
324324
"""
325325
checkpoint = {
326-
'epoch': self.current_epoch,
326+
'epoch': self.current_epoch + 1,
327327
'global_step': self.global_step + 1,
328328
'pytorch-ligthning_version': pytorch_lightning.__version__,
329329
}

pytorch_lightning/trainer/training_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,8 @@ def train(self):
346346
model.on_train_start()
347347

348348
try:
349-
# run all epochs from actual + 1 till the maximal
350-
for epoch in range(self.current_epoch + 1, self.max_epochs + 1):
349+
# run all epochs
350+
for epoch in range(self.current_epoch, self.max_epochs):
351351
# reset train dataloader
352352
if self.reload_dataloaders_every_epoch:
353353
self.reset_train_dataloader(model)
@@ -382,7 +382,7 @@ def train(self):
382382
self.update_learning_rates(interval='epoch')
383383

384384
# early stopping
385-
met_min_epochs = epoch >= self.min_epochs
385+
met_min_epochs = epoch >= self.min_epochs - 1
386386
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
387387

388388
# TODO wrap this logic into the callback
@@ -476,7 +476,7 @@ def run_training_epoch(self):
476476
# RUN VAL STEP
477477
# ---------------
478478
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
479-
can_check_epoch = self.current_epoch % self.check_val_every_n_epoch == 0
479+
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
480480
can_check_val = not self.disable_validation and can_check_epoch
481481
should_check_val = is_val_check_batch or early_stop_epoch
482482
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))

tests/callbacks/test_callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def training_step(self, *args, **kwargs):
286286
result = trainer.fit(model)
287287

288288
assert result == 1, 'training failed to complete'
289-
assert trainer.current_epoch <= trainer.max_epochs
289+
assert trainer.current_epoch < trainer.max_epochs
290290

291291

292292
def test_pickling(tmpdir):

tests/models/test_hooks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def training_epoch_end(self, outputs):
6868
# a metric shared in both methods gets overwritten by epoch_end
6969
assert metrics['shared_metric'] == 111
7070
# metrics are kept after each epoch
71-
for i in range(1, num_epochs + 1):
71+
for i in range(num_epochs):
7272
assert metrics[f'epoch_metric_{i}'] == i
7373

7474

tests/models/test_restore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_dp_resume(tmpdir):
172172
result = trainer.fit(model)
173173

174174
# track epoch before saving. Increment since we finished the current epoch, don't want to rerun
175-
real_global_epoch = trainer.current_epoch
175+
real_global_epoch = trainer.current_epoch + 1
176176

177177
# correct result and ok accuracy
178178
assert result == 1, 'amp + dp model failed to complete'

0 commit comments

Comments
 (0)