Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618))


- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))


- Added `configure_sharded_model` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679))


Expand Down Expand Up @@ -213,6 +210,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))


- Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339))


- Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292))


Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import STEP_OUTPUT


class Callback(abc.ABC):
Expand Down Expand Up @@ -108,17 +108,15 @@ def on_validation_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightn
"""Called when the val epoch begins."""
pass

def on_validation_epoch_end(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT
) -> None:
def on_validation_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the val epoch ends."""
pass

def on_test_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the test epoch begins."""
pass

def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None:
def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the test epoch ends."""
pass

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import STEP_OUTPUT


class ModelHooks:
Expand Down Expand Up @@ -245,7 +245,7 @@ def on_validation_epoch_start(self) -> None:
Called in the validation loop at the very beginning of the epoch.
"""

def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
def on_validation_epoch_end(self) -> None:
"""
Called in the validation loop at the very end of the epoch.
"""
Expand All @@ -255,7 +255,7 @@ def on_test_epoch_start(self) -> None:
Called in the test loop at the very beginning of the epoch.
"""

def on_test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
def on_test_epoch_end(self) -> None:
"""
Called in the test loop at the very end of the epoch.
"""
Expand Down
36 changes: 6 additions & 30 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,44 +111,20 @@ def on_validation_epoch_start(self):
for callback in self.callbacks:
callback.on_validation_epoch_start(self, self.lightning_module)

def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT):
"""Called when the epoch ends.

Args:
outputs: List of outputs on each ``validation`` epoch
"""
def on_validation_epoch_end(self):
"""Called when the validation epoch ends."""
for callback in self.callbacks:
if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"):
callback.on_validation_epoch_end(self, self.lightning_module, outputs)
else:
warning_cache.warn(
"`Callback.on_validation_epoch_end` signature has changed in v1.3."
" `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
callback.on_validation_epoch_end(self, self.lightning_module)
callback.on_validation_epoch_end(self, self.lightning_module)

def on_test_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_test_epoch_start(self, self.lightning_module)

def on_test_epoch_end(self, outputs: EPOCH_OUTPUT):
"""Called when the epoch ends.

Args:
outputs: List of outputs on each ``test`` epoch
"""
def on_test_epoch_end(self):
"""Called when the test epoch ends."""
for callback in self.callbacks:
if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"):
callback.on_test_epoch_end(self, self.lightning_module, outputs)
else:
warning_cache.warn(
"`Callback.on_test_epoch_end` signature has changed in v1.3."
" `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
callback.on_test_epoch_end(self, self.lightning_module)
callback.on_test_epoch_end(self, self.lightning_module)

def on_predict_epoch_start(self) -> None:
"""Called when the epoch begins."""
Expand Down
24 changes: 12 additions & 12 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

from torch.utils.data import DataLoader

Expand All @@ -20,7 +20,6 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -76,6 +75,7 @@ def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool:
return sum(max_batches) == 0

def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end()
if self.trainer.testing:
self.trainer.call_hook('on_test_start', *args, **kwargs)
else:
Expand Down Expand Up @@ -188,6 +188,13 @@ def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT
output = self.trainer.call_hook('validation_step_end', *args, **kwargs)
return output

def _should_track_batch_outputs_for_epoch_end(self) -> bool:
model = self.trainer.lightning_module
if self.trainer.testing:
return is_overridden('test_epoch_end', model=model)
else:
return is_overridden('validation_epoch_end', model=model)

def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
# unset dataloder_idx in model
self.trainer.logger_connector.evaluation_epoch_end()
Expand Down Expand Up @@ -241,7 +248,7 @@ def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, datal
# track debug metrics
self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output)

def on_evaluation_epoch_end(self, outputs: Union[List[List[Dict]], List[Dict]]) -> None:
def on_evaluation_epoch_end(self) -> None:
model_ref = self.trainer.lightning_module
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"

Expand All @@ -251,18 +258,11 @@ def on_evaluation_epoch_end(self, outputs: Union[List[List[Dict]], List[Dict]])

if hasattr(self.trainer, hook_name):
on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name)
on_evaluation_epoch_end_hook(outputs)
on_evaluation_epoch_end_hook()

if is_overridden(hook_name, model_ref):
model_hook_fx = getattr(model_ref, hook_name)
if is_param_in_hook_signature(model_hook_fx, "outputs"):
model_hook_fx(outputs)
else:
self.warning_cache.warn(
f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
model_hook_fx()
model_hook_fx()

self.trainer._cache_logged_metrics()

Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,22 +972,23 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
dl_outputs = self.track_output_for_epoch_end(dl_outputs, output)

# store batch level output per dataloader
self.evaluation_loop.outputs.append(dl_outputs)
if self.evaluation_loop.should_track_batch_outputs_for_epoch_end:
self.evaluation_loop.outputs.append(dl_outputs)

outputs = self.evaluation_loop.outputs

# reset outputs
self.evaluation_loop.outputs = []

# with a single dataloader don't pass a 2D list
if self.evaluation_loop.num_dataloaders == 1:
if len(outputs) > 0 and self.evaluation_loop.num_dataloaders == 1:
outputs = outputs[0]

# lightning module method
self.evaluation_loop.evaluation_epoch_end(outputs)

# hook
self.evaluation_loop.on_evaluation_epoch_end(outputs)
self.evaluation_loop.on_evaluation_epoch_end()

# update epoch-level lr_schedulers
if on_epoch:
Expand Down Expand Up @@ -1212,8 +1213,8 @@ def _cache_logged_metrics(self):

def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
# Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook
# This was done to manage the deprecation of an argument to on_train_epoch_end
# If making chnages to this function, ensure that those changes are also made to
# This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end
# If making changes to this function, ensure that those changes are also made to
# TrainLoop._on_train_epoch_end_hook

# set hook_name to model + reset Result obj
Expand Down
42 changes: 0 additions & 42 deletions tests/callbacks/test_callback_hook_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,48 +65,6 @@ def training_epoch_end(self, outputs) -> None:
trainer.fit(model)


def test_on_val_epoch_end_outputs(tmpdir):

class CB(Callback):

def on_validation_epoch_end(self, trainer, pl_module, outputs):
if trainer.running_sanity_check:
assert len(outputs) == trainer.num_sanity_val_batches[0]
else:
assert len(outputs) == trainer.num_val_batches[0]

model = BoringModel()

trainer = Trainer(
callbacks=CB(),
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
)

trainer.fit(model)


def test_on_test_epoch_end_outputs(tmpdir):

class CB(Callback):

def on_test_epoch_end(self, trainer, pl_module, outputs):
assert len(outputs) == trainer.num_test_batches[0]

model = BoringModel()

trainer = Trainer(
callbacks=CB(),
default_root_dir=tmpdir,
weights_summary=None,
)

trainer.test(model)


def test_free_memory_on_eval_outputs(tmpdir):

class CB(Callback):
Expand Down
8 changes: 4 additions & 4 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_sanity_check_end(trainer, model),
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_test_batch_start(trainer, model, ANY, 1, 0),
call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_test_epoch_end(trainer, model, ANY),
call.on_test_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_test_end(trainer, model),
call.teardown(trainer, model, 'test'),
Expand Down Expand Up @@ -163,7 +163,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_batch_start(trainer, model, ANY, 1, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.teardown(trainer, model, 'validate'),
Expand Down
56 changes: 0 additions & 56 deletions tests/core/test_hooks.py

This file was deleted.

Loading