Skip to content

Commit ef8ef12

Browse files
tchatonSeanNarenSeanNarenBordacarmocca
authored
[feat] pp 2/n (#5026)
* Added changes for RPC plugin * Add missing kwargs * Fix code format * Loading refactors by introducing is_distributed var, fix optimizer step flow * Add rpc guard * Added docstrings and typing * resolve comments * Add additional rpc hook, refactor name of exit process hook for clarity * remove annotation * Modify behaviour to allow optional return, add test for rpc plugin * resolve tests * rename is_ddp_based * update * update for windows * update * resolve test * code smell * Added sequential plugin * resolve bug * update * cleanup * add Exception * resolve docs * Remove ddp support * Revert distributed -> ddp * Update pl_examples/basic_examples/conv_sequential_example.py Co-authored-by: Jirka Borovec <[email protected]> * Update pl_examples/basic_examples/conv_sequential_example.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Jirka Borovec <[email protected]> * Address code review points * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Jirka Borovec <[email protected]> * Add missing return * Fix formatting, add datamodule args * add small comment * resolve comments * resolve comments * update source for fairscale * update extras * remove staticmethod * resolve flake8 * Skip tests that are failing due to bug upstream with multiple optimizers and shard * update * update on comments * clean test * latest comments * remove old comments * add todo * Update version * update * resolve bugs * resolve bugs * update test * remove hanging test * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <[email protected]> * resolve on comments * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <[email protected]> * resolve on comments * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <[email protected]> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <[email protected]> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <[email protected]> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <[email protected]> * remove ImportError Co-authored-by: SeanNaren <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 7d9784e commit ef8ef12

File tree

13 files changed

+881
-27
lines changed

13 files changed

+881
-27
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,6 @@ repos:
3232
types: [python]
3333

3434
- repo: https://github.com/pre-commit/mirrors-mypy
35+
rev: master
3536
hooks:
3637
- id: mypy

benchmarks/test_sharded_parity.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
131131
)
132132

133133

134+
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
134135
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
135136
@pytest.mark.skipif(platform.system() == "Windows",
136137
reason="Distributed training is not supported on Windows")
@@ -148,6 +149,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
148149
)
149150

150151

152+
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
151153
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
152154
@pytest.mark.skipif(platform.system() == "Windows",
153155
reason="Distributed training is not supported on Windows")
@@ -189,7 +191,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
189191

