Skip to content

Commit 5d239cc

Browse files
justusschockawaelchliBorda
authored
Base classes for accelerator refactoring (#5715)
* add basic accelerator class. Co-Authored with @awaelchi * Add base plugin class. Co-authored with @awaelchi * add basic trainign type plugin. Co-Authored with @awaelchi * add basic precision plugin. Co-Authored with @awaelchi * Add missing inits. Co-authored with @awaelchi * pep8 Co-authored-by: @awaelchi * ignore flake8 * coverage omit * imports in init * lost * imports * flake8 * . * . * chlog * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent fca9272 commit 5d239cc

File tree

9 files changed

+678
-1
lines changed

9 files changed

+678
-1
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
107107
- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))
108108

109109

110+
- Refactored Accelerators and Plugins (
111+
[#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715),
112+
)
113+
114+
110115
### Deprecated
111116

112117
- Function `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Callable, Iterable, Optional, Union
15+
16+
import torch
17+
from torch.optim import Optimizer
18+
19+
from pytorch_lightning.core import LightningModule
20+
from pytorch_lightning.plugins import TrainingTypePlugin
21+
from pytorch_lightning.utilities.apply_func import move_data_to_device
22+
from pytorch_lightning.utilities.enums import LightningEnum
23+
24+
25+
class Accelerator(object):
26+
"""
27+
The Accelerator Base Class.
28+
An Accelerator is meant to deal with one type of Hardware.
29+
30+
Currently there are accelerators for:
31+
- CPU
32+
- GPU
33+
- TPU
34+
35+
Each Accelerator gets two plugins upon initialization:
36+
One to handle differences from the training routine and one to handle different precisions.
37+
38+
"""
39+
40+
def __init__(
41+
self,
42+
precision_plugin, #: PrecisionPlugin # fixme
43+
training_type_plugin: TrainingTypePlugin,
44+
) -> None:
45+
"""
46+
47+
Args:
48+
precision_plugin: the plugin to handle precision-specific parts
49+
training_type_plugin: the plugin to handle different training routines
50+
"""
51+
self.precision_plugin = precision_plugin
52+
self.training_type_plugin = training_type_plugin
53+
54+
self.optimizers = None
55+
self.lr_schedulers = None
56+
self.optimizer_frequencies = None
57+
58+
def setup(self, trainer: "Trainer", model: LightningModule) -> None:
59+
"""
60+
Connects the plugins to the training process, creates optimizers
61+
62+
Args:
63+
trainer: the trainer instance to connect to
64+
model: the model to train
65+
"""
66+
self.connect_training_type_plugin(self.training_type_plugin, model)
67+
self.setup_optimizers(trainer, model)
68+
self.connect_precision_plugin(self.precision_plugin)
69+
self.optimizers = trainer.convert_to_lightning_optimizers(self.optimizers)
70+
71+
@property
72+
def model(self) -> torch.nn.Module:
73+
"""Returns the model. This can also be a wrapped LightningModule.
74+
For retrieving the pure LightningModule use :attr:`Accelerator.lightning_module`
75+
76+
"""
77+
return self.training_type_plugin.model
78+
79+
@model.setter
80+
def model(self, new_model: torch.nn.Module) -> None:
81+
self.training_type_plugin.model = new_model
82+
83+
@property
84+
def lightning_module(self) -> LightningModule:
85+
"""Returns the pure LightningModule.
86+
To get the potentially wrapped model use :attr:`Accelerator.model`
87+
88+
"""
89+
return self.training_type_plugin.lightning_module
90+
91+
@property
92+
def root_device(self) -> torch.device:
93+
return self.training_type_plugin.root_device
94+
95+
def teardown(self):
96+
"""This method is called to teardown the training process.
97+
It is the right place to release memory and free other ressources.
98+
"""
99+
pass
100+
101+
def batch_to_device(self, batch: Any, device: torch.device) -> Any:
102+
"""Moves the batch to the correct device.
103+
The returned batch is of the same type as the input batch, just having all tensors on the correct device.
104+
105+
Args:
106+
batch: The batch of samples to move to the correct device
107+
device: The target device
108+
"""
109+
model = self.lightning_module
110+
if model is not None:
111+
return model.transfer_batch_to_device(batch, device)
112+
return move_data_to_device(batch, device)
113+
114+
def on_train_start(self):
115+
"""Hook to do something upon the training start"""
116+
pass
117+
118+
def training_step(self, args):
119+
"""The actual training step.
120+
121+
Args:
122+
args: the arguments for the models training step. Can consist of the following:
123+
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
124+
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
125+
batch_idx (int): Integer displaying index of this batch
126+
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
127+
hiddens(:class:`~torch.Tensor`): Passed in if
128+
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
129+
130+
"""
131+
batch = self.to_device(args[0])
132+
133+
args[0] = batch
134+
135+
with self.precision_plugin.train_step_context():
136+
with self.training_type_plugin.train_step_context():
137+
return self.lightning_module.training_step(*args)
138+
139+
def validation_step(self, args):
140+
"""The actual validation step.
141+
142+
Args:
143+
args: the arguments for the models validation step. Can consist of the following:
144+
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
145+
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
146+
batch_idx (int): The index of this batch
147+
dataloader_idx (int): The index of the dataloader that produced this batch
148+
(only if multiple val dataloaders used)
149+
"""
150+
batch = self.to_device(args[0])
151+
152+
args[0] = batch
153+
154+
with self.precision_plugin.val_step_context():
155+
with self.training_type_plugin.val_step_context():
156+
return self.lightning_module.validation_step(*args)
157+
158+
def test_step(self, args):
159+
"""The actual test step.
160+
161+
Args:
162+
args: the arguments for the models test step. Can consist of the following:
163+
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
164+
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
165+
batch_idx (int): The index of this batch.
166+
dataloader_idx (int): The index of the dataloader that produced this batch
167+
(only if multiple test dataloaders used).
168+
"""
169+
batch = self.to_device(args[0])
170+
171+
args[0] = batch
172+
173+
with self.precision_plugin.test_step_context():
174+
with self.training_type_plugin.test_step_context():
175+
return self.lightning_module.test_step(*args)
176+
177+
def training_step_end(self, output):
178+
"""A hook to do something at the end of the training step
179+
180+
Args:
181+
output: the output of the training step
182+
"""
183+
return output
184+
185+
def test_step_end(self, output):
186+
"""A hook to do something at the end of the test step
187+
188+
Args:
189+
output: the output of the test step
190+
"""
191+
return output
192+
193+
def validation_step_end(self, output):
194+
"""A hook to do something at the end of the validation step
195+
196+
Args:
197+
output: the output of the validation step
198+
"""
199+
return output
200+
201+
def process_dataloader(
202+
self, dataloader: Union[Iterable, torch.utils.data.DataLoader]
203+
) -> Union[Iterable, torch.utils.data.DataLoader]:
204+
"""Wraps the dataloader if necessary
205+
206+
Args:
207+
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
208+
"""
209+
return dataloader
210+
211+
def backward(
212+
self,
213+
closure_loss: torch.Tensor,
214+
optimizer: torch.optim.Optimizer,
215+
opt_idx: int,
216+
should_accumulate: bool,
217+
*args,
218+
**kwargs,
219+
) -> torch.Tensor:
220+
"""Forwards backward-calls to the precision plugin.
221+
222+
Args:
223+
closure_loss: a tensor holding the loss value to backpropagate
224+
optimizer: the optimizer to do the step later on.
225+
opt_idx: the index of the optimizer
226+
should_accumulate: whether to accumulate gradients
227+
"""
228+
output = self.precision_plugin.backward(
229+
self.lightning_module, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs
230+
)
231+
232+
# TODO: this is a hack, find a better solution for this (hook?)
233+
# fixme: uncomment when this class is added
234+
# if isinstance(self.training_type_plugin, HorovodPlugin):
235+
# optimizer.synchronize()
236+
237+
return output
238+
239+
def optimizer_step(
240+
self,
241+
optimizer: torch.optim.Optimizer,
242+
current_epoch: int,
243+
batch_idx: int,
244+
opt_idx: int,
245+
lambda_closure: Callable,
246+
):
247+
"""performs the actual optimizer step.
248+
249+
Args:
250+
optimizer: the optimizer performing the step
251+
current_epoch: current training epoch
252+
batch_idx: index of the current batch
253+
opt_idx: index of the current optimizer
254+
lambda_closure: closure calculating the loss value
255+
256+
"""
257+
model_ref = self.lightning_module
258+
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
259+
# fixme: uncomment when this class is added
260+
# is_native_amp = (
261+
# isinstance(self.precision_plugin, MixedPrecisionPlugin) and self.precision_plugin.backend == AMPType.NATIVE
262+
# )
263+
is_native_amp = False
264+
265+
self.precision_plugin.pre_optimizer_step(optimizer, opt_idx)
266+
self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx)
267+
268+
# model hook
269+
res = model_ref.optimizer_step(
270+
epoch=current_epoch,
271+
batch_idx=batch_idx,
272+
optimizer=optimizer,
273+
optimizer_idx=opt_idx,
274+
optimizer_closure=lambda_closure,
275+
on_tpu=False, # TPUAccelerator class sets this as True
276+
using_native_amp=is_native_amp,
277+
using_lbfgs=is_lbfgs,
278+
)
279+
280+
self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
281+
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx)
282+
return res
283+
284+
def optimizer_zero_grad(
285+
self, current_epoch: int, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int
286+
) -> None:
287+
"""Zeros all model parameter's gradients"""
288+
model_ref = self.lightning_module
289+
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
290+
291+
def clip_gradients(self, optimizer: torch.optim.Optimizer, clip_val: Union[int, float]) -> None:
292+
"""clips all the optimizer parameters to the given value"""
293+
294+
self.precision_plugin.clip_gradients(optimizer, clip_val)
295+
296+
def on_train_epoch_end(self, outputs) -> None:
297+
"""Hook to do something on the end of an training epoch
298+
299+
Args:
300+
outputs: the outputs of the training steps
301+
"""
302+
pass
303+
304+
def on_train_end(self) -> None:
305+
"""Hook to do something at the end of the training"""
306+
pass
307+
308+
def setup_optimizers(self, trainer: "Trainer", model: LightningModule):
309+
"""creates optimizers and schedulers
310+
311+
Args:
312+
trainer: the Trainer, these optimizers should be connected to
313+
model: the model to be optimized by the created optimizers
314+
"""
315+
if trainer.testing is True:
316+
return
317+
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(model)
318+
self.optimizers = optimizers
319+
self.lr_schedulers = lr_schedulers
320+
self.optimizer_frequencies = optimizer_frequencies
321+
322+
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
323+
"""Attaches the training type plugin to the accelerator.
324+
Also transfers ownership of the model to this plugin
325+
326+
"""
327+
plugin.connect(model)
328+
329+
def connect_precision_plugin(self, plugin): #: PrecisionPlugin # fixme
330+
"""Attaches the precision plugin to the accelerator"""
331+
model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers)
332+
self.model = model
333+
self.optimizers = optimizers
334+
self.schedulers = schedulers
335+
336+
def to_device(self, batch: Any) -> Any:
337+
"""Pushes the batch to the root device"""
338+
return self.batch_to_device(batch, self.root_device)
339+
340+
@property
341+
def amp_backend(self) -> Optional[LightningEnum]:
342+
# fixme: uncomment when this class is added
343+
# if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
344+
# return AMPType.APEX
345+
# elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
346+
# return AMPType.NATIVE
347+
# return None
348+
pass
349+
350+
@property
351+
def precision(self) -> int:
352+
return self.precision_plugin.precision
353+
354+
@property
355+
def scaler(self):
356+
if hasattr(self.precision_plugin, "scaler"):
357+
return self.precision_plugin.scaler
358+
359+
return None
360+
361+
@property
362+
def rpc_enabled(self) -> bool:
363+
return self.training_type_plugin.rpc_enabled
364+
365+
def optimizer_state(self, optimizer: Optimizer) -> dict:
366+
"""
367+
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
368+
plugins.
369+
"""
370+
if self.training_type_plugin and hasattr(self.training_type_plugin, "optimizer_state"):
371+
return self.training_type_plugin.optimizer_state(optimizer)
372+
return optimizer.state_dict()
373+
374+
def on_save(self, checkpoint):
375+
return checkpoint
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401
2+
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
3+
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401

0 commit comments

Comments
 (0)