Skip to content

Commit cea000d

Browse files
Ref/accelerator connector (#5742)
* final cleanup Co-authored-by: Adrian Wälchli <[email protected]> * connector cleanup Co-authored-by: Adrian Wälchli <[email protected]> * trainer cleanup Co-authored-by: Adrian Wälchli <[email protected]> * accelerator cleanup + missing logic in accelerator connector Co-authored-by: Adrian Wälchli <[email protected]> * add missing changes to callbacks Co-authored-by: Adrian Wälchli <[email protected]> * reflect accelerator changes to lightning module Co-authored-by: Adrian Wälchli <[email protected]> * clean cluster envs Co-authored-by: Adrian Wälchli <[email protected]> * cleanup plugins Co-authored-by: Adrian Wälchli <[email protected]> * add broadcasting Co-authored-by: Adrian Wälchli <[email protected]> * yapf * remove plugin connector Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 81001e3 commit cea000d

File tree

14 files changed

+141
-50
lines changed

14 files changed

+141
-50
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,6 @@ def optimizer_state(self, optimizer: Optimizer) -> dict:
374374

375375
def on_save(self, checkpoint):
376376
return checkpoint
377+
378+
def barrier(self, name: Optional[str] = None) -> None:
379+
self.training_type_plugin.barrier(name=name)

pytorch_lightning/accelerators/accelerator_connector.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16+
from typing import Optional, Sequence
1617

1718
import torch
1819

@@ -26,15 +27,21 @@
2627
DataParallelPlugin,
2728
DDP2Plugin,
2829
DDPPlugin,
30+
DDPShardedPlugin,
2931
DDPSpawnPlugin,
32+
DDPSpawnShardedPlugin,
3033
HorovodPlugin,
3134
NativeMixedPrecisionPlugin,
3235
PrecisionPlugin,
36+
RPCPlugin,
3337
ShardedNativeMixedPrecisionPlugin,
3438
SingleDevicePlugin,
3539
SingleTPUPlugin,
3640
TPUHalfPrecisionPlugin,
37-
TPUSpawnPlugin, DDPShardedPlugin, DDPSpawnShardedPlugin,
41+
TPUSpawnPlugin,
42+
TrainingTypePlugin,
43+
DDPShardedPlugin,
44+
DDPSpawnShardedPlugin,
3845
)
3946
from pytorch_lightning.plugins.environments import SLURMEnvironment, TorchElasticEnvironment
4047
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
@@ -74,6 +81,7 @@ def __init__(
7481
amp_type,
7582
amp_level,
7683
cluster_environment,
84+
plugins,
7785
):
7886
# initialization
7987
self._device_type = DeviceType.CPU
@@ -95,6 +103,11 @@ def __init__(
95103
self.cluster_environment = cluster_environment
96104
self.is_slurm_managing_tasks = False
97105

106+
self._precision_plugin: Optional[PrecisionPlugin] = None
107+
self._training_type_plugin: Optional[TrainingTypePlugin] = None
108+
109+
self.handle_given_plugins(plugins)
110+
98111
# init the default rank if exists
99112
# we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
100113
# this way we only show it on rank 0
@@ -136,6 +149,56 @@ def __init__(
136149

137150
self.replace_sampler_ddp = replace_sampler_ddp
138151

152+
def handle_given_plugins(self, plugins: Optional[Sequence]):
153+
if plugins is None:
154+
return
155+
156+
if not isinstance(plugins, Sequence):
157+
plugins = [plugins]
158+
159+
training_type = None
160+
precision = None
161+
162+
for plug in plugins:
163+
if isinstance(plug, TrainingTypePlugin):
164+
if training_type is None:
165+
training_type = plug
166+
else:
167+
raise MisconfigurationException(
168+
'You can only specify one precision and one training type plugin. '
169+
'Found more than 1 training type plugin'
170+
)
171+
elif isinstance(plug, PrecisionPlugin):
172+
if precision is None:
173+
precision = plug
174+
else:
175+
raise MisconfigurationException(
176+
'You can only specify one precision and one training type plugin. '
177+
'Found more than 1 precision plugin'
178+
)
179+
else:
180+
raise MisconfigurationException(
181+
f'Found invalid type for plugin {plug}. '
182+
'Expected a precision or training type plugin.'
183+
)
184+
185+
self._training_type_plugin = training_type
186+
self._precision_plugin = precision
187+
188+
@property
189+
def precision_plugin(self) -> PrecisionPlugin:
190+
if self._precision_plugin is None:
191+
self._precision_plugin = self.select_precision_plugin()
192+
193+
return self._precision_plugin
194+
195+
@property
196+
def training_type_plugin(self) -> TrainingTypePlugin:
197+
if self._training_type_plugin is None:
198+
self._training_type_plugin = self.select_training_type_plugin()
199+
200+
return self._training_type_plugin
201+
139202
@property
140203
def on_cpu(self):
141204
return self._device_type == DeviceType.CPU
@@ -205,6 +268,9 @@ def select_precision_plugin(self):
205268
if self.on_tpu:
206269
return TPUHalfPrecisionPlugin()
207270

271+
if isinstance(self.training_type_plugin, RPCPlugin):
272+
raise MisconfigurationException
273+
208274
if self.amp_type == "native":
209275
if not _NATIVE_AMP_AVAILABLE:
210276
rank_zero_warn(
@@ -215,7 +281,7 @@ def select_precision_plugin(self):
215281
self.amp_type = "apex"
216282
else:
217283
log.info("Using native 16bit precision.")
218-
if self.distributed_backend == "ddp_sharded" or self.distributed_backend == "ddp_sharded_spawn":
284+
if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)):
219285
return ShardedNativeMixedPrecisionPlugin()
220286
self.amp_type = AMPType.NATIVE
221287
return NativeMixedPrecisionPlugin()
@@ -227,7 +293,7 @@ def select_precision_plugin(self):
227293
" Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
228294
)
229295
else:
230-
if self.distributed_backend == "ddp_sharded" or self.distributed_backend == "ddp_sharded_spawn":
296+
if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)):
231297
raise MisconfigurationException(
232298
"Sharded Plugin is not supported with Apex AMP, "
233299
"please using native AMP for 16-bit precision."
@@ -289,6 +355,12 @@ def select_training_type_plugin(self):
289355
def select_accelerator(self):
290356
if isinstance(self.distributed_backend, Accelerator):
291357
# custom accelerator from user
358+
if self._precision_plugin is not None or self._training_type_plugin is not None:
359+
# plugins also specified by user
360+
rank_zero_warn(
361+
'Specified Precision and TrainingType Plugins will be ignored, '
362+
'since an Accelerator instance was provided'
363+
)
292364
return self.distributed_backend
293365

294366
if self.on_gpu:
@@ -299,8 +371,8 @@ def select_accelerator(self):
299371
acc_cls = CPUAccelerator
300372

301373
return acc_cls(
302-
precision_plugin=self.select_precision_plugin(),
303-
training_type_plugin=self.select_training_type_plugin(),
374+
precision_plugin=self.precision_plugin,
375+
training_type_plugin=self.training_type_plugin,
304376
)
305377

306378
def select_cluster_environment(self):

pytorch_lightning/accelerators/tpu.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
1+
from typing import Callable
2+
3+
import torch
4+
15
from pytorch_lightning.accelerators.accelerator import Accelerator
26
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
37
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
48
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
9+
from pytorch_lightning.utilities import _XLA_AVAILABLE
510
from pytorch_lightning.utilities.exceptions import MisconfigurationException
611

12+
if _XLA_AVAILABLE:
13+
import torch_xla.core.xla_model as xm
14+
715

816
class TPUAccelerator(Accelerator):
917

@@ -17,3 +25,16 @@ def setup(self, trainer, model):
1725
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
1826
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
1927
return super().setup(trainer, model)
28+
29+
def optimizer_step(
30+
self, optimizer: torch.optim.Optimizer, current_epoch: int, batch_idx: int, opt_idx: int,
31+
lambda_closure: Callable
32+
):
33+
34+
self.precision_plugin.pre_optimizer_step(optimizer, opt_idx)
35+
self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx)
36+
37+
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure})
38+
39+
self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
40+
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx)

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def _run_early_stopping_check(self, trainer, pl_module):
196196
if self.monitor_op(current - self.min_delta, self.best_score):
197197
self.best_score = current
198198
self.wait_count = 0
199+
should_stop = False
199200
else:
200201
self.wait_count += 1
201202
should_stop = self.wait_count >= self.patience

