Skip to content

Commit 07fdd95

Browse files
tchatonrootUbuntujustusschock
authored
[accelerator][BugFix] Resolve some test for 1 gpu (#5863)
* update * revert init * resolve a bug * update * resolve flake8 * update * update * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * update * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * revert init * update * resolve flake8 * update * update * update * update * update * all_gather * update * make plugins work, add misconfig for RPC * update * update * remove breaking test * resolve some tests * resolve flake8 * revert to ddp_spawn Co-authored-by: root <[email protected]> Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Justus Schock <[email protected]>
1 parent 0ac5fc4 commit 07fdd95

28 files changed

+153
-79
lines changed

.drone.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ steps:
4747
- unzip -o legacy/checkpoints.zip -d legacy/
4848
- ls -l legacy/checkpoints/
4949
# testing...
50-
- python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8
50+
- python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests --ignore tests/plugins/test_sharded_plugin.py --ignore tests/trainer/test_dataloaders.py -v --durations=25 # --flake8
51+
# Todo: Find why those tests are failing when run in the main pytest.
52+
- python -m coverage run -a --source pytorch_lightning -m pytest tests/plugins/test_sharded_plugin.py tests/trainer/test_dataloaders.py -v --durations=25 # --flake8
5153
# Running special tests
5254
- sh tests/special_tests.sh
5355
- coverage report

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,4 @@ wandb
151151

152152
# dataset generated from bolts in examples.
153153
cifar-10-batches-py
154+
*.pt

pytorch_lightning/accelerators/accelerator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
1415
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union
1516

1617
import torch
@@ -374,3 +375,15 @@ def on_save(self, checkpoint):
374375

375376
def barrier(self, name: Optional[str] = None) -> None:
376377
self.training_type_plugin.barrier(name=name)
378+
379+
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
380+
"""
381+
Function to gather a tensor from several distributed processes
382+
Args:
383+
tensor: tensor of shape (batch, ...)
384+
group: the process group to gather results from. Defaults to all processes (world)
385+
sync_grads: flag that allows users to synchronize gradients for all_gather op
386+
Return:
387+
A tensor of shape (world_size, batch, ...)
388+
"""
389+
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

pytorch_lightning/accelerators/accelerator_connector.py

100644100755
Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
HorovodPlugin,
3434
NativeMixedPrecisionPlugin,
3535
PrecisionPlugin,
36-
RPCPlugin,
3736
ShardedNativeMixedPrecisionPlugin,
3837
SingleDevicePlugin,
3938
SingleTPUPlugin,
@@ -116,11 +115,11 @@ def __init__(
116115
self.parallel_device_ids = device_parser.parse_gpu_ids(self.gpus)
117116
self.root_gpu = device_parser.determine_root_gpu_device(self.parallel_device_ids)
118117

119-
self.handle_given_plugins(plugins)
120-
121118
self.set_distributed_mode()
122119
self.configure_slurm_ddp()
123120

121+
self.handle_given_plugins(plugins)
122+
124123
self.accelerator = self.select_accelerator()
125124

126125
# override dist backend when using tpus
@@ -147,8 +146,10 @@ def __init__(
147146
self.replace_sampler_ddp = replace_sampler_ddp
148147

149148
def handle_given_plugins(self, plugins: Optional[Sequence]):
150-
if plugins is None:
151-
return
149+
plugins = plugins if plugins is not None else []
150+
151+
if isinstance(plugins, str):
152+
plugins = [plugins]
152153

153154
if not isinstance(plugins, Sequence):
154155
plugins = [plugins]
@@ -158,7 +159,10 @@ def handle_given_plugins(self, plugins: Optional[Sequence]):
158159
cluster_environment = None
159160

160161
for plug in plugins:
161-
if isinstance(plug, TrainingTypePlugin):
162+
if isinstance(plug, str):
163+
self.set_distributed_mode(plug)
164+
165+
elif isinstance(plug, TrainingTypePlugin):
162166
if training_type is None:
163167
training_type = plug
164168

@@ -191,6 +195,7 @@ def handle_given_plugins(self, plugins: Optional[Sequence]):
191195
)
192196

193197
self._training_type_plugin = training_type
198+
self._training_type_plugin = self.training_type_plugin
194199
self._precision_plugin = precision
195200
self._cluster_environment = cluster_environment or self.select_cluster_environment()
196201

@@ -206,6 +211,7 @@ def training_type_plugin(self) -> TrainingTypePlugin:
206211
self._training_type_plugin = self.select_training_type_plugin()
207212
else:
208213
self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin)
214+
209215
return self._training_type_plugin
210216

211217
@property
@@ -327,7 +333,7 @@ def select_precision_plugin(self):
327333

328334
def select_training_type_plugin(self):
329335
if self.use_ddp2:
330-
plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self._cluster_environment)
336+
plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment)
331337
elif self.use_ddp:
332338
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
333339
use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
@@ -359,7 +365,7 @@ def select_training_type_plugin(self):
359365
plugin = ddp_plugin_cls(
360366
parallel_devices=self.parallel_devices,
361367
num_nodes=self.num_nodes,
362-
cluster_environment=self.select_cluster_environment(),
368+
cluster_environment=self.cluster_environment,
363369
sync_batchnorm=self.sync_batchnorm,
364370
)
365371
elif self.use_dp:
@@ -425,7 +431,11 @@ def select_cluster_environment(self):
425431
env = TorchElasticEnvironment()
426432
return env
427433

