Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, tpu_distributed
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
Expand Down Expand Up @@ -282,6 +283,26 @@ def test_step(self, *args, **kwargs):
def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
self._pod_progress_bar_force_stdout()
return output

def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
self._pod_progress_bar_force_stdout()
return output

def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
self._pod_progress_bar_force_stdout()
return output

def _pod_progress_bar_force_stdout(self) -> None:
# Why is it required? The way `pytorch_xla.distributed` streams logs
# from different vms to the master worker doesn't work well with tqdm
# Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140
# The print statement seems to force tqdm to flush stdout.
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
print()

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Expand Down