Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
568b35e
Added base fairscale accelerator and dependency. modified checkpoint …
SeanNaren Oct 11, 2020
630796b
Added wrapper class to ensure we only call state_dict on rank zero
SeanNaren Oct 15, 2020
133a250
Added additional comment from override, fixed over-identation
SeanNaren Oct 15, 2020
eb552ef
Added wrapper for sharded ddp
SeanNaren Oct 18, 2020
3439f47
Update state_dict call, allow every process to call function, only wa…
SeanNaren Oct 19, 2020
e15fc65
Update broadcast call in model dispatch
SeanNaren Oct 22, 2020
bd000dc
Swap name to sharded_ddp
SeanNaren Nov 1, 2020
16d679d
Update API based on fairscale oss_autograd changes
Nov 5, 2020
f4343ec
Add check for adding ddp sampler for ddp_sharded
Nov 5, 2020
bc5ab46
Merge branch 'master' into feature/817-fairscale
SeanNaren Nov 5, 2020
aca672b
Remove need for explicit require grad by ensuring we clear up handles…
Nov 6, 2020
39df0ff
Merge branch 'master' into feature/817-fairscale
SeanNaren Nov 6, 2020
e4abd88
Revert to normal DDP in testing mode
Nov 6, 2020
628ef0f
Added elastic sharded ddp, need to reduce duplication
Nov 8, 2020
0147d03
Added sharded ddp to check
Nov 8, 2020
6b17458
Added temporary broadcast to ensure we broadcast parameters regardles…
Nov 9, 2020
8c59f94
Added fix to torchelastic accelerator
Nov 9, 2020
5f81215
Merge branch 'master' into feature/817-fairscale
Nov 10, 2020
33363c4
Added initial changes to support ShardedDDPPlugin
Nov 10, 2020
8cdd66c
Removed more sharded refs
Nov 10, 2020
aeeca87
Removed ref
Nov 10, 2020
88e50d0
Fixed indent
Nov 10, 2020
135d2d6
Pass model reference
Nov 10, 2020
d86c92b
Better name
Nov 10, 2020
c709364
Added to device function to ensure we move to device when using shard…
Nov 10, 2020
1bbf0df
Pass model ref
Nov 10, 2020
aedeaed
Pass DDP model
Nov 10, 2020
5c18b18
Fix reduce
Nov 10, 2020
1c2316b
Fix logic
Nov 10, 2020
1dfbd95
Simplified gradscaler
Nov 10, 2020
6de6a4f
Add temporary grad scaler handling
Nov 10, 2020
27d2682
Swapped to encapsulating within the precision connector
Nov 11, 2020
770caa8
Update scaler if using amp
Nov 11, 2020
1ba7bcf
Refactor to select amp plugin correctly as sharded
Nov 12, 2020
4b3ebd6
Add additional plugin check to ensure it has been init
Nov 12, 2020
c9c7921
Added custom sharded clip gradients logic, abstracted out precision p…
Nov 12, 2020
b2068de
Merge branch 'master' into feature/817-fairscale
Nov 12, 2020
acda934
Add rank save check for state sharding, removing need to override OSS…
Nov 14, 2020
166237b
Allow ddp plugin to modify optimizer state saving
Nov 14, 2020
6d285a8
Merge branch 'master' into feature/817-fairscale-2n
tchaton Nov 14, 2020
acaa995
Rely on the accelerator for optimizer states
Nov 14, 2020
af65eff
Ensure we init the accelerator for the saving function
Nov 14, 2020
5c8a7b4
Better comment for optim state dump
Nov 14, 2020
9b82bfb
Revert "Ensure we init the accelerator for the saving function"
Nov 14, 2020
f9929c0
Added accelerator check to initialize tuner before saving model check…
Nov 14, 2020
e669930
Simplify comment
Nov 15, 2020
18b0d74
Revert "Added accelerator check to initialize tuner before saving mod…
Nov 15, 2020
09ef93e
Return single optimizer state to reduce duplication
Nov 15, 2020
14b54e2
Fixed docstring
Nov 15, 2020
6be7caf
Fixed typing
Nov 15, 2020
098fc64
Fixed comment
Nov 15, 2020
db91ec5
Added CHANGELOG.md
Nov 15, 2020
206a660
Allow ddp plugin to move the input to a different device if needed
Nov 15, 2020
68c8e55
Merge branch 'master' into feature/817-fairscale-3n-redo
SeanNaren Nov 15, 2020
1a25532
Merge branch 'master' into feature/817-fairscale-3n-redo
SeanNaren Nov 15, 2020
da4c022
Merge branch 'master' into feature/817-fairscale-3n-redo
williamFalcon Nov 15, 2020
ca6d536
Swapped name to on_before_forward to align with hooks in the future
Nov 15, 2020
bbe5760
Merge branch 'master' into feature/817-fairscale-3n-redo
SeanNaren Nov 18, 2020
3b90d6b
Merge branch 'master' into feature/817-fairscale-2n
SeanNaren Nov 18, 2020
369d8c7
Merge branch 'master' into feature/817-fairscale-2n
SeanNaren Nov 18, 2020
1125ae3
Merge branch 'master' into feature/817-fairscale-3n-redo
SeanNaren Nov 18, 2020
61fc39c
Expose scaler in amp plugin
Nov 18, 2020
681217b
Merge branch 'master' into feature/817-fairscale-4n
SeanNaren Nov 18, 2020
9e5a4d8
Merge branch 'master' into feature/817-fairscale
Nov 18, 2020
a3d9680
Merge branch 'feature/817-fairscale-2n' into feature/817-fairscale
Nov 18, 2020
30aaa27
Merge branch 'feature/817-fairscale-3n-redo' into feature/817-fairscale
Nov 18, 2020
ff49a59
Merge branch 'feature/817-fairscale-4n' into feature/817-fairscale
Nov 18, 2020
8aa9bd2
Merged pending PRs to unify API, updated to use latest sharded DDP
Nov 18, 2020
115d498
Fixed var call
Nov 18, 2020
f461094
temp
Nov 24, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))