428-
def set_distributed_mode(self):
434+
def set_distributed_mode(self, distributed_backend: Optional[str] = None):
435+
436+
if distributed_backend is not None:
437+
self.distributed_backend = distributed_backend
438+
429439
if isinstance(self.distributed_backend, Accelerator):
430440
return
431441

@@ -484,6 +494,9 @@ def set_distributed_mode(self):
484494
):
485495
self.num_processes = self.num_gpus
486496

497+
if (self._device_type == DeviceType.GPU and self._distrib_type == DistributedType.DDP2):
498+
self.num_processes = self.num_nodes
499+
487500
# Horovod is an extra case...
488501
if self.distributed_backend == "horovod":
489502
self._set_horovod_backend()

pytorch_lightning/accelerators/tpu.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Callable
1+
from typing import Any, Callable, Optional, Union
2+
import torch
23

34
from torch.optim import Optimizer
45

@@ -28,3 +29,15 @@ def setup(self, trainer, model):
2829

2930
def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
3031
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})
32+
33+
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
34+
"""
35+
Function to gather a tensor from several distributed processes
36+
Args:
37+
tensor: tensor of shape (batch, ...)
38+
group: the process group to gather results from. Defaults to all processes (world)
39+
sync_grads: flag that allows users to synchronize gradients for all_gather op
40+
Return:
41+
A tensor of shape (world_size, batch, ...)
42+
"""
43+
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -540,9 +540,9 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
540540

541541
accelerator_backend = trainer.accelerator_backend
542542

