Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for returning an object of type `Mapping` from `LightningModule.training_step()` ([#18657](https://github.com/Lightning-AI/lightning/pull/18657))


- Added the hook `LightningModule.on_validation_model_zero_grad()` to allow overriding the behavior of zeroing the gradients before entering the validation loop ([#18710](https://github.com/Lightning-AI/lightning/pull/18710))


### Changed

- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
Expand Down Expand Up @@ -289,6 +292,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed numerical issues when reducing values in low precision with `self.log` ([#18686](https://github.com/Lightning-AI/lightning/pull/18686))


- Fixed an issue that would cause the gradients to be erased if validation happened in the middle of a gradient accumulation phase ([#18710](https://github.com/Lightning-AI/lightning/pull/18710))


## [2.0.9] - 2023-09-14

Expand Down
6 changes: 6 additions & 0 deletions src/lightning/pytorch/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import Tensor
from torch.optim.optimizer import Optimizer

from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch.utilities import move_data_to_device
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
Expand Down Expand Up @@ -151,6 +152,11 @@ def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: in

"""

def on_validation_model_zero_grad(self) -> None:
"""Called by the training loop to release gradients before entering the validation loop."""
zero_grad_kwargs = {} if _TORCH_GREATER_EQUAL_2_0 else {"set_to_none": True}
self.zero_grad(**zero_grad_kwargs)

def on_validation_model_eval(self) -> None:
"""Sets the model to eval during the val loop."""
self.trainer.model.eval()
Expand Down
2 changes: 0 additions & 2 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,7 @@ def on_run_start(self) -> None:
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
hooks."""
self._verify_dataloader_idx_requirement()

self._on_evaluation_model_eval()
self.trainer.lightning_module.zero_grad()
self._on_evaluation_start()
self._on_evaluation_epoch_start()

Expand Down
5 changes: 1 addition & 4 deletions src/lightning/pytorch/loops/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,7 @@ def reset(self) -> None:
def on_run_start(self) -> None:
"""Calls ``_on_predict_model_eval``, ``_on_predict_start`` and ``_on_predict_epoch_start`` hooks."""
self._verify_dataloader_idx_requirement()

trainer = self.trainer
call._call_lightning_module_hook(trainer, "on_predict_model_eval")
trainer.lightning_module.zero_grad()
call._call_lightning_module_hook(self.trainer, "on_predict_model_eval")
self._on_predict_start()
self._on_predict_epoch_start()

Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None:
self.trainer.validating = True
# save and reset this state in case validation runs inside training loop (val_check_interval<1.0)
first_loop_iter = self.trainer._logger_connector._first_loop_iter

if not self._should_accumulate():
# clear gradients to not leave any unused memory during validation
call._call_lightning_module_hook(self.trainer, "on_validation_model_zero_grad")

self.val_loop.run()
self.trainer.training = True
self.trainer._logger_connector._first_loop_iter = first_loop_iter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class _LogOptions(TypedDict):
"test_dataloader": None,
"prepare_data": None,
"configure_callbacks": None,
"on_validation_model_zero_grad": None,
"on_validation_model_eval": None,
"on_test_model_eval": None,
"on_validation_model_train": None,
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,9 @@ def _run_stage(self) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
# wait for all to join if on distributed
self.strategy.barrier("run-stage")

zero_grad_kwargs = {} if _TORCH_GREATER_EQUAL_2_0 else {"set_to_none": True}
self.lightning_module.zero_grad(**zero_grad_kwargs)

if self.evaluating:
return self._evaluation_loop.run()
if self.predicting:
Expand Down
31 changes: 31 additions & 0 deletions tests/tests_pytorch/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,3 +851,34 @@ def _get_iterator(self):
3, # teardown on epoch 2, workers from epoch 2 get destroyed
]
assert val_dataloader.shutdown_workers_epochs == expected


def test_validation_during_gradient_accumulation_window(tmp_path):
"""Test that gradients don't get erased when the validation interval falls within the gradient accumulation
phase."""

class ValidationModel(BoringModel):
def on_validation_start(self):
batch_idx = self.trainer.fit_loop.epoch_loop.batch_progress.current.completed
grad_expected = batch_idx % self.trainer.accumulate_grad_batches != 0
if grad_expected:
assert batch_idx in (2, 4)
assert all(p.grad is not None for p in self.parameters())
else:
assert batch_idx == 6
assert all(p.grad is None for p in self.parameters())
self.ran_assert = True

model = ValidationModel()
trainer = Trainer(
default_root_dir=tmp_path,
limit_train_batches=6,
limit_val_batches=1,
accumulate_grad_batches=3,
# validation happens in the middle of the first two accumulations, and at the end of the third
val_check_interval=2,
max_epochs=1,
num_sanity_val_steps=0,
)
trainer.fit(model)
assert model.ran_assert
12 changes: 8 additions & 4 deletions tests/tests_pytorch/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
from torch import Tensor
Expand Down Expand Up @@ -465,11 +466,11 @@ def training_step(self, batch, batch_idx):
{"name": "configure_optimizers"},
{"name": "Callback.on_fit_start", "args": (trainer, model)},
{"name": "on_fit_start"},
{"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
{"name": "Callback.on_sanity_check_start", "args": (trainer, model)},
{"name": "val_dataloader"},
{"name": "train", "args": (False,)},
{"name": "on_validation_model_eval"},
{"name": "zero_grad"},
{"name": "Callback.on_validation_start", "args": (trainer, model)},
{"name": "on_validation_start"},
*model._eval_epoch("validation", trainer, model, val_batches, "x", device=device),
Expand All @@ -486,9 +487,10 @@ def training_step(self, batch, batch_idx):
{"name": "Callback.on_train_epoch_start", "args": (trainer, model)},
{"name": "on_train_epoch_start"},
*model._train_batch(trainer, model, train_batches, device=device, **kwargs),
{"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
{"name": "on_validation_model_zero_grad"},
{"name": "train", "args": (False,)},
{"name": "on_validation_model_eval"},
{"name": "zero_grad"},
{"name": "Callback.on_validation_start", "args": (trainer, model)},
{"name": "on_validation_start"},
*model._eval_epoch("validation", trainer, model, val_batches, "x", device=device),
Expand Down Expand Up @@ -566,6 +568,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
{"name": "configure_optimizers"},
{"name": "Callback.on_fit_start", "args": (trainer, model)},
{"name": "on_fit_start"},
{"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
{"name": "train_dataloader"},
{"name": "train", "args": (True,)},
{"name": "Callback.on_train_start", "args": (trainer, model)},
Expand Down Expand Up @@ -644,6 +647,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir):
{"name": "configure_optimizers"},
{"name": "Callback.on_fit_start", "args": (trainer, model)},
{"name": "on_fit_start"},
{"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
{"name": "train_dataloader"},
{"name": "train", "args": (True,)},
{"name": "Callback.on_train_start", "args": (trainer, model)},
Expand Down Expand Up @@ -690,7 +694,6 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
{"name": f"{dataloader}_dataloader"},
{"name": "train", "args": (False,)},
{"name": f"on_{noun}_model_eval"},
{"name": "zero_grad"},
{"name": f"Callback.on_{noun}_start", "args": (trainer, model)},
{"name": f"on_{noun}_start"},
*model._eval_epoch(noun, trainer, model, batches, key, trainer.strategy.root_device),
Expand All @@ -705,6 +708,7 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": verb}},
{"name": "setup", "kwargs": {"stage": verb}},
{"name": "configure_model"},
{"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
*(hooks if batches else []),
{"name": "Callback.teardown", "args": (trainer, model), "kwargs": {"stage": verb}},
{"name": "teardown", "kwargs": {"stage": verb}},
Expand All @@ -727,10 +731,10 @@ def test_trainer_model_hook_system_predict(tmpdir):
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "predict"}},
{"name": "setup", "kwargs": {"stage": "predict"}},
{"name": "configure_model"},
{"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
{"name": "predict_dataloader"},
{"name": "train", "args": (False,)},
{"name": "on_predict_model_eval"},
{"name": "zero_grad"},
{"name": "Callback.on_predict_start", "args": (trainer, model)},
{"name": "on_predict_start"},
{"name": "Callback.on_predict_epoch_start", "args": (trainer, model)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def test_fx_validator_integration(tmpdir):
"on_sanity_check_end": "You can't",
"prepare_data": "You can't",
"configure_callbacks": "You can't",
"on_validation_model_zero_grad": "You can't",
"on_validation_model_eval": "You can't",
"on_validation_model_train": "You can't",
"lr_scheduler_step": "You can't",
Expand Down