pytorch_lightning/core/lightning.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,6 @@ def log(
275275
raise MisconfigurationException(
276276
f"Logged key: {name} should not contain information about dataloader_idx.")
277277

278-
accelerator = self.trainer.accelerator_backend
279278
training_type_plugin = self.trainer.training_type_plugin
280279

281280
self._results.log(

pytorch_lightning/core/optimizer.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType
2121
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2222

23-
if _TPU_AVAILABLE:
24-
import torch_xla.core.xla_model as xm
25-
2623

2724
def is_lightning_optimizer(optimizer):
2825
return isinstance(optimizer, LightningOptimizer)
@@ -133,18 +130,10 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n
133130
optimizer = self._optimizer
134131
model = trainer.get_model()
135132

136-
if trainer._device_type == DeviceType.TPU:
137-
with trainer.profiler.profile(profiler_name):
138-
xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs})
139-
140-
# elif trainer.amp_backend is not None:
141-
# # TODO: Adapt for new optimizer structure
142-
# trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure)
143-
144-
else:
145-
with trainer.profiler.profile(profiler_name):
146-
optimizer.step(closure=closure, *args, **kwargs)
147-
133+
with trainer.profiler.profile(profiler_name):
134+
trainer.accelerator_backend.optimizer_step(*args, lambda_closure=closure, **kwargs)
135+
136+
# TODO: Do we need this?
148137
accelerator_backend = trainer.accelerator_backend
149138
if accelerator_backend is not None and accelerator_backend.rpc_enabled:
150139
if accelerator_backend.ddp_plugin.is_main_rpc_process:

