Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
be4c24c
Encapsulate extracting reference model within the plugin to allow cus…
Nov 19, 2020
5101696
Add missing new lines
Nov 19, 2020
078a829
Fix call to accelerator
Nov 19, 2020
aeab93c
Removed double blank
Nov 19, 2020
95a1f19
Use accelerator backend
Nov 19, 2020
84ccdbf
Handle case where wrapper has not been initialized within the plugin
Nov 19, 2020
0864b1c
Added basic get model tests, add better typing
Nov 19, 2020
142a2d3
Change model name
Nov 19, 2020
6e548df
Split GPU/DDP test
Nov 19, 2020
aebb1a3
Add stronger typing, skip ddp test on windows
Nov 19, 2020
fa04807
Fix import
Nov 19, 2020
47e562e
Fix import in dp
Nov 19, 2020
15734e9
Fixed PEP8 definition
Nov 19, 2020
f29f7c5
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 19, 2020
3a7a848
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 20, 2020
10a3a1e
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 20, 2020
e3869c3
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 20, 2020
3a3eaa5
Merge branch 'master' into feature/817-fairscale-5n
tchaton Nov 20, 2020
b44dd75
Add ddp launcher for ddp testing
Nov 20, 2020
6786407
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 20, 2020
9a07f67
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 20, 2020
358f503
Modify accelerator reference model to property, change name to reflec…
Nov 22, 2020
4b16b47
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 22, 2020
977625c
Revert property as this is incorrect.=
Nov 22, 2020
250cd96
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 22, 2020
b506a7e
Revert across accelerators
Nov 22, 2020
2e8585f
Add base code
Nov 19, 2020
9c34589
Assert availability via imports
Nov 20, 2020
1e429ba
Unified API upstream with suggestion to ben
Nov 21, 2020
4ae6f09
Fixed reference
Nov 22, 2020
50ed083
Add module wrapper code
Nov 22, 2020
df416f6
Fix conversion in on_before_forward
Nov 22, 2020
c590e3a
Ensure we check if we should use sharded amp plugin
Nov 22, 2020
d953f2b
Merge branch 'master' into feature/fairscale-817-6n
Nov 23, 2020
08d37d9
Fixed name ref
Nov 23, 2020
f765364
Fixed configure_ddp, removed lr scheduler modification, added unit tests
Nov 24, 2020
6b12921
Add catches around fairscale installation
Nov 24, 2020
17f23e5
Ensure imports are not required explicitly for type casting
Nov 24, 2020
a52e6a4
Add additional checkpoint tests
Nov 24, 2020
bfe754d
Removed comments, skip test
Nov 25, 2020
b39f290
Merge branch 'master' into feature/plug
Nov 25, 2020
9932608
Add additional test cases
Nov 25, 2020
d822468
Move to percentage diff, increase diff
Nov 25, 2020
a311ee1
Add fairscale requirement as zip before release
Nov 25, 2020
ba31247
Add check to ensure 1.6
Nov 25, 2020
586f6c6
Attempt try catch to prevent errors
Nov 25, 2020
22b4d5e
Merge branch 'master' into feature/plug
SeanNaren Nov 25, 2020
888b12b
Add additional else check
Nov 25, 2020
cf7a7f7
Add additional else check
Nov 25, 2020
9215908
Removed line, dont abs
Nov 25, 2020
6b93987
Revert "Add check to ensure 1.6"
Nov 25, 2020
321e63a
Fixes to import
Nov 25, 2020
28afc46
Removed lines
Nov 25, 2020
6c8715e
Swap ordering of imports
Nov 25, 2020
5f2a64b
Add explicit checkpoints for tests
Nov 25, 2020
7952767
Remove amp check as guard now upstream
Nov 26, 2020
80e5329
Add check for windows to plugin
Nov 26, 2020
fa59344
Fixes
Nov 26, 2020
8f97631
Add check to fairscale override
Nov 26, 2020
091c236
Ensure we do windows check first
Nov 26, 2020
ff34a8f
Update tests/plugins/test_sharded_plugin.py
SeanNaren Nov 26, 2020
47c121e
Addressed code review points
Nov 26, 2020
8a0c8fe
Fixed imports, swap to relying on function for entire batch
Nov 26, 2020
29e3108
Fix import order
Nov 26, 2020
c0e148b
Fix formatting
Nov 26, 2020
2e50c2e
Remove else check
Nov 26, 2020
ab655e5
Removed old eval logic, added eval tests
Nov 26, 2020
a9c316b
Add additional check to ensure apex is not used with sharded
Nov 26, 2020
74afcf7
Merge branch 'master' into feature/plug
SeanNaren Nov 26, 2020
8dc857c
Ensure we add the condition to the case statement
Nov 26, 2020
fc9b2bf
Fix logic and add test for apex check, rename file, add DDP launcher …
Nov 26, 2020
508eaff
Fix name
Nov 26, 2020
737447f
Merge branch 'master' into feature/plug
Nov 26, 2020
04bb0ab
Merge branch 'master' into feature/plug
Nov 27, 2020
bde2a12
Fix var name
Nov 27, 2020
10d41fb
Moved common functions into utilities
Nov 27, 2020
e52386b
Combine utilities
Nov 27, 2020
00bd0d2
Merge branch 'master' into feature/plug
SeanNaren Nov 27, 2020
bd4223e
Fix imports
Nov 27, 2020
5598dce
Remove unneeded check
Nov 27, 2020
cdd2e12
Add none check for func
Nov 27, 2020
4f69376
Update pytorch_lightning/trainer/connectors/precision_connector.py
SeanNaren Nov 27, 2020
1704773
Address code review
Nov 27, 2020
bf9cf3d
Tighten up regression testing
Nov 27, 2020
b4e8071
Increase speed diff for drone
Nov 27, 2020
d12577d
Reduce speed diff further, lack of GPU saturation is causing regressi…
Nov 27, 2020
06a856e
Merge branch 'master' into feature/plug
SeanNaren Nov 27, 2020
1719b2d
Skip a few tests to reduce drone CI wait times
Nov 27, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
317 changes: 317 additions & 0 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
import os
import platform
import time

