|
34 | 34 | from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, tpu_distributed |
35 | 35 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
36 | 36 | from pytorch_lightning.utilities.seed import reset_seed |
| 37 | +from pytorch_lightning.utilities.types import STEP_OUTPUT |
37 | 38 |
|
38 | 39 | if _TPU_AVAILABLE: |
39 | 40 | import torch_xla.core.xla_env_vars as xenv |
@@ -282,6 +283,26 @@ def test_step(self, *args, **kwargs): |
282 | 283 | def predict_step(self, *args, **kwargs): |
283 | 284 | return self.model(*args, **kwargs) |
284 | 285 |
|
| 286 | + def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: |
| 287 | + self._pod_progress_bar_force_stdout() |
| 288 | + return output |
| 289 | + |
| 290 | + def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: |
| 291 | + self._pod_progress_bar_force_stdout() |
| 292 | + return output |
| 293 | + |
| 294 | + def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: |
| 295 | + self._pod_progress_bar_force_stdout() |
| 296 | + return output |
| 297 | + |
| 298 | + def _pod_progress_bar_force_stdout(self) -> None: |
| 299 | + # Why is it required? The way `pytorch_xla.distributed` streams logs |
| 300 | + # from different vms to the master worker doesn't work well with tqdm |
| 301 | + # Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140 |
| 302 | + # The print statement seems to force tqdm to flush stdout. |
| 303 | + if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: |
| 304 | + print() |
| 305 | + |
285 | 306 | def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: |
286 | 307 | """Save model/training states as a checkpoint file through state-dump and file-write. |
287 | 308 |
|
|
0 commit comments