diff --git a/.gitignore b/.gitignore index 390551b8f6e60..cd0ba22453512 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,5 @@ cifar-10-batches-py # ctags tags data +MNIST +runs diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d78578fc1fb2..c28e1fa2f202e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -288,6 +288,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015)) +- Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027)) + + ## [1.1.8] - 2021-02-08 ### Fixed diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 2e8e31139dda2..967b6a85c878b 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union import torch from torch.optim import Optimizer +from torch.utils.data import DataLoader from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision import ( @@ -388,3 +389,11 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s A tensor of shape (world_size, batch, ...) """ return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + + def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Wraps the dataloader if necessary + + Args: + dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` + """ + return self.training_type_plugin.process_dataloader(dataloader) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a8024bef2a539..83d86b619c7c9 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -554,6 +554,14 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics): epoch = metrics.get("epoch") step = metrics.get("step") + # when `val_loss` is being logged and no ModelCheckpoint is being provided + # `val_loss` will be selected for monitor and need to be reduced to + # prevent processes divergence + # TODO: Move this logic to logger_connector. This also needs to be fixed for any + # other monitor logged value which aren't produced from a Metric. + if self.monitor == "val_loss": + current = trainer.training_type_plugin.reduce(current, reduce_op="mean") + if self.check_monitor_top_k(current): self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics) elif self.verbose: diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index d4374d0ef9c6a..0136e78a4381f 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -10,7 +10,8 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything if _TPU_AVAILABLE: @@ -46,10 +47,6 @@ def create_mp_queue(self): def distributed_sampler_kwargs(self) -> dict: return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - @property - def should_finalize(self): - return self.world_size == 1 - @property def is_distributed(self): return self.world_size != 1 @@ -179,6 +176,24 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool: should_stop = int(stop.item()) == self.world_size return should_stop + def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): + if not isinstance(output, torch.Tensor): + output = torch.tensor(output, device=self.device) + + _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM + _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if _invalid_reduce_op or _invalid_reduce_op_str: + raise MisconfigurationException( + "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." + ) + + output = xm.mesh_reduce('reduce', output, sum) + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + output = output / self.world_size + + return output + def post_dispatch(self) -> None: # TODO: Check if trainer references can be resolved otherwise model = self.lightning_module @@ -213,6 +228,10 @@ def __load_weights_on_main_process(self) -> None: self._model = model + def _close_logger(self, trainer) -> None: + if hasattr(trainer, "logger"): + trainer.logger.finalize("success") + @property def xmp_spawn_kwargs(self): return { @@ -225,9 +244,11 @@ def start_training(self, trainer) -> None: # todo: precision pluging is call in accelerator setup and should be moved if 'XLA_USE_BF16' in os.environ: del os.environ["XLA_USE_BF16"] + self._close_logger(trainer) xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) def start_testing(self, trainer) -> None: + self._close_logger(trainer) xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) def start_predicting(self, trainer) -> None: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cede3e5f98b43..938a17249e9f6 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -35,10 +35,6 @@ def __init__(self) -> None: self._results = None self.global_rank = 0 - @property - def should_finalize(self): - return True - @property @abstractmethod def on_gpu(self) -> bool: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2453a08ba9067..2b2b2f92dce59 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -711,7 +711,7 @@ def run_evaluation(self, max_batches=None, on_epoch=False): for dataloader_idx, dataloader in enumerate(dataloaders): # bookkeeping dl_outputs = [] - dataloader = self.training_type_plugin.process_dataloader(dataloader) + dataloader = self.accelerator.process_dataloader(dataloader) dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): @@ -823,7 +823,7 @@ def run_predict(self): # run validation/testing for dataloader_idx, dataloader in enumerate(dataloaders): - dataloader = self.training_type_plugin.process_dataloader(dataloader) + dataloader = self.accelerator.process_dataloader(dataloader) dl_max_batches = self.predict_loop.max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0908e96bd1c17..57c0b10f12412 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -140,7 +140,7 @@ def on_train_end(self): # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. # It might be related to xla tensors blocked when moving the cpu # kill loggers - if self.trainer.logger is not None and self.trainer.training_type_plugin.should_finalize: + if self.trainer.logger is not None: self.trainer.logger.finalize("success") # summarize profile results @@ -502,7 +502,7 @@ def tbptt_split_batch(self, batch): def run_training_epoch(self): # modify dataloader if needed (ddp, etc...) - train_dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader) + train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)] diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index d9ea8a9917d2b..4c6620b07b74a 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -26,6 +26,7 @@ from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.utils import pl_multi_process_test @@ -264,9 +265,6 @@ def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores): @pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") -@pytest.mark.skipif( - not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" -) @pl_multi_process_test def test_broadcast_on_tpu(): """ Checks if an object from the master process is broadcasted to other processes correctly""" @@ -327,3 +325,26 @@ def test_tpu_cores_with_argparse(cli_args, expected): for k, v in expected.items(): assert getattr(args, k) == v assert Trainer.from_argparse_args(args) + + +@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@pl_multi_process_test +def test_tpu_reduce(): + """Test tpu spawn reduce operation """ + + def test_reduce(rank): + trainer = Trainer(tpu_cores=8) + # faster this way + reduce_ops = ["mean", "AVG", "undefined", "sum", ReduceOp.SUM, ReduceOp.MAX] + for reduce_op in reduce_ops: + if reduce_op == "undefined" or reduce_op == ReduceOp.MAX: + with pytest.raises(MisconfigurationException, match="TPUSpawn TrainingTypePlugin only support"): + result = trainer.training_type_plugin.reduce(1, reduce_op) + else: + result = trainer.training_type_plugin.reduce(1, reduce_op) + if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"): + assert result.item() == 1 + else: + assert result.item() == 8 + + xmp.spawn(test_reduce, nprocs=8, start_method='fork')