Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `Strategy.on_tpu` property ([#11536](https://github.com/PyTorchLightning/pytorch-lightning/pull/11536))


- Removed `FitLoop.current_epoch` getter and setter ([#11562](https://github.com/PyTorchLightning/pytorch-lightning/pull/11562))


- Removed access to `_short_id` in `NeptuneLogger` ([#11517](https://github.com/PyTorchLightning/pytorch-lightning/pull/11517))


Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,14 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None:
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
"""
epoch = trainer.current_epoch
global_step = trainer.global_step

self._validate_monitor_key(trainer)

# track epoch when ckpt was last checked
global_step = trainer.global_step
self._last_global_step_saved = global_step

# what can be monitored
monitor_candidates = self._monitor_candidates(trainer, epoch=epoch, step=global_step)
monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=global_step)

# callback supports multiple simultaneous modes
# here we call each mode sequentially
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,7 @@ def example_input_array(self, example: Any) -> None:

@property
def current_epoch(self) -> int:
"""The current epoch in the Trainer.

If no Trainer is attached, this propery is 0.
"""
"""The current epoch in the ``Trainer``, or 0 if not attached."""
return self.trainer.current_epoch if self.trainer else 0

@property
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def run(self, *args, **kwargs):
self._restarting = False
except StopIteration:
break
self._restarting = False

output = self.on_run_end()
return output
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,7 @@ def _is_validation_done(self) -> bool:

@property
def done(self) -> bool:
"""Returns whether the training should be stopped.

The criteria are that the number of steps reached the max steps, the last batch is reached or the trainer
signals to stop (e.g. by early stopping).
"""
"""Evaluates when to leave the loop."""
return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop

def connect( # type: ignore[override]
Expand Down
24 changes: 5 additions & 19 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,6 @@ def __init__(
self._is_fresh_start_epoch: bool = True
self._outputs: _EPOCH_OUTPUTS_TYPE = []

@property
def current_epoch(self) -> int:
"""Return the current epoch."""
return self.epoch_progress.current.completed

@current_epoch.setter
def current_epoch(self, value: int) -> None:
"""Setter for the current epoch."""
self.epoch_progress.current.completed = value

@property
def global_step(self) -> int:
"""Returns the global step."""
Expand Down Expand Up @@ -149,19 +139,15 @@ def _results(self) -> _ResultCollection:

@property
def done(self) -> bool:
"""Evaluates when to leave the loop.

Returns True if trainer.should_stop was set (e.g. by early stopping) or if the maximum number of steps or epochs
is reached.
"""
"""Evaluates when to leave the loop."""
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
stop_epochs = _is_max_limit_reached(self.current_epoch, self.max_epochs)
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.completed, self.max_epochs)

should_stop = False
if self.trainer.should_stop:
# early stopping
met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True
met_min_epochs = self.epoch_progress.current.completed >= self.min_epochs if self.min_epochs else True
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
if met_min_epochs and met_min_steps:
should_stop = True
Expand Down Expand Up @@ -219,7 +205,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
):
# set seed for distributed sampler (enables shuffling for each epoch)
self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch)
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.completed)

# changing gradient according accumulation_scheduler
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
Expand Down Expand Up @@ -307,7 +293,7 @@ def on_run_end(self) -> None:
# Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
# To simulate that current behavior, we decrement here.
# TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
self.current_epoch = max(self.current_epoch - 1, 0)
self.epoch_progress.current.completed = max(self.epoch_progress.current.completed - 1, 0)

# hook
self.trainer._call_callback_hooks("on_train_end")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def restore_loops(self) -> None:
return

self.trainer.fit_loop.global_step = self._loaded_checkpoint["global_step"]
self.trainer.fit_loop.current_epoch = self._loaded_checkpoint["epoch"]
# set the `current_epoch` value for old checkpoints without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
self.trainer.fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]

assert self.trainer.state.fn is not None
state_dict = self._loaded_checkpoint.get("loops")
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ def __init__(
To enable infinite training, set ``max_epochs = -1``.

min_epochs: Force training for at least these many epochs. Disabled by default (None).
If both min_epochs and min_steps are not specified, defaults to ``min_epochs = 1``.

max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
Expand Down Expand Up @@ -2349,7 +2348,8 @@ def global_step(self) -> int:

@property
def current_epoch(self) -> int:
return self.fit_loop.current_epoch
"""The current epoch, updated after the epoch end hooks are run."""
return self.fit_loop.epoch_progress.current.completed

@property
def max_epochs(self) -> int:
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def scale_batch_size(

# Save initial model, that is loaded after batch size is found
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
trainer.fit_loop.current_epoch -= 1
trainer.fit_loop.epoch_progress.current.completed -= 1
trainer.fit_loop.global_step -= 1
trainer.save_checkpoint(ckpt_path)
trainer.fit_loop.current_epoch += 1
trainer.fit_loop.epoch_progress.current.completed += 1
trainer.fit_loop.global_step += 1
params = __scale_batch_dump_params(trainer)

Expand Down Expand Up @@ -110,7 +110,6 @@ def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None:
trainer.auto_scale_batch_size = None # prevent recursion
trainer.auto_lr_find = False # avoid lr find being called multiple times
trainer.fit_loop.current_epoch = 0
trainer.fit_loop.max_steps = steps_per_trial # take few steps
trainer.logger = DummyLogger() if trainer.logger is not None else None
trainer.callbacks = [] # not needed before full run
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,10 @@ def lr_find(

# Save initial model, that is loaded after learning rate is found
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
trainer.fit_loop.current_epoch -= 1
trainer.fit_loop.epoch_progress.current.completed -= 1
trainer.fit_loop.global_step -= 1
trainer.save_checkpoint(ckpt_path)
trainer.fit_loop.current_epoch += 1
trainer.fit_loop.epoch_progress.current.completed += 1
trainer.fit_loop.global_step += 1
params = __lr_finder_dump_params(trainer)

Expand Down
9 changes: 3 additions & 6 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ def on_validation_epoch_end(self):
for epoch in range(max_epochs):
score = model.scores[epoch]
expected_score = getattr(model, f"{monitor}s")[epoch].mean().item()
expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt"
assert math.isclose(score, expected_score, rel_tol=1e-4)

expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt"
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
assert chk["epoch"] == epoch + 1
assert chk["global_step"] == limit_train_batches * (epoch + 1)
Expand Down Expand Up @@ -462,7 +462,6 @@ class ModelCheckpointExtensionTest(ModelCheckpoint):

def test_model_checkpoint_file_extension(tmpdir):
"""Test ModelCheckpoint with different file extension."""

model = LogInTwoMethods()
model_checkpoint = ModelCheckpointExtensionTest(
monitor="early_stop_on", dirpath=tmpdir, save_top_k=1, save_last=True
Expand Down Expand Up @@ -613,7 +612,7 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs):
)
trainer.fit(model)

# check that the correct ckpts were created
# check that the correct ckpts were created, the modulo condition is checked in `ModelCheckpoint`
expected = [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)

Expand Down Expand Up @@ -967,15 +966,13 @@ def assert_checkpoint_log_dir(idx):
assert_checkpoint_content(ckpt_dir)

# load from checkpoint
trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)]
trainer = pl.Trainer(**trainer_config)
assert_trainer_init(trainer)

model = ExtendedBoringModel()

trainer.test(model)
assert trainer.global_step == 0
assert trainer.current_epoch == 0
assert_trainer_init(trainer)

trainer.fit(model, ckpt_path=chk)
assert trainer.global_step == epochs * limit_train_batches
Expand Down
4 changes: 0 additions & 4 deletions tests/checkpointing/test_trainer_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy

import torch

Expand Down Expand Up @@ -53,8 +52,6 @@ def validation_step(self, batch, batch_idx):
assert os.listdir(tmpdir) == ["epoch=00.ckpt"]

best_model_paths = [checkpoint_callback.best_model_path]
results = []

for idx in range(3, 6):
# load from checkpoint
trainer = pl.Trainer(
Expand All @@ -67,7 +64,6 @@ def validation_step(self, batch, batch_idx):
)
trainer.fit(model, ckpt_path=best_model_paths[-1])
trainer.test()
results.append(deepcopy(trainer.callback_metrics))
best_model_paths.append(trainer.checkpoint_callback.best_model_path)

for idx, best_model_path in enumerate(best_model_paths):
Expand Down
5 changes: 0 additions & 5 deletions tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
_ResultMetric,
_Sync,
)
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -373,13 +372,9 @@ def __repr__(self) -> str:


def result_collection_reload(accelerator="auto", devices=1, **kwargs):

"""This test is going to validate _ResultCollection is properly being reload and final accumulation with Fault
Tolerant Training is correct."""

if not _fault_tolerant_training():
pytest.skip("Fault tolerant not available")

class CustomException(Exception):
pass

Expand Down
69 changes: 69 additions & 0 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,75 @@ def on_test_start(self):
trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt)


def test_correct_step_and_epoch(tmpdir):
model = BoringModel()
first_max_epochs = 2
train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir, max_epochs=first_max_epochs, limit_train_batches=train_batches, limit_val_batches=0
)
assert trainer.current_epoch == 0
assert trainer.global_step == 0

trainer.fit(model)
# TODO(@carmocca): should not need `-1`
assert trainer.current_epoch == first_max_epochs - 1
assert trainer.global_step == first_max_epochs * train_batches

# save checkpoint after loop ends, training end called, epoch count increased
ckpt_path = str(tmpdir / "model.ckpt")
trainer.save_checkpoint(ckpt_path)

ckpt = torch.load(ckpt_path)
assert ckpt["epoch"] == first_max_epochs
# TODO(@carmocca): should not need `+1`
assert ckpt["global_step"] == first_max_epochs * train_batches + 1

max_epochs = first_max_epochs + 2
trainer = Trainer(
default_root_dir=tmpdir, max_epochs=max_epochs, limit_train_batches=train_batches, limit_val_batches=0
)
# the ckpt state is not loaded at this point
assert trainer.current_epoch == 0
assert trainer.global_step == 0

class TestModel(BoringModel):
def on_pretrain_routine_end(self) -> None:
assert self.trainer.current_epoch == first_max_epochs
# TODO(@carmocca): should not need `+1`
assert self.trainer.global_step == first_max_epochs * train_batches + 1

trainer.fit(TestModel(), ckpt_path=ckpt_path)
# TODO(@carmocca): should not need `-1`
assert trainer.current_epoch == max_epochs - 1
# TODO(@carmocca): should not need `+1`
assert trainer.global_step == max_epochs * train_batches + 1


def test_fit_twice(tmpdir):
epochs = []

class TestModel(BoringModel):
def on_train_epoch_end(self, *_):
epochs.append(self.current_epoch)

trainer = Trainer(
max_epochs=2,
limit_train_batches=1,
limit_val_batches=1,
default_root_dir=tmpdir,
logger=False,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)
trainer.fit(TestModel())
trainer.fit_loop.max_epochs = 4
trainer.fit(TestModel())
# TODO(@carmocca): 1 should not be duplicated
assert epochs == [0, 1, 1, 2, 3]


def test_try_resume_from_non_existing_checkpoint(tmpdir):
"""Test that trying to resume from non-existing `ckpt_path` fails with an error."""
model = BoringModel()
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def mock_save_function(filepath, *args):

# emulate callback's calls during the training
for i, loss in enumerate(losses):
trainer.fit_loop.current_epoch = i
trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch`
trainer.fit_loop.global_step = i
trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)})
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)
Expand Down