diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index b7146f58c60d9..0a678f07e043e 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -23,9 +23,11 @@ from pytorch_lightning.accelerators.accelerator_connector import BackendConnector from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType, rank_zero_warn +from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn from pytorch_lightning.utilities.argparse import ( add_argparse_args, from_argparse_args, @@ -33,14 +35,6 @@ parse_env_variables, ) from pytorch_lightning.utilities.cloud_io import get_filesystem - -if _TPU_AVAILABLE: - import torch_xla.core.xla_model as xm - -if _HOROVOD_AVAILABLE: - import horovod.torch as hvd - -from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.utilities.model_helpers import is_overridden @@ -420,37 +414,10 @@ def __setstate__(self, state): self.__dict__ = state @property - def require_distributed_sampler(self): - if self.accelerator_backend is not None: - return self.accelerator_backend.require_distributed_sampler - return self._distrib_type in ( - DistributedType.HOROVOD, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2 - ) or self._device_type == DeviceType.TPU - - @property - def distributed_sampler_kwargs(self): - if self.accelerator_backend is not None: + def distributed_sampler_kwargs(self) -> Optional[dict]: + if isinstance(self.training_type_plugin, ParallelPlugin): return self.training_type_plugin.distributed_sampler_kwargs - # TODO: make sure the cases below are handled by the training_type_plugin - if self._device_type == DeviceType.TPU: - kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - - elif self._distrib_type == DistributedType.HOROVOD: - kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) - - else: - world_size = { - "ddp": self.num_nodes * self.num_processes, - "ddp_spawn": self.num_nodes * self.num_processes, - "ddp2": self.num_nodes, - "ddp_cpu": self.num_processes * self.num_nodes - } - assert self.distributed_backend is not None - kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank) - - return kwargs - # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties)