|
17 | 17 | """ |
18 | 18 |
|
19 | 19 | import abc |
20 | | -from typing import Any, Dict, Optional |
| 20 | +from typing import Any, Dict, List, Optional |
21 | 21 |
|
22 | 22 | from pytorch_lightning.core.lightning import LightningModule |
23 | 23 |
|
@@ -81,23 +81,23 @@ def on_train_epoch_start(self, trainer, pl_module: LightningModule) -> None: |
81 | 81 | """Called when the train epoch begins.""" |
82 | 82 | pass |
83 | 83 |
|
84 | | - def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: |
| 84 | + def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: |
85 | 85 | """Called when the train epoch ends.""" |
86 | 86 | pass |
87 | 87 |
|
88 | 88 | def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None: |
89 | 89 | """Called when the val epoch begins.""" |
90 | 90 | pass |
91 | 91 |
|
92 | | - def on_validation_epoch_end(self, trainer, pl_module: LightningModule) -> None: |
| 92 | + def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: |
93 | 93 | """Called when the val epoch ends.""" |
94 | 94 | pass |
95 | 95 |
|
96 | 96 | def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None: |
97 | 97 | """Called when the test epoch begins.""" |
98 | 98 | pass |
99 | 99 |
|
100 | | - def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None: |
| 100 | + def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: |
101 | 101 | """Called when the test epoch ends.""" |
102 | 102 | pass |
103 | 103 |
|
|
0 commit comments