Skip to content

Commit 72097ba

Browse files
author
SeanNaren
committed
Merge branch 'master' into fix/setup_ddp_hook
2 parents d7ec33e + 297e438 commit 72097ba

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+688
-703
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4040
- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
4141

4242

43+
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))
44+
45+
46+
4347
### Changed
4448

4549
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
@@ -68,6 +72,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6872
- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505),
6973

7074
[#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530),
75+
76+
[#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547),
7177

7278
)
7379

docs/source/benchmarking/performance.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,19 @@ Most UNIX-based operating systems provide direct access to tmpfs through a mount
181181
.. code-block:: python
182182
183183
datamodule = MyDataModule(data_root="/dev/shm/my_data")
184+
185+
186+
Zero Grad ``set_to_none=True``
187+
------------------------------
188+
189+
In order to modestly improve performance, once can override :meth:`~pytorch_lightning.core.lightning.LightningModule.optimizer_zero_grad`.
190+
191+
For a more detailed explanation of pros / cons of this technique,
192+
read `this <https://pytorch.org/docs/master/optim.html#torch.optim.Optimizer.zero_grad>`_ documentation by the PyTorch team.
193+
194+
.. testcode::
195+
196+
class Model(LightningModule):
197+
198+
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
199+
optimizer.zero_grad(set_to_none=True)

pytorch_lightning/callbacks/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""
1818

1919
import abc
20-
from typing import Any, Dict, Optional
20+
from typing import Any, Dict, List, Optional
2121

2222
from pytorch_lightning.core.lightning import LightningModule
2323

@@ -81,23 +81,23 @@ def on_train_epoch_start(self, trainer, pl_module: LightningModule) -> None:
8181
"""Called when the train epoch begins."""
8282
pass
8383

84-
def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None:
84+
def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
8585
"""Called when the train epoch ends."""
8686
pass
8787

8888
def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None:
8989
"""Called when the val epoch begins."""
9090
pass
9191

92-
def on_validation_epoch_end(self, trainer, pl_module: LightningModule) -> None:
92+
def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
9393
"""Called when the val epoch ends."""
9494
pass
9595

9696
def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None:
9797
"""Called when the test epoch begins."""
9898
pass
9999

100-
def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None:
100+
def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
101101
"""Called when the test epoch ends."""
102102
pass
103103

pytorch_lightning/core/hooks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def on_train_epoch_start(self) -> None:
240240
"""
241241
# do something when the epoch starts
242242

243-
def on_train_epoch_end(self, outputs) -> None:
243+
def on_train_epoch_end(self, outputs: List[Any]) -> None:
244244
"""
245245
Called in the training loop at the very end of the epoch.
246246
"""
@@ -252,7 +252,7 @@ def on_validation_epoch_start(self) -> None:
252252
"""
253253
# do something when the epoch starts
254254

255-
def on_validation_epoch_end(self) -> None:
255+
def on_validation_epoch_end(self, outputs: List[Any]) -> None:
256256
"""
257257
Called in the validation loop at the very end of the epoch.
258258
"""
@@ -264,7 +264,7 @@ def on_test_epoch_start(self) -> None:
264264
"""
265265
# do something when the epoch starts
266266

267-
def on_test_epoch_end(self) -> None:
267+
def on_test_epoch_end(self, outputs: List[Any]) -> None:
268268
"""
269269
Called in the test loop at the very end of the epoch.
270270
"""

pytorch_lightning/core/step_result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
import torch
2222
from torch import Tensor
23+
from torchmetrics import Metric
2324

24-
from pytorch_lightning.metrics import Metric
2525
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
2626

2727

pytorch_lightning/metrics/classification/accuracy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from typing import Any, Callable, Optional
1515

1616
import torch
17+
from torchmetrics import Metric
1718

1819
from pytorch_lightning.metrics.functional.accuracy import _accuracy_compute, _accuracy_update
19-
from pytorch_lightning.metrics.metric import Metric
2020

2121

2222
class Accuracy(Metric):

pytorch_lightning/metrics/classification/auc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from typing import Any, Callable, Optional
1515

1616
import torch
17+
from torchmetrics import Metric
1718

1819
from pytorch_lightning.metrics.functional.auc import _auc_compute, _auc_update
19-
from pytorch_lightning.metrics.metric import Metric
2020
from pytorch_lightning.utilities import rank_zero_warn
2121

2222

pytorch_lightning/metrics/classification/auroc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from typing import Any, Callable, Optional
1616

1717
import torch
18+
from torchmetrics import Metric
1819

1920
from pytorch_lightning.metrics.functional.auroc import _auroc_compute, _auroc_update
20-
from pytorch_lightning.metrics.metric import Metric
2121
from pytorch_lightning.utilities import rank_zero_warn
2222

2323

pytorch_lightning/metrics/classification/average_precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from typing import Any, List, Optional, Union
1515

1616
import torch
17+
from torchmetrics import Metric
1718

1819
from pytorch_lightning.metrics.functional.average_precision import _average_precision_compute, _average_precision_update
19-
from pytorch_lightning.metrics.metric import Metric
2020
from pytorch_lightning.utilities import rank_zero_warn
2121

2222

pytorch_lightning/metrics/classification/confusion_matrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from typing import Any, Optional
1515

1616
import torch
17+
from torchmetrics import Metric
1718

1819
from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update
19-
from pytorch_lightning.metrics.metric import Metric
2020

2121

2222
class ConfusionMatrix(Metric):

0 commit comments

Comments
 (0)