1+ # Copyright The PyTorch Lightning team.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+ from typing import Any , Callable , Iterable , Optional , Union
15+
16+ import torch
17+ from torch .optim import Optimizer
18+
19+ from pytorch_lightning .core import LightningModule
20+ from pytorch_lightning .plugins import TrainingTypePlugin
21+ from pytorch_lightning .utilities .apply_func import move_data_to_device
22+ from pytorch_lightning .utilities .enums import LightningEnum
23+
24+
25+ class Accelerator (object ):
26+ """
27+ The Accelerator Base Class.
28+ An Accelerator is meant to deal with one type of Hardware.
29+
30+ Currently there are accelerators for:
31+ - CPU
32+ - GPU
33+ - TPU
34+
35+ Each Accelerator gets two plugins upon initialization:
36+ One to handle differences from the training routine and one to handle different precisions.
37+
38+ """
39+
40+ def __init__ (
41+ self ,
42+ precision_plugin , #: PrecisionPlugin # fixme
43+ training_type_plugin : TrainingTypePlugin ,
44+ ) -> None :
45+ """
46+
47+ Args:
48+ precision_plugin: the plugin to handle precision-specific parts
49+ training_type_plugin: the plugin to handle different training routines
50+ """
51+ self .precision_plugin = precision_plugin
52+ self .training_type_plugin = training_type_plugin
53+
54+ self .optimizers = None
55+ self .lr_schedulers = None
56+ self .optimizer_frequencies = None
57+
58+ def setup (self , trainer : "Trainer" , model : LightningModule ) -> None :
59+ """
60+ Connects the plugins to the training process, creates optimizers
61+
62+ Args:
63+ trainer: the trainer instance to connect to
64+ model: the model to train
65+ """
66+ self .connect_training_type_plugin (self .training_type_plugin , model )
67+ self .setup_optimizers (trainer , model )
68+ self .connect_precision_plugin (self .precision_plugin )
69+ self .optimizers = trainer .convert_to_lightning_optimizers (self .optimizers )
70+
71+ @property
72+ def model (self ) -> torch .nn .Module :
73+ """Returns the model. This can also be a wrapped LightningModule.
74+ For retrieving the pure LightningModule use :attr:`Accelerator.lightning_module`
75+
76+ """
77+ return self .training_type_plugin .model
78+
79+ @model .setter
80+ def model (self , new_model : torch .nn .Module ) -> None :
81+ self .training_type_plugin .model = new_model
82+
83+ @property
84+ def lightning_module (self ) -> LightningModule :
85+ """Returns the pure LightningModule.
86+ To get the potentially wrapped model use :attr:`Accelerator.model`
87+
88+ """
89+ return self .training_type_plugin .lightning_module
90+
91+ @property
92+ def root_device (self ) -> torch .device :
93+ return self .training_type_plugin .root_device
94+
95+ def teardown (self ):
96+ """This method is called to teardown the training process.
97+ It is the right place to release memory and free other ressources.
98+ """
99+ pass
100+
101+ def batch_to_device (self , batch : Any , device : torch .device ) -> Any :
102+ """Moves the batch to the correct device.
103+ The returned batch is of the same type as the input batch, just having all tensors on the correct device.
104+
105+ Args:
106+ batch: The batch of samples to move to the correct device
107+ device: The target device
108+ """
109+ model = self .lightning_module
110+ if model is not None :
111+ return model .transfer_batch_to_device (batch , device )
112+ return move_data_to_device (batch , device )
113+
114+ def on_train_start (self ):
115+ """Hook to do something upon the training start"""
116+ pass
117+
118+ def training_step (self , args ):
119+ """The actual training step.
120+
121+ Args:
122+ args: the arguments for the models training step. Can consist of the following:
123+ batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
124+ The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
125+ batch_idx (int): Integer displaying index of this batch
126+ optimizer_idx (int): When using multiple optimizers, this argument will also be present.
127+ hiddens(:class:`~torch.Tensor`): Passed in if
128+ :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
129+
130+ """
131+ batch = self .to_device (args [0 ])
132+
133+ args [0 ] = batch
134+
135+ with self .precision_plugin .train_step_context ():
136+ with self .training_type_plugin .train_step_context ():
137+ return self .lightning_module .training_step (* args )
138+
139+ def validation_step (self , args ):
140+ """The actual validation step.
141+
142+ Args:
143+ args: the arguments for the models validation step. Can consist of the following:
144+ batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
145+ The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
146+ batch_idx (int): The index of this batch
147+ dataloader_idx (int): The index of the dataloader that produced this batch
148+ (only if multiple val dataloaders used)
149+ """
150+ batch = self .to_device (args [0 ])
151+
152+ args [0 ] = batch
153+
154+ with self .precision_plugin .val_step_context ():
155+ with self .training_type_plugin .val_step_context ():
156+ return self .lightning_module .validation_step (* args )
157+
158+ def test_step (self , args ):
159+ """The actual test step.
160+
161+ Args:
162+ args: the arguments for the models test step. Can consist of the following:
163+ batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
164+ The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
165+ batch_idx (int): The index of this batch.
166+ dataloader_idx (int): The index of the dataloader that produced this batch
167+ (only if multiple test dataloaders used).
168+ """
169+ batch = self .to_device (args [0 ])
170+
171+ args [0 ] = batch
172+
173+ with self .precision_plugin .test_step_context ():
174+ with self .training_type_plugin .test_step_context ():
175+ return self .lightning_module .test_step (* args )
176+
177+ def training_step_end (self , output ):
178+ """A hook to do something at the end of the training step
179+
180+ Args:
181+ output: the output of the training step
182+ """
183+ return output
184+
185+ def test_step_end (self , output ):
186+ """A hook to do something at the end of the test step
187+
188+ Args:
189+ output: the output of the test step
190+ """
191+ return output
192+
193+ def validation_step_end (self , output ):
194+ """A hook to do something at the end of the validation step
195+
196+ Args:
197+ output: the output of the validation step
198+ """
199+ return output
200+
201+ def process_dataloader (
202+ self , dataloader : Union [Iterable , torch .utils .data .DataLoader ]
203+ ) -> Union [Iterable , torch .utils .data .DataLoader ]:
204+ """Wraps the dataloader if necessary
205+
206+ Args:
207+ dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
208+ """
209+ return dataloader
210+
211+ def backward (
212+ self ,
213+ closure_loss : torch .Tensor ,
214+ optimizer : torch .optim .Optimizer ,
215+ opt_idx : int ,
216+ should_accumulate : bool ,
217+ * args ,
218+ ** kwargs ,
219+ ) -> torch .Tensor :
220+ """Forwards backward-calls to the precision plugin.
221+
222+ Args:
223+ closure_loss: a tensor holding the loss value to backpropagate
224+ optimizer: the optimizer to do the step later on.
225+ opt_idx: the index of the optimizer
226+ should_accumulate: whether to accumulate gradients
227+ """
228+ output = self .precision_plugin .backward (
229+ self .lightning_module , closure_loss , optimizer , opt_idx , should_accumulate , * args , ** kwargs
230+ )
231+
232+ # TODO: this is a hack, find a better solution for this (hook?)
233+ # fixme: uncomment when this class is added
234+ # if isinstance(self.training_type_plugin, HorovodPlugin):
235+ # optimizer.synchronize()
236+
237+ return output
238+
239+ def optimizer_step (
240+ self ,
241+ optimizer : torch .optim .Optimizer ,
242+ current_epoch : int ,
243+ batch_idx : int ,
244+ opt_idx : int ,
245+ lambda_closure : Callable ,
246+ ):
247+ """performs the actual optimizer step.
248+
249+ Args:
250+ optimizer: the optimizer performing the step
251+ current_epoch: current training epoch
252+ batch_idx: index of the current batch
253+ opt_idx: index of the current optimizer
254+ lambda_closure: closure calculating the loss value
255+
256+ """
257+ model_ref = self .lightning_module
258+ is_lbfgs = isinstance (optimizer , torch .optim .LBFGS )
259+ # fixme: uncomment when this class is added
260+ # is_native_amp = (
261+ # isinstance(self.precision_plugin, MixedPrecisionPlugin) and self.precision_plugin.backend == AMPType.NATIVE
262+ # )
263+ is_native_amp = False
264+
265+ self .precision_plugin .pre_optimizer_step (optimizer , opt_idx )
266+ self .training_type_plugin .pre_optimizer_step (optimizer , opt_idx )
267+
268+ # model hook
269+ res = model_ref .optimizer_step (
270+ epoch = current_epoch ,
271+ batch_idx = batch_idx ,
272+ optimizer = optimizer ,
273+ optimizer_idx = opt_idx ,
274+ optimizer_closure = lambda_closure ,
275+ on_tpu = False , # TPUAccelerator class sets this as True
276+ using_native_amp = is_native_amp ,
277+ using_lbfgs = is_lbfgs ,
278+ )
279+
280+ self .precision_plugin .post_optimizer_step (optimizer , opt_idx )
281+ self .training_type_plugin .post_optimizer_step (optimizer , opt_idx )
282+ return res
283+
284+ def optimizer_zero_grad (
285+ self , current_epoch : int , batch_idx : int , optimizer : torch .optim .Optimizer , opt_idx : int
286+ ) -> None :
287+ """Zeros all model parameter's gradients"""
288+ model_ref = self .lightning_module
289+ model_ref .optimizer_zero_grad (current_epoch , batch_idx , optimizer , opt_idx )
290+
291+ def clip_gradients (self , optimizer : torch .optim .Optimizer , clip_val : Union [int , float ]) -> None :
292+ """clips all the optimizer parameters to the given value"""
293+
294+ self .precision_plugin .clip_gradients (optimizer , clip_val )
295+
296+ def on_train_epoch_end (self , outputs ) -> None :
297+ """Hook to do something on the end of an training epoch
298+
299+ Args:
300+ outputs: the outputs of the training steps
301+ """
302+ pass
303+
304+ def on_train_end (self ) -> None :
305+ """Hook to do something at the end of the training"""
306+ pass
307+
308+ def setup_optimizers (self , trainer : "Trainer" , model : LightningModule ):
309+ """creates optimizers and schedulers
310+
311+ Args:
312+ trainer: the Trainer, these optimizers should be connected to
313+ model: the model to be optimized by the created optimizers
314+ """
315+ if trainer .testing is True :
316+ return
317+ optimizers , lr_schedulers , optimizer_frequencies = trainer .init_optimizers (model )
318+ self .optimizers = optimizers
319+ self .lr_schedulers = lr_schedulers
320+ self .optimizer_frequencies = optimizer_frequencies
321+
322+ def connect_training_type_plugin (self , plugin : TrainingTypePlugin , model : LightningModule ) -> None :
323+ """Attaches the training type plugin to the accelerator.
324+ Also transfers ownership of the model to this plugin
325+
326+ """
327+ plugin .connect (model )
328+
329+ def connect_precision_plugin (self , plugin ): #: PrecisionPlugin # fixme
330+ """Attaches the precision plugin to the accelerator"""
331+ model , optimizers , schedulers = plugin .connect (self .model , self .optimizers , self .lr_schedulers )
332+ self .model = model
333+ self .optimizers = optimizers
334+ self .schedulers = schedulers
335+
336+ def to_device (self , batch : Any ) -> Any :
337+ """Pushes the batch to the root device"""
338+ return self .batch_to_device (batch , self .root_device )
339+
340+ @property
341+ def amp_backend (self ) -> Optional [LightningEnum ]:
342+ # fixme: uncomment when this class is added
343+ # if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
344+ # return AMPType.APEX
345+ # elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
346+ # return AMPType.NATIVE
347+ # return None
348+ pass
349+
350+ @property
351+ def precision (self ) -> int :
352+ return self .precision_plugin .precision
353+
354+ @property
355+ def scaler (self ):
356+ if hasattr (self .precision_plugin , "scaler" ):
357+ return self .precision_plugin .scaler
358+
359+ return None
360+
361+ @property
362+ def rpc_enabled (self ) -> bool :
363+ return self .training_type_plugin .rpc_enabled
364+
365+ def optimizer_state (self , optimizer : Optimizer ) -> dict :
366+ """
367+ Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
368+ plugins.
369+ """
370+ if self .training_type_plugin and hasattr (self .training_type_plugin , "optimizer_state" ):
371+ return self .training_type_plugin .optimizer_state (optimizer )
372+ return optimizer .state_dict ()
373+
374+ def on_save (self , checkpoint ):
375+ return checkpoint
0 commit comments