import pytest
import torch

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE
from tests.backends.launcher import DDPLauncher
from tests.base.boring_model import BoringModel, RandomDataset


@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_device():
plugin_parity_test(
accelerator='ddp_cpu',
max_percent_speed_diff=0.15, # slower speed due to one CPU doing additional sequential memory saving calls
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_gpu():
plugin_parity_test(
gpus=1,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)


@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
plugin_parity_test(
gpus=1,
precision=16,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)


@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu():
plugin_parity_test(
gpus=2,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25
)


@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
plugin_parity_test(
gpus=2,
precision=16,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25
)


@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 32")
def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
plugin_parity_test(
gpus=args.gpus,
precision=args.precision,
accelerator=args.distributed_backend,
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)


@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 16")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
plugin_parity_test(
gpus=args.gpus,
precision=args.precision,
accelerator=args.distributed_backend,
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
"""
Ensures same results using multiple optimizers across multiple GPUs
"""
plugin_parity_test(
plugin=DDPShardedPlugin(),
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderMultipleOptimizersModel,
max_percent_speed_diff=0.2 # Increase speed diff since only 2 GPUs sharding 2 optimizers
)


@pytest.mark.skip(reason="Currently DDP manual optimization is broken due to no reduce within training step.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
"""
Ensures using multiple optimizers across multiple GPUs with manual optimization
"""
plugin_parity_test(
plugin=DDPShardedPlugin(),
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderManualModel,
)


class SeedTrainLoaderModel(BoringModel):
"""
Overrides training loader to ensure we enforce the same seed for all DDP processes.
"""

def train_dataloader(self):
seed_everything(42)
return torch.utils.data.DataLoader(RandomDataset(32, 64))


class SeedTrainLoaderManualModel(SeedTrainLoaderModel):
def training_step(self, batch, batch_idx, optimizer_idx):
# manual
(opt_a, opt_b) = self.optimizers()
loss_1 = self.step(batch)

self.manual_backward(loss_1, opt_a)
self.manual_optimizer_step(opt_a)

# fake discriminator
loss_2 = self.step(batch[0])

# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)
self.manual_optimizer_step(opt_b)

assert self.layer.weight.grad is None or torch.all(self.layer.weight.grad == 0)

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
assert len(outputs) == 2

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2

@property
def automatic_optimization(self) -> bool:
return False


class SeedTrainLoaderMultipleOptimizersModel(SeedTrainLoaderModel):
def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
assert len(outputs) == 2

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2


def record_ddp_fit_model_stats(trainer, model, use_cuda):
"""
Helper to calculate wall clock time for fit + max allocated memory.

Args:
trainer: The trainer object.
model: The model to fit.
use_cuda: Whether to sync CUDA kernels.

Returns:
Max Memory if using GPUs, and total wall clock time.
"""
max_memory = None

time_start = time.perf_counter()
if use_cuda:
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()

trainer.fit(model)

if use_cuda:
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() / 2 ** 20

total_time = time.perf_counter() - time_start

return max_memory, total_time


def plugin_parity_test(
model_cls: SeedTrainLoaderModel,
plugin: DDPPlugin,
seed: int = 42,
accelerator: str = 'ddp_spawn',
gpus: int = 0,
precision: int = 32,
max_percent_speed_diff: float = 0.1):
"""
Ensures that the trained model is identical to the standard DDP implementation.
Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate.

Args:
model_cls: Model class to use for test.
plugin: Plugin to parity test.
seed: Seed for generators. Note that this does not handle the seed for data-loading on multi-process.
accelerator: Accelerator type for test.
gpus: Number of GPUS to enable.
precision: Whether to use AMP or normal FP32 training.
max_percent_speed_diff: The maximum speed difference compared to normal DDP training.
This is more a safety net for variability in CI which can vary in speed, not for benchmarking.

"""

# Train normal DDP
seed_everything(seed)
ddp_model = model_cls()
use_cuda = gpus > 0

trainer = Trainer(
fast_dev_run=True,
max_epochs=1,
gpus=gpus,
precision=precision,
accelerator=accelerator,
)

max_memory_ddp, ddp_time = record_ddp_fit_model_stats(
trainer=trainer,
model=ddp_model,
use_cuda=use_cuda
)

# Reset and train Custom DDP
seed_everything(seed)
custom_plugin_model = model_cls()

trainer = Trainer(
fast_dev_run=True,
max_epochs=1,
gpus=gpus,
precision=precision,
accelerator=accelerator,
plugins=[plugin],
)

max_memory_custom, custom_model_time = record_ddp_fit_model_stats(
trainer=trainer,
model=custom_plugin_model,
use_cuda=use_cuda
)

# Assert model parameters are identical after fit
for ddp_param, custom_param in zip(ddp_model.parameters(), custom_plugin_model.parameters()):
assert torch.equal(ddp_param, custom_param), 'Model parameters are different between DDP and Custom plugin'

# Assert speed parity by ensuring percentage difference between custom/ddp is below threshold
percent_diff = (custom_model_time - ddp_time) / custom_model_time

assert percent_diff <= max_percent_speed_diff, \
f'Custom DDP plugin was too slow compared to DDP, Custom Plugin Time: {custom_model_time}, DDP Time: {ddp_time}'

if use_cuda:
# Assert CUDA memory parity
assert max_memory_custom <= max_memory_ddp, \
f'Custom plugin used too much memory compared to DDP,' \
f'Custom Mem: {max_memory_custom}, DDP Mem: {max_memory_ddp}'
32 changes: 32 additions & 0 deletions pytorch_lightning/overrides/fairscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE

LightningShardedDataParallel = None
if FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel

class LightningShardedDataParallel(ShardedDataParallel):

def forward(self, *inputs, **kwargs):
if self.enable_broadcast_buffers:
self.sync_buffers()

if self.module.training:
outputs = self.module.training_step(*inputs, **kwargs)
elif self.module.testing:
outputs = self.module.test_step(*inputs, **kwargs)
else:
outputs = self.module.validation_step(*inputs, **kwargs)
return outputs
33 changes: 33 additions & 0 deletions pytorch_lightning/plugins/sharded_native_amp_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast

from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE, FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin

if NATIVE_AMP_AVAILABLE and FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler


class ShardedNativeAMPPlugin(NativeAMPPlugin):
@property
def scaler(self):
return ShardedGradScaler()

def clip_gradients(self, grad_clip_val, model, optimizer):
max_norm = grad_clip_val
norm_type = float(2.0)
optimizer = cast(OSS, optimizer)
optimizer.clip_grad_norm(max_norm, norm_type=norm_type)
Loading