543-
if accelerator_backend is not None and accelerator_backend.rpc_enabled:
543+
if accelerator_backend.training_type_plugin.rpc_enabled:
544544
# RPCPlugin manages saving all model states
545-
accelerator_backend.ddp_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module)
545+
accelerator_backend.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module)
546546
else:
547547
self._save_model(last_filepath, trainer, pl_module)
548548
if (

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Callable, List, Tuple
14+
from typing import List, Tuple, Callable
1515

1616
import torch
1717
from torch.optim import Optimizer
@@ -38,6 +38,8 @@ def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
3838
"""Connects the precision plugin to the training process,
3939
configures apex and reinits the schedulers
4040
"""
41+
if model.device.type != "cuda":
42+
return model, optimizers, lr_schedulers
4143
model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level)
4244
self.reinit_scheduler_properties(optimizers, lr_schedulers)
4345
return model, optimizers, lr_schedulers
@@ -71,7 +73,7 @@ def backward(
7173
# do backward pass
7274
# TODO: not entirely sure, why we need this
7375
if model is not None and isinstance(model, LightningModule):
74-
model.backward(closure_loss, optimizer, opt_idx)
76+
model.backward(closure_loss, optimizer, opt_idx, **kwargs)
7577

7678
# TODO: avoid dev_debugger and track these calls with mock
7779
model.trainer.dev_debugger.track_event('AMP', str(AMPType.APEX))
@@ -90,6 +92,16 @@ def backward(
9092
closure_loss = closure_loss.detach()
9193
return closure_loss
9294

95+
def pre_optimizer_step(
96+
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, closure: Callable, **kwargs
97+
) -> bool:
98+
"""Hook to do something before each optimizer step."""
99+
# Apex: Amp does not support closure use with optimizers
100+
closure()
101+
optimizer.step()
102+
return False
103+
104+
93105
def configure_apex(
94106
self,
95107
amp: object,

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pytorch_lightning.overrides.distributed import prepare_for_backward
3030
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3131
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
32-
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _PYTORCH_GREATER_EQUAL_THAN_1_7_0, rank_zero_warn
32+
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _PYTORCH_GREATER_EQUAL_1_7_0, rank_zero_warn
3333
from pytorch_lightning.utilities.distributed import (
3434
find_free_network_port,
3535
rank_zero_only,
@@ -181,7 +181,7 @@ def set_world_ranks(self):
181181

182182
def pre_configure_ddp(self):
183183
# todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()``` breaking manual_optimization
184-
if _PYTORCH_GREATER_EQUAL_THAN_1_7_0 and not self.lightning_module.automatic_optimization:
184+
if _PYTORCH_GREATER_EQUAL_1_7_0 and not self.lightning_module.automatic_optimization:
185185
rank_zero_warn(
186186
"From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` "
187187
"to properly work with DDP."

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytorch_lightning.overrides.distributed import prepare_for_backward
2828
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
2929
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
30-
from pytorch_lightning.utilities import _PYTORCH_GREATER_EQUAL_THAN_1_7_0
30+
from pytorch_lightning.utilities import _PYTORCH_GREATER_EQUAL_1_7_0
3131
from pytorch_lightning.utilities.cloud_io import atomic_save
3232
from pytorch_lightning.utilities.cloud_io import load as pl_load
3333
from pytorch_lightning.utilities.distributed import (
@@ -91,6 +91,7 @@ def setup(self, model):
9191
def set_world_ranks(self, process_idx):
9292
self.local_rank = process_idx
9393
self.node_rank = self.cluster_environment.node_rank()
94+
self.task_idx = self.cluster_local_rank
9495
self.global_rank = self.node_rank * self.num_processes + self.local_rank
9596
self.world_size = self.num_nodes * self.num_processes
9697

@@ -164,7 +165,7 @@ def post_training(self):
164165

165166
def pre_configure_ddp(self):
166167
# todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()``` breaking manual_optimization
167-
if _PYTORCH_GREATER_EQUAL_THAN_1_7_0 and not self.lightning_module.automatic_optimization:
168+
if _PYTORCH_GREATER_EQUAL_1_7_0 and not self.lightning_module.automatic_optimization:
168169
rank_zero_warn(
169170
"From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` "
170171
"to properly work with DDP."

pytorch_lightning/plugins/training_type/parallel.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,17 @@ def __init__(
3636
):
3737
super().__init__()
3838
self.parallel_devices = parallel_devices
39-
self.local_rank = 0
4039
self.world_size = 1
40+
self.local_rank = 0
4141
self.cluster_environment = cluster_environment
4242

43+
@property
44+
def cluster_local_rank(self):
45+
try:
46+
return self.cluster_environment.local_rank()
47+
except KeyError:
48+
return 0
49+
4350
@property
4451
@abstractmethod
4552
def root_device(self):

0 commit comments

Comments
 (0)