pytorch_lightning/plugins/__init__.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
1212
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
1313
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
14+
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
15+
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401
16+
from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin # noqa: F401
17+
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401
1418
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401
1519
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401
1620
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401
@@ -19,17 +23,8 @@
1923
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401
2024

2125
__all__ = [
22-
"ApexMixedPrecisionPlugin",
23-
"DataParallelPlugin",
24-
"DDP2Plugin",
25-
"DDPPlugin",
26-
"DDPSpawnPlugin",
27-
"HorovodPlugin",
28-
"NativeMixedPrecisionPlugin",
29-
"PrecisionPlugin",
30-
"ShardedNativeMixedPrecisionPlugin",
31-
"SingleDevicePlugin",
32-
"SingleTPUPlugin",
33-
"TPUHalfPrecisionPlugin",
34-
"TPUSpawnPlugin",
26+
"ApexMixedPrecisionPlugin", "DataParallelPlugin", "DDP2Plugin", "DDPPlugin", "DDPSpawnPlugin", "HorovodPlugin",
27+
"NativeMixedPrecisionPlugin", "PrecisionPlugin", "ShardedNativeMixedPrecisionPlugin", "SingleDevicePlugin",
28+
"SingleTPUPlugin", "TPUHalfPrecisionPlugin", "TPUSpawnPlugin", 'RPCPlugin', 'RPCSequentialPlugin'
29+
'TrainingTypePlugin', 'ParallelPlugin', 'Plugin', 'DDPShardedPlugin', 'DDPSpawnShardedPlugin'
3530
]

pytorch_lightning/plugins/training_type/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin
55
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
66
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
7+
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin
8+
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin
79
from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin
810
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin
911
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin

pytorch_lightning/plugins/training_type/parallel.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import io
1415
from abc import ABC, abstractmethod
1516
from contextlib import contextmanager
1617
from typing import List, Optional
@@ -22,7 +23,7 @@
2223
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
2324
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
2425
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
25-
from pytorch_lightning.utilities.distributed import ReduceOp
26+
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp
2627

2728

2829
class ParallelPlugin(TrainingTypePlugin, ABC):
@@ -102,3 +103,13 @@ def block_backward_sync(self):
102103
yield self.model.no_sync()
103104
else:
104105
yield None
106+
107+
def broadcast(self, obj: object, src: int) -> object:
108+
buffer = io.BytesIO()
109+
torch.save(obj, buffer)
110+
data = bytearray(buffer.getbuffer())
111+
data_tensor = torch.tensor(data).to(self.root_device, dtype=torch.float)
112+
data = all_gather_ddp_if_available(data_tensor)
113+
buffer = io.BytesIO(data.cpu().byte().numpy())
114+
obj = torch.load(buffer)
115+
return obj

pytorch_lightning/trainer/data_loading.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
9393
return dataloader
9494

9595
is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu
96+
9697
need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler)
9798
if self.accelerator_connector.replace_sampler_ddp and need_dist_sampler:
9899
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):

0 commit comments

Comments
 (0)