Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.plugins import DDPSpawnShardedPlugin
from pytorch_lightning.plugins.legacy.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.legacy.sharded_plugin import DDPShardedPlugin
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE
Expand All @@ -32,7 +33,7 @@
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_gpu():
sharded_parity_test(
plugin_parity_test(
gpus=1,
model_cls=SeedTrainLoaderModel,
)
Expand All @@ -43,7 +44,7 @@ def test_ddp_sharded_plugin_correctness_one_gpu():
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
sharded_parity_test(
plugin_parity_test(
gpus=1,
precision=16,
model_cls=SeedTrainLoaderModel,
Expand All @@ -55,7 +56,7 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu():
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu():
sharded_parity_test(
plugin_parity_test(
gpus=2,
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
Expand All @@ -67,7 +68,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
sharded_parity_test(
plugin_parity_test(
gpus=2,
precision=16,
model_cls=SeedTrainLoaderModel,
Expand All @@ -80,7 +81,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
sharded_parity_test(
plugin_parity_test(
gpus=2,
precision=16,
model_cls=SeedTrainLoaderModel,
Expand All @@ -95,7 +96,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
)
@DDPLauncher.run("--accelerator ddp --gpus 2 --precision 32")
def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
sharded_parity_test(
plugin_parity_test(
gpus=args.gpus,
precision=args.precision,
model_cls=SeedTrainLoaderModel,
Expand All @@ -109,7 +110,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
)
@DDPLauncher.run("--accelerator ddp --gpus 2 --precision 16")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
sharded_parity_test(
plugin_parity_test(
gpus=args.gpus,
precision=args.precision,
model_cls=SeedTrainLoaderModel,
Expand All @@ -124,7 +125,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
"""
Ensures same results using multiple optimizers across multiple GPUs
"""
sharded_parity_test(
plugin_parity_test(
gpus=2,
model_cls=SeedTrainLoaderMultipleOptimizersModel,
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
Expand All @@ -139,7 +140,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
"""
Ensures using multiple optimizers across multiple GPUs with manual optimization
"""
sharded_parity_test(
plugin_parity_test(
gpus=2,
model_cls=SeedTrainLoaderManualModel,
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
Expand Down Expand Up @@ -242,9 +243,7 @@ def record_ddp_fit_model_stats(trainer, model, use_cuda):

def plugin_parity_test(
model_cls: Type[SeedTrainLoaderModel],
plugin: Union[str, DDPPlugin],
seed: int = 42,
accelerator: str = 'ddp_spawn',
gpus: int = 0,
precision: int = 32,
max_percent_speed_diff: float = 0.1,
Expand Down Expand Up @@ -289,6 +288,7 @@ def plugin_parity_test(
precision=precision,
accelerator='ddp_sharded_spawn',
)
assert isinstance(trainer.training_type_plugin, DDPSpawnShardedPlugin)

max_memory_custom, custom_model_time = record_ddp_fit_model_stats(
trainer=trainer, model=custom_plugin_model, use_cuda=use_cuda
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def setup(self, trainer, model):
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
self.set_nvidia_flags()
torch.cuda.set_device(self.root_device)
model.to(self.root_device)
return super().setup(trainer, model)

def on_train_start(self):
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def __init__(self, parallel_devices: List[torch.device]):
super().__init__(parallel_devices=parallel_devices, cluster_environment=None)

def setup(self, model):
# model needs to be moved to the device before it is wrapped
model.to(self.root_device)
self._model = DataParallel(LightningParallelModule(model), self.parallel_devices)

def reduce(self, output, *args, **kwargs):
Expand Down
6 changes: 0 additions & 6 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,13 @@ def configure_ddp(self):
def _reinit_optimizers_with_oss(self):
optimizers = self.lightning_module.trainer.optimizers
for x, optimizer in enumerate(optimizers):
if is_lightning_optimizer(optimizer):
optimizer = optimizer._optimizer
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
optimizers[x] = zero_optimizer
del optimizer
trainer = self.lightning_module.trainer
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()

def _wrap_optimizers(self):
trainer = self.model.trainer
Expand All @@ -41,9 +38,6 @@ def _wrap_optimizers(self):
self._reinit_optimizers_with_oss()

def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]:
if is_lightning_optimizer(optimizer):
optimizer = optimizer._optimizer

if isinstance(optimizer, OSS):
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)
Expand Down