|
34 | 34 | from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, tpu_distributed |
35 | 35 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
36 | 36 | from pytorch_lightning.utilities.seed import reset_seed |
| 37 | +from pytorch_lightning.utilities.types import STEP_OUTPUT |
37 | 38 |
|
38 | 39 | if _TPU_AVAILABLE: |
39 | 40 | import torch_xla.core.xla_env_vars as xenv |
@@ -152,9 +153,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: |
152 | 153 | # set warning rank |
153 | 154 | rank_zero_only.rank = self.global_rank |
154 | 155 |
|
155 | | - if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: |
156 | | - print(' ', end='', flush=True) |
157 | | - |
158 | 156 | if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: |
159 | 157 | trainer.progress_bar_callback.disable() |
160 | 158 |
|
@@ -285,21 +283,26 @@ def test_step(self, *args, **kwargs): |
285 | 283 | def predict_step(self, *args, **kwargs): |
286 | 284 | return self.model(*args, **kwargs) |
287 | 285 |
|
288 | | - def training_step_end(self, output): |
289 | | - if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: |
290 | | - print(' ', end='', flush=True) |
| 286 | + def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: |
| 287 | + self._pod_progress_bar_force_stdout() |
291 | 288 | return output |
292 | 289 |
|
293 | | - def validation_step_end(self, output): |
294 | | - if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: |
295 | | - print(' ', end='', flush=True) |
| 290 | + def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: |
| 291 | + self._pod_progress_bar_force_stdout() |
296 | 292 | return output |
297 | 293 |
|
298 | | - def test_step_end(self, output): |
299 | | - if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: |
300 | | - print(' ', end='', flush=True) |
| 294 | + def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: |
| 295 | + self._pod_progress_bar_force_stdout() |
301 | 296 | return output |
302 | 297 |
|
| 298 | + def _pod_progress_bar_force_stdout(self) -> None: |
| 299 | + # Why is it required? The way `pytorch_xla.distributed` streams logs |
| 300 | + # from different vms to the master worker doesn't work well with tqdm |
| 301 | + # Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140 |
| 302 | + # The print statement seems to force tqdm to flush stdout. |
| 303 | + if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: |
| 304 | + print() |
| 305 | + |
303 | 306 | def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: |
304 | 307 | """Save model/training states as a checkpoint file through state-dump and file-write. |
305 | 308 |
|
|
0 commit comments