|
23 | 23 | from pytorch_lightning.accelerators.accelerator_connector import BackendConnector |
24 | 24 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase |
25 | 25 | from pytorch_lightning.core.lightning import LightningModule |
| 26 | +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger |
| 27 | +from pytorch_lightning.plugins import ParallelPlugin |
26 | 28 | from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector |
27 | 29 | from pytorch_lightning.trainer.states import TrainerState |
28 | | -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType, rank_zero_warn |
| 30 | +from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn |
29 | 31 | from pytorch_lightning.utilities.argparse import ( |
30 | 32 | add_argparse_args, |
31 | 33 | from_argparse_args, |
32 | 34 | parse_argparser, |
33 | 35 | parse_env_variables, |
34 | 36 | ) |
35 | 37 | from pytorch_lightning.utilities.cloud_io import get_filesystem |
36 | | - |
37 | | -if _TPU_AVAILABLE: |
38 | | - import torch_xla.core.xla_model as xm |
39 | | - |
40 | | -if _HOROVOD_AVAILABLE: |
41 | | - import horovod.torch as hvd |
42 | | - |
43 | | -from pytorch_lightning.loggers.tensorboard import TensorBoardLogger |
44 | 38 | from pytorch_lightning.utilities.model_helpers import is_overridden |
45 | 39 |
|
46 | 40 |
|
@@ -420,37 +414,10 @@ def __setstate__(self, state): |
420 | 414 | self.__dict__ = state |
421 | 415 |
|
422 | 416 | @property |
423 | | - def require_distributed_sampler(self): |
424 | | - if self.accelerator_backend is not None: |
425 | | - return self.accelerator_backend.require_distributed_sampler |
426 | | - return self._distrib_type in ( |
427 | | - DistributedType.HOROVOD, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2 |
428 | | - ) or self._device_type == DeviceType.TPU |
429 | | - |
430 | | - @property |
431 | | - def distributed_sampler_kwargs(self): |
432 | | - if self.accelerator_backend is not None: |
| 417 | + def distributed_sampler_kwargs(self) -> Optional[dict]: |
| 418 | + if isinstance(self.training_type_plugin, ParallelPlugin): |
433 | 419 | return self.training_type_plugin.distributed_sampler_kwargs |
434 | 420 |
|
435 | | - # TODO: make sure the cases below are handled by the training_type_plugin |
436 | | - if self._device_type == DeviceType.TPU: |
437 | | - kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) |
438 | | - |
439 | | - elif self._distrib_type == DistributedType.HOROVOD: |
440 | | - kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) |
441 | | - |
442 | | - else: |
443 | | - world_size = { |
444 | | - "ddp": self.num_nodes * self.num_processes, |
445 | | - "ddp_spawn": self.num_nodes * self.num_processes, |
446 | | - "ddp2": self.num_nodes, |
447 | | - "ddp_cpu": self.num_processes * self.num_nodes |
448 | | - } |
449 | | - assert self.distributed_backend is not None |
450 | | - kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank) |
451 | | - |
452 | | - return kwargs |
453 | | - |
454 | 421 |
|
455 | 422 | # Used to represent the concrete type TrainerProperties class methods are called on. |
456 | 423 | _T = TypeVar('_T', bound=TrainerProperties) |
0 commit comments