190192
# ensure we forward the correct params to the optimizer
191193
# without retain_graph we can't do multiple backward passes
192-
self.manual_backward(loss_2, opt_b, retain_graph=True)
194+
self.manual_backward(loss_2, opt_b)
193195
# todo: understand why synchronization breaks there.
194196
# self.manual_backward(loss_2, opt_a, retain_graph=True)
195197
opt_b.step()
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
16+
Example script of running the experimental DDP Sequential Plugin.
17+
This script splits a convolutional model onto multiple GPUs, whilst using the internal built in balancer
18+
to balance across your GPUs.
19+
20+
To run:
21+
python conv_model_sequential_example.py --accelerator ddp --gpus 4 --max_epochs 1 --batch_size 256 --use_ddp_sequential
22+
"""
23+
import math
24+
from argparse import ArgumentParser
25+
26+
import torch
27+
import torch.nn as nn
28+
import torch.nn.functional as F
29+
import torchvision
30+
31+
import pytorch_lightning as pl
32+
from pytorch_lightning import Trainer
33+
from pytorch_lightning.metrics.functional import accuracy
34+
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin
35+
from pytorch_lightning.utilities import BOLTS_AVAILABLE, FAIRSCALE_PIPE_AVAILABLE
36+
37+
if BOLTS_AVAILABLE:
38+
import pl_bolts
39+
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
40+
41+
42+
#####################
43+
# Modules #
44+
#####################
45+
46+
47+
class Flatten(nn.Module):
48+
def forward(self, x):
49+
return x.view(x.size(0), -1)
50+
51+
###############################
52+
# LightningModule #
53+
###############################
54+
55+
56+
class LitResnet(pl.LightningModule):
57+
def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
58+
super().__init__()
59+
60+
self.save_hyperparameters()
61+
self.sequential_module = nn.Sequential(
62+
# Conv Layer block 1
63+
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
64+
nn.BatchNorm2d(32),
65+
nn.ReLU(inplace=False),
66+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
67+
nn.ReLU(inplace=False),
68+
nn.MaxPool2d(kernel_size=2, stride=2),
69+
70+
# Conv Layer block 2
71+
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
72+
nn.BatchNorm2d(128),
73+
nn.ReLU(inplace=False),
74+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
75+
nn.ReLU(inplace=False),
76+
nn.MaxPool2d(kernel_size=2, stride=2),
77+
nn.Dropout2d(p=0.05),
78+
79+
# Conv Layer block 3
80+
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
81+
nn.BatchNorm2d(256),
82+
nn.ReLU(inplace=False),
83+
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
84+
nn.ReLU(inplace=False),
85+
nn.MaxPool2d(kernel_size=2, stride=2),
86+
87+
Flatten(),
88+
89+
nn.Dropout(p=0.1),
90+
nn.Linear(4096, 1024),
91+
nn.ReLU(inplace=False),
92+
nn.Linear(1024, 512),
93+
nn.ReLU(inplace=False),
94+
nn.Dropout(p=0.1),
95+
nn.Linear(512, 10)
96+
)
97+
self._example_input_array = torch.randn((1, 3, 32, 32))
98+
self._manual_optimization = manual_optimization
99+
if self._manual_optimization:
100+
self.training_step = self.training_step_manual
101+
102+
def forward(self, x):
103+
out = self.sequential_module(x)
104+
return F.log_softmax(out, dim=-1)
105+
106+
def training_step_manual(self, batch, batch_idx):
107+
opt = self.optimizers()
108+
109+
def closure():
110+
x, y = batch
111+
logits = self.forward(x)
112+
loss = F.nll_loss(logits, y)
113+
self.manual_backward(loss, opt)
114+
self.log('train_loss', loss, prog_bar=True)
115+
116+
opt.step(closure=closure)
117+
118+
def training_step(self, batch, batch_idx):
119+
x, y = batch
120+
logits = self.forward(x)
121+
loss = F.nll_loss(logits, y)
122+
self.log('Training Loss', loss)
123+
return loss
124+
125+
def _evaluate(self, batch, batch_idx, stage=None):
126+
x, y = batch
127+
out = self.forward(x)
128+
logits = F.log_softmax(out, dim=-1)
129+
loss = F.nll_loss(logits, y)
130+
preds = torch.argmax(logits, dim=-1)
131+
acc = accuracy(preds, y)
132+
133+
if stage:
134+
self.log(f'{stage}_loss', loss, prog_bar=True)
135+
self.log(f'{stage}_acc', acc, prog_bar=True)
136+
137+
return loss, acc
138+
139+
def validation_step(self, batch, batch_idx):
140+
return self._evaluate(batch, batch_idx, 'val')[0]
141+
142+
def test_step(self, batch, batch_idx):
143+
loss, acc = self._evaluate(batch, batch_idx, 'test')
144+
self.log_dict({'test_loss': loss, 'test_acc': acc})
145+
146+
def configure_optimizers(self):
147+
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
148+
return {
149+
'optimizer': optimizer,
150+
'lr_scheduler': {
151+
'scheduler': torch.optim.lr_scheduler.OneCycleLR(
152+
optimizer,
153+
0.1,
154+
epochs=self.trainer.max_epochs,
155+
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)),
156+
'interval': 'step',
157+
}
158+
}
159+
160+
@property
161+
def automatic_optimization(self) -> bool:
162+
return not self._manual_optimization
163+
164+
165+
#################################
166+
# Instantiate Data Module #
167+
#################################
168+
169+
def instantiate_datamodule(args):
170+
train_transforms = torchvision.transforms.Compose([
171+
torchvision.transforms.RandomCrop(32, padding=4),
172+
torchvision.transforms.RandomHorizontalFlip(),
173+
torchvision.transforms.ToTensor(),
174+
cifar10_normalization(),
175+
])
176+
177+
test_transforms = torchvision.transforms.Compose([
178+
torchvision.transforms.ToTensor(),
179+
cifar10_normalization(),
180+
])
181+
182+
cifar10_dm = pl_bolts.datamodules.CIFAR10DataModule(
183+
batch_size=args.batch_size,
184+
train_transforms=train_transforms,
185+
test_transforms=test_transforms,
186+
val_transforms=test_transforms,
187+
)
188+
189+
return cifar10_dm
190+
191+
192+
if __name__ == "__main__":
193+
parser = ArgumentParser(description="Pipe Example")
194+
parser.add_argument("--use_ddp_sequential", action="store_true")
195+
parser = Trainer.add_argparse_args(parser)
196+
parser = pl_bolts.datamodules.CIFAR10DataModule.add_argparse_args(parser)
197+
args = parser.parse_args()
198+
199+
assert BOLTS_AVAILABLE, "Bolts is required for this example, install it via pip install pytorch-lightning-bolts"
200+
assert FAIRSCALE_PIPE_AVAILABLE, "FairScale and PyTorch 1.6 is required for this example."
201+
202+
cifar10_dm = instantiate_datamodule(args)
203+
204+
plugins = None
205+
if args.use_ddp_sequential:
206+
plugins = DDPSequentialPlugin()
207+
208+
model = LitResnet(batch_size=args.batch_size, manual_optimization=not args.automatic_optimization)
209+
210+
trainer = pl.Trainer.from_argparse_args(args, plugins=[plugins] if plugins else None)
211+
trainer.fit(model, cifar10_dm)
212+
trainer.test(model, datamodule=cifar10_dm)
213+
214+
if trainer.accelerator_backend.rpc_enabled:
215+
# Called at the end of trainer to ensure all processes are killed
216+
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()

pytorch_lightning/overrides/data_parallel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ class LightningDistributedDataParallel(DistributedDataParallel):
155155
"""
156156
Override the forward call in lightning so it goes to training and validation step respectively
157157
"""
158+
PREPARE_FOR_BACKWARDS = True
158159

159160
def parallel_apply(self, replicas, inputs, kwargs):
160161
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
@@ -165,6 +166,7 @@ def forward(self, *inputs, **kwargs): # pragma: no-cover
165166
fx_called: str = ''
166167

167168
if self.device_ids:
169+
168170
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
169171
if len(self.device_ids) == 1:
170172
# --------------
@@ -195,7 +197,7 @@ def forward(self, *inputs, **kwargs): # pragma: no-cover
195197
else:
196198
output = self.module.validation_step(*inputs, **kwargs)
197199

198-
if not self._reducer_prepared_for_backwards:
200+
if not self._reducer_prepared_for_backwards and self.PREPARE_FOR_BACKWARDS:
199201
self.reducer_prepare_for_backwards(output)
200202

201203
if output is None:

0 commit comments

Comments
 (0)