Skip to content

Commit f1d0b65

Browse files
committed
Add _pod_progress_bar_force_stdout
1 parent 6fda7d9 commit f1d0b65

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, tpu_distributed
3535
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3636
from pytorch_lightning.utilities.seed import reset_seed
37+
from pytorch_lightning.utilities.types import STEP_OUTPUT
3738

3839
if _TPU_AVAILABLE:
3940
import torch_xla.core.xla_env_vars as xenv
@@ -152,9 +153,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
152153
# set warning rank
153154
rank_zero_only.rank = self.global_rank
154155

155-
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
156-
print(' ', end='', flush=True)
157-
158156
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
159157
trainer.progress_bar_callback.disable()
160158

@@ -285,21 +283,26 @@ def test_step(self, *args, **kwargs):
285283
def predict_step(self, *args, **kwargs):
286284
return self.model(*args, **kwargs)
287285

288-
def training_step_end(self, output):
289-
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
290-
print(' ', end='', flush=True)
286+
def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
287+
self._pod_progress_bar_force_stdout()
291288
return output
292289

293-
def validation_step_end(self, output):
294-
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
295-
print(' ', end='', flush=True)
290+
def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
291+
self._pod_progress_bar_force_stdout()
296292
return output
297293

298-
def test_step_end(self, output):
299-
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
300-
print(' ', end='', flush=True)
294+
def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
295+
self._pod_progress_bar_force_stdout()
301296
return output
302297

298+
def _pod_progress_bar_force_stdout(self) -> None:
299+
# Why is it required? The way `pytorch_xla.distributed` streams logs
300+
# from different vms to the master worker doesn't work well with tqdm
301+
# Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140
302+
# The print statement seems to force tqdm to flush stdout.
303+
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
304+
print()
305+
303306
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
304307
"""Save model/training states as a checkpoint file through state-dump and file-write.
305308

0 commit comments

Comments
 (0)