Skip to content

Commit d422ef2

Browse files
authored
clean up unused distributed sampler logic in trainer (#5975)
* clean up sampler unused logic * undo cached * imports
1 parent 4f63942 commit d422ef2

File tree

1 file changed

+5
-38
lines changed

1 file changed

+5
-38
lines changed

pytorch_lightning/trainer/properties.py

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,18 @@
2323
from pytorch_lightning.accelerators.accelerator_connector import BackendConnector
2424
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
2525
from pytorch_lightning.core.lightning import LightningModule
26+
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
27+
from pytorch_lightning.plugins import ParallelPlugin
2628
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
2729
from pytorch_lightning.trainer.states import TrainerState
28-
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType, rank_zero_warn
30+
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn
2931
from pytorch_lightning.utilities.argparse import (
3032
add_argparse_args,
3133
from_argparse_args,
3234
parse_argparser,
3335
parse_env_variables,
3436
)
3537
from pytorch_lightning.utilities.cloud_io import get_filesystem
36-
37-
if _TPU_AVAILABLE:
38-
import torch_xla.core.xla_model as xm
39-
40-
if _HOROVOD_AVAILABLE:
41-
import horovod.torch as hvd
42-
43-
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
4438
from pytorch_lightning.utilities.model_helpers import is_overridden
4539

4640

@@ -420,37 +414,10 @@ def __setstate__(self, state):
420414
self.__dict__ = state
421415

422416
@property
423-
def require_distributed_sampler(self):
424-
if self.accelerator_backend is not None:
425-
return self.accelerator_backend.require_distributed_sampler
426-
return self._distrib_type in (
427-
DistributedType.HOROVOD, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2
428-
) or self._device_type == DeviceType.TPU
429-
430-
@property
431-
def distributed_sampler_kwargs(self):
432-
if self.accelerator_backend is not None:
417+
def distributed_sampler_kwargs(self) -> Optional[dict]:
418+
if isinstance(self.training_type_plugin, ParallelPlugin):
433419
return self.training_type_plugin.distributed_sampler_kwargs
434420

435-
# TODO: make sure the cases below are handled by the training_type_plugin
436-
if self._device_type == DeviceType.TPU:
437-
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
438-
439-
elif self._distrib_type == DistributedType.HOROVOD:
440-
kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())
441-
442-
else:
443-
world_size = {
444-
"ddp": self.num_nodes * self.num_processes,
445-
"ddp_spawn": self.num_nodes * self.num_processes,
446-
"ddp2": self.num_nodes,
447-
"ddp_cpu": self.num_processes * self.num_nodes
448-
}
449-
assert self.distributed_backend is not None
450-
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
451-
452-
return kwargs
453-
454421

455422
# Used to represent the concrete type TrainerProperties class methods are called on.
456423
_T = TypeVar('_T', bound=TrainerProperties)

0 commit comments

Comments
 (0)