- Added ability for DDP plugin to modify optimizer state saving ([#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675))


### Changed

- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))
Expand Down
13 changes: 12 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from enum import Enum
from typing import Any, Optional, Union
from typing import Any, Optional, Union, List

import torch
from torch.optim import Optimizer
Expand Down Expand Up @@ -202,6 +202,17 @@ def sync_tensor(self,
"""
raise NotImplementedError()

def optimizer_state(self, optimizer: Optimizer) -> dict:
"""
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
plugins.
Return:
Optimizer state dict
"""
if self.ddp_plugin:
return self.ddp_plugin.optimizer_state(optimizer)
return optimizer.state_dict()

def __getstate__(self):
return {
'trainer': self.trainer,
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ def select_accelerator(self):
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks

# torchelastic or general non_slurm ddp
te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ)
use_torchelastic_ddp = self.trainer.use_ddp and te_flags_passed
use_torchelastic_ddp = self.trainer.use_ddp and self._is_using_torchelastic()

use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_spawn"
use_ddp_cpu_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_cpu"
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,23 @@ def train(self):
return self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model)

def training_step(self, args):
return self._step(args)

def validation_step(self, args):
return self._step(args)

def test_step(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model())
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,23 @@ def train(self):
return results

def training_step(self, args):
return self._step(args)

def validation_step(self, args):
return self._step(args)

def test_step(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model())
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,23 @@ def ddp_train(self, process_idx, mp_queue, model):
torch.cuda.empty_cache()

def training_step(self, args):
return self._step(args)

def validation_step(self, args):
return self._step(args)

def test_step(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model())
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,23 @@ def get_device_ids(self):
return device_ids

def training_step(self, args):
return self._step(args)

def validation_step(self, args):
return self._step(args)

def test_step(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model())
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
Expand Down
25 changes: 17 additions & 8 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,23 @@ def get_device_ids(self):
return device_ids

def training_step(self, args):
return self._step(args)

def validation_step(self, args):
return self._step(args)

def test_step(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model())
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
Expand Down Expand Up @@ -270,3 +272,10 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def sync_optim_state(self):
self.ddp_plugin.sync_optim_state(self.trainer.get_model())

@property
def rank_should_save_optim_state(self):
return self.ddp_plugin.rank_should_save_optim_state(self.trainer.global_rank)
16 changes: 16 additions & 0 deletions pytorch_lightning/plugins/ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Dict, Any

from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel

Expand Down Expand Up @@ -62,3 +64,17 @@ def configure_ddp(self, model, device_ids):
**self._ddp_kwargs,
)
return model

def on_before_forward(self, args: Any, model: LightningModule):
"""
Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally
within the DDP wrapper.
Args:
args: Inputs to the model.
model: Model to train.
Returns: args moved to correct device if needed.
"""
return args

def optimizer_state(self, optimizer: Optimizer) -> dict:
return optimizer.state_dict()
8 changes: 8 additions & 0 deletions pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def training_step(self, fx, args):
output = fx(*args)
return output

@property
def scaler(self):
return torch.cuda.amp.GradScaler()

def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
model = self.trainer.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type)

@property
def scaler(self):
return torch.cuda.amp.GradScaler()
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Union

from torch.optim import Optimizer

import abc


class PrecisionPlugin(abc.ABC):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,12 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
callback_states = self.trainer.on_save_checkpoint()
checkpoint['callbacks'] = callback_states

# dump optimizers
optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers):
optimizer_states.append(optimizer.state_dict())
# Rely on accelerator to dump optimizer state
optimizer_state = self.trainer.accelerator_backend.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)

checkpoint['optimizer_states'] = optimizer_states

# dump lr schedulers
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
LightningDataParallel,
)


class ModelConnector:
def __init__(self, trainer):
self.trainer = trainer
Expand Down Expand Up @@ -55,6 +54,13 @@ def copy_trainer_model_properties(self, model):
m.local_rank = self.trainer.local_rank

def get_model(self):
is_dp_module = isinstance(self.trainer.model, (LightningDistributedDataParallel, LightningDataParallel))
is_dp_module = isinstance(
self.trainer.model,
(
LightningShardedDataParallel,
LightningDistributedDataParallel,
LightningDataParallel
)
)
model = self.trainer.model.module if is_dp_module else self.trainer.model
return model
33 changes: 26 additions & 7 deletions pytorch_lightning/trainer/connectors/precision_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.apex import ApexPlugin
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, AMPType, rank_zero_warn


Expand All @@ -24,7 +27,7 @@ def __init__(self, trainer):
self.trainer = trainer
self.backend = None

def on_trainer_init(self, precision, amp_level, amp_backend):
def on_trainer_init(self, precision, amp_level, amp_backend, plugins):
# AMP init
# These are the only lines needed after v0.8.0
# we wrap the user's forward with autocast and give it back at the end of fit
Expand All @@ -33,14 +36,14 @@ def on_trainer_init(self, precision, amp_level, amp_backend):
self.trainer.scaler = None

self.trainer.amp_level = amp_level
self.init_amp(amp_backend)
self.init_amp(amp_backend, plugins)

def init_amp(self, amp_type: str):
def init_amp(self, amp_type: str, plugins: Optional[list]):
assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported'
self.trainer.amp_backend = None
self._setup_amp_backend(amp_type)
self._setup_amp_backend(amp_type, plugins)

def _setup_amp_backend(self, amp_type: str):
def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]):
if self.trainer.precision != 16:
# no AMP requested, so we can leave now
return
Expand All @@ -54,9 +57,14 @@ def _setup_amp_backend(self, amp_type: str):
' We will attempt to use NVIDIA Apex for this session.')
amp_type = 'apex'
else:
log.info('Using native 16bit precision.')
self.trainer.amp_backend = AMPType.NATIVE
self.backend = NativeAMPPlugin(self.trainer)
log.info('Using native 16bit precision.')

if plugins and self._sharded_in_plugins(plugins):
log.info('Using Sharded 16bit plugin.')
self.backend = ShardedNativeAMPPlugin(self.trainer)
else:
self.backend = NativeAMPPlugin(self.trainer)

if amp_type == 'apex':
if not APEX_AVAILABLE:
Expand All @@ -79,3 +87,14 @@ def connect(self, model):
self.trainer.optimizers = optimizers

return model

@property
def scaler(self):
if self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16 and not self.trainer.use_tpu:
return self.backend.scaler

def _sharded_in_plugins(self, plugins):
for plugin in plugins:
if isinstance(plugin, DDPShardedPlugin):
return True
return False
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def __init__(
)

# set precision
self.precision_connector.on_trainer_init(precision, amp_level, amp_backend)
self.precision_connector.on_trainer_init(precision, amp_level, amp_backend, plugins)

# last thing are the plugins which override whatever the trainer used by default
self.plugin_connector.on_trainer_init(plugins)
Expand Down
Loading