Skip to content

Commit 2c3d43d

Browse files
tchatonmaxjeblickSeanNarenSeanNarens-rog
authored
Initialize trainer with None in DDPAccelerator (#4915)
* Initialize trainer with None * add typing to all accelerators * resolve imports * update * add typing * removed typo * update * Fix formatting and imports in accelerator Co-authored-by: maxjeblick <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: SeanNaren <[email protected]> Co-authored-by: Roger Shieh <[email protected]>
1 parent d5fa02e commit 2c3d43d

12 files changed

+76
-27
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from enum import Enum
1615
from typing import Any, Optional, Union
1716

1817
import torch
1918
import torch.distributed as torch_distrib
2019
from torch.optim import Optimizer
2120

21+
from pytorch_lightning.cluster_environments import ClusterEnvironment
2222
from pytorch_lightning.core.lightning import LightningModule
23+
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
2324
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
2425
from pytorch_lightning.utilities.apply_func import move_data_to_device
2526
from pytorch_lightning.utilities.parsing import AttributeDict
@@ -33,7 +34,10 @@ class ReduceOp:
3334

3435
class Accelerator(object):
3536

36-
def __init__(self, trainer=None, cluster_environment=None, ddp_plugin=None):
37+
def __init__(self,
38+
trainer: Optional = None,
39+
cluster_environment: Optional[ClusterEnvironment] = None,
40+
ddp_plugin: Optional[DDPPlugin] = None):
3741
self.trainer = trainer
3842
self.nickname = None
3943
self.cluster_environment = cluster_environment

pytorch_lightning/accelerators/cpu_accelerator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
import torch
1717

1818
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
19+
from pytorch_lightning.cluster_environments import ClusterEnvironment
1920
from pytorch_lightning.utilities import AMPType, rank_zero_warn
2021
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2122

2223

2324
class CPUAccelerator(Accelerator):
2425

25-
def __init__(self, trainer, cluster_environment=None):
26+
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
2627
"""
2728
Runs training on CPU
2829

pytorch_lightning/accelerators/ddp2_accelerator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020

2121
from pytorch_lightning import _logger as log
2222
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
23+
from pytorch_lightning.cluster_environments import ClusterEnvironment
2324
from pytorch_lightning.core.lightning import LightningModule
2425
from pytorch_lightning.core.step_result import Result
2526
from pytorch_lightning.distributed.dist import LightningDistributed
27+
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
2628
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
2729
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
28-
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available
30+
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available
2931

3032
if HYDRA_AVAILABLE:
3133
from hydra.core.hydra_config import HydraConfig
@@ -34,7 +36,10 @@
3436

3537
class DDP2Accelerator(Accelerator):
3638

37-
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
39+
def __init__(self,
40+
trainer,
41+
cluster_environment: Optional[ClusterEnvironment] = None,
42+
ddp_plugin: Optional[DDPPlugin] = None):
3843
"""
3944
Runs training using DDP2 strategy on a cluster
4045

pytorch_lightning/accelerators/ddp_accelerator.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,18 @@
2525

2626
from pytorch_lightning import _logger as log
2727
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
28+
from pytorch_lightning.cluster_environments import ClusterEnvironment
2829
from pytorch_lightning.core.lightning import LightningModule
2930
from pytorch_lightning.distributed.dist import LightningDistributed
31+
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
3032
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
3133
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
32-
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
33-
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
34+
from pytorch_lightning.utilities.distributed import (
35+
all_gather_ddp_if_available,
36+
find_free_network_port,
37+
rank_zero_only,
38+
sync_ddp_if_available,
39+
)
3440
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3541
from pytorch_lightning.utilities.seed import seed_everything
3642

@@ -41,7 +47,10 @@
4147

4248
class DDPAccelerator(Accelerator):
4349

44-
def __init__(self, trainer=None, cluster_environment=None, ddp_plugin=None):
50+
def __init__(self,
51+
trainer: Optional = None,
52+
cluster_environment: Optional[ClusterEnvironment] = None,
53+
ddp_plugin: Optional[DDPPlugin] = None):
4554
"""
4655
Runs training using DDP strategy on a single machine (manually, not via cluster start)
4756

pytorch_lightning/accelerators/ddp_cpu_hpc_accelerator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,25 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
14+
from typing import Optional
15+
16+
from pytorch_lightning import _logger as log
1417
from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator
18+
from pytorch_lightning.cluster_environments import ClusterEnvironment
19+
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
1520
from pytorch_lightning.utilities import HYDRA_AVAILABLE
1621

1722
if HYDRA_AVAILABLE:
18-
from hydra.utils import to_absolute_path, get_original_cwd
1923
from hydra.core.hydra_config import HydraConfig
24+
from hydra.utils import get_original_cwd, to_absolute_path
2025

2126

2227
class DDPCPUHPCAccelerator(DDPHPCAccelerator):
2328

24-
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
29+
def __init__(self,
30+
trainer,
31+
cluster_environment: Optional[ClusterEnvironment] = None,
32+
ddp_plugin: Optional[DDPPlugin] = None):
2533
"""
2634
Runs training using DDP (with CPUs) strategy on a cluster
2735

pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,23 @@
1616

1717
import torch
1818
import torch.distributed as torch_distrib
19-
import torch.distributed as dist
2019
import torch.multiprocessing as mp
2120
from torch.nn.parallel import DistributedDataParallel
2221

2322
from pytorch_lightning import _logger as log
2423
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
24+
from pytorch_lightning.cluster_environments import ClusterEnvironment
2525
from pytorch_lightning.core.lightning import LightningModule
2626
from pytorch_lightning.distributed.dist import LightningDistributed
27+
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
2728
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
2829
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
2930
from pytorch_lightning.utilities.distributed import (
31+
all_gather_ddp_if_available,
3032
find_free_network_port,
3133
rank_zero_only,
3234
rank_zero_warn,
3335
sync_ddp_if_available,
34-
all_gather_ddp_if_available,
3536
)
3637

3738
if HYDRA_AVAILABLE:
@@ -41,7 +42,11 @@
4142

4243
class DDPCPUSpawnAccelerator(Accelerator):
4344

44-
def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
45+
def __init__(self,
46+
trainer,
47+
nprocs: int,
48+
cluster_environment: Optional[ClusterEnvironment] = None,
49+
ddp_plugin: Optional[DDPPlugin] = None):
4550
"""
4651
Runs training using DDP (on a single machine or manually on multiple machines), using mp.spawn
4752
@@ -197,8 +202,8 @@ def broadcast(self, obj, src=0):
197202

198203
def early_stopping_should_stop(self, pl_module):
199204
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
200-
dist.all_reduce(stop, op=dist.reduce_op.SUM)
201-
dist.barrier()
205+
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
206+
torch_distrib.barrier()
202207
should_stop = stop == self.trainer.world_size
203208
return should_stop
204209

pytorch_lightning/accelerators/ddp_hpc_accelerator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121

2222
from pytorch_lightning import _logger as log
2323
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
24+
from pytorch_lightning.cluster_environments import ClusterEnvironment
2425
from pytorch_lightning.core.lightning import LightningModule
2526
from pytorch_lightning.distributed.dist import LightningDistributed
27+
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
2628
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
2729
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
28-
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available
30+
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available
2931

3032
if HYDRA_AVAILABLE:
3133
from hydra.core.hydra_config import HydraConfig
@@ -34,7 +36,10 @@
3436

3537
class DDPHPCAccelerator(Accelerator):
3638

37-
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
39+
def __init__(self,
40+
trainer,
41+
cluster_environment: Optional[ClusterEnvironment] = None,
42+
ddp_plugin: Optional[DDPPlugin] = None):
3843
"""
3944
Runs training using DDP on an HPC cluster
4045

pytorch_lightning/accelerators/ddp_spawn_accelerator.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,25 @@
1717

1818
import torch
1919
import torch.distributed as torch_distrib
20-
import torch.distributed as dist
2120
import torch.multiprocessing as mp
2221
from torch.nn.parallel import DistributedDataParallel
2322

2423
from pytorch_lightning import _logger as log
2524
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
25+
from pytorch_lightning.cluster_environments import ClusterEnvironment
2626
from pytorch_lightning.core.lightning import LightningModule
2727
from pytorch_lightning.distributed import LightningDistributed
28+
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
2829
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
2930
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
3031
from pytorch_lightning.utilities.cloud_io import atomic_save
3132
from pytorch_lightning.utilities.cloud_io import load as pl_load
3233
from pytorch_lightning.utilities.distributed import (
34+
all_gather_ddp_if_available,
3335
find_free_network_port,
3436
rank_zero_only,
3537
rank_zero_warn,
3638
sync_ddp_if_available,
37-
all_gather_ddp_if_available,
3839
)
3940
from pytorch_lightning.utilities.seed import seed_everything
4041

@@ -45,7 +46,11 @@
4546

4647
class DDPSpawnAccelerator(Accelerator):
4748

48-
def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
49+
def __init__(self,
50+
trainer,
51+
nprocs: int,
52+
cluster_environment: Optional[ClusterEnvironment] = None,
53+
ddp_plugin: Optional[DDPPlugin] = None):
4954
"""
5055
Runs training using DDP using mp.spawn via manual launch (not cluster launch)
5156
@@ -226,8 +231,8 @@ def barrier(self, name: Optional[str] = None):
226231

227232
def early_stopping_should_stop(self, pl_module):
228233
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
229-
dist.all_reduce(stop, op=dist.reduce_op.SUM)
230-
dist.barrier()
234+
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
235+
torch_distrib.barrier()
231236
should_stop = stop == self.trainer.world_size
232237
return should_stop
233238

pytorch_lightning/accelerators/dp_accelerator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Union
14+
from typing import Optional, Union
1515

1616
import torch
1717
from torch import optim
1818

19+
from pytorch_lightning import _logger as log
1920
from pytorch_lightning.accelerators.accelerator import Accelerator
21+
from pytorch_lightning.cluster_environments import ClusterEnvironment
2022
from pytorch_lightning.core.lightning import LightningModule
2123
from pytorch_lightning.core.step_result import Result
2224
from pytorch_lightning.distributed import LightningDistributed
@@ -27,7 +29,7 @@
2729

2830
class DataParallelAccelerator(Accelerator):
2931

30-
def __init__(self, trainer, cluster_environment=None):
32+
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
3133
"""
3234
Runs training using DP via manual start (not HPC cluster)
3335

pytorch_lightning/accelerators/gpu_accelerator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@
1515

1616
import torch
1717

18+
from pytorch_lightning import _logger as log
1819
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
20+
from pytorch_lightning.cluster_environments import ClusterEnvironment
1921
from pytorch_lightning.distributed.dist import LightningDistributed
2022
from pytorch_lightning.utilities import AMPType
2123

2224

2325
class GPUAccelerator(Accelerator):
2426
amp_backend: AMPType
2527

26-
def __init__(self, trainer, cluster_environment=None):
28+
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
2729
"""
2830
Runs training using a single GPU
2931

0 commit comments

Comments
 (0)