1313# limitations under the License.
1414import contextlib
1515from abc import abstractmethod
16- from typing import Any , Callable , Dict , Generator , List , Optional , Union
16+ from typing import Any , Dict , Generator , Optional , Union
1717
1818import torch
19- from torch import Tensor
20- from torch .cuda .amp import GradScaler
2119from torch .nn import Module
22- from torch .optim import Optimizer
2320
2421import pytorch_lightning as pl
25- from pytorch_lightning .plugins .precision import ApexMixedPrecisionPlugin , NativeMixedPrecisionPlugin , PrecisionPlugin
22+ from pytorch_lightning .plugins .precision import PrecisionPlugin
2623from pytorch_lightning .plugins .training_type import TrainingTypePlugin
27- from pytorch_lightning .trainer .states import TrainerFn
28- from pytorch_lightning .utilities import rank_zero_deprecation
29- from pytorch_lightning .utilities .apply_func import apply_to_collection , move_data_to_device
30- from pytorch_lightning .utilities .enums import AMPType , LightningEnum
3124from pytorch_lightning .utilities .types import STEP_OUTPUT
3225
3326
@@ -62,10 +55,6 @@ def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_pl
6255 if precision_plugin is not None :
6356 self .training_type_plugin ._precision_plugin = precision_plugin
6457
65- self .optimizers : List = []
66- self .lr_schedulers : List = []
67- self .optimizer_frequencies : List = []
68-
6958 def setup_environment (self ) -> None :
7059 """Setup any processes or distributed connections.
7160
@@ -80,28 +69,18 @@ def setup(self, trainer: "pl.Trainer") -> None:
8069 Args:
8170 trainer: the trainer instance
8271 """
83- self .setup_training_type_plugin ()
84- if not self .training_type_plugin .setup_optimizers_in_pre_dispatch :
85- self .setup_optimizers (trainer )
86- self .setup_precision_plugin ()
72+ self .training_type_plugin .setup (trainer )
8773
8874 def pre_dispatch (self , trainer : "pl.Trainer" ) -> None :
8975 """Hook to do something before the training/evaluation/prediction starts."""
90- self ._move_optimizer_state ()
76+ self .training_type_plugin . _move_optimizer_state ()
9177
9278 self .training_type_plugin .pre_dispatch ()
9379 if self .training_type_plugin .setup_optimizers_in_pre_dispatch :
94- self .setup_optimizers (trainer )
80+ self .training_type_plugin . setup_optimizers (trainer )
9581
9682 self .training_type_plugin .precision_plugin .pre_dispatch ()
9783
98- def _move_optimizer_state (self , device : Optional [torch .device ] = None ) -> None :
99- """Moves the state of the optimizers to the GPU if needed."""
100- device = device or self .root_device
101- for opt in self .optimizers :
102- for p , v in opt .state .items ():
103- opt .state [p ] = apply_to_collection (v , torch .Tensor , move_data_to_device , device )
104-
10584 def dispatch (self , trainer : "pl.Trainer" ) -> None :
10685 """Hook to do something before the training/evaluation/prediction starts."""
10786 self .training_type_plugin .dispatch (trainer )
@@ -177,115 +156,12 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
177156 with self .training_type_plugin .precision_plugin .predict_step_context ():
178157 return self .training_type_plugin .predict_step (* step_kwargs .values ())
179158
180- def backward (self , closure_loss : Tensor , * args : Any , ** kwargs : Any ) -> Tensor :
181- """Forwards backward-calls to the precision plugin.
182-
183- Args:
184- closure_loss: a tensor holding the loss value to backpropagate
185- """
186- self .training_type_plugin .pre_backward (closure_loss )
187- closure_loss = self .training_type_plugin .precision_plugin .pre_backward (self .lightning_module , closure_loss )
188-
189- self .training_type_plugin .precision_plugin .backward (self .lightning_module , closure_loss , * args , ** kwargs )
190-
191- closure_loss = self .training_type_plugin .precision_plugin .post_backward (self .lightning_module , closure_loss )
192- self .training_type_plugin .post_backward (closure_loss )
193-
194- return closure_loss
195-
196- def optimizer_step (
197- self ,
198- optimizer : Optimizer ,
199- opt_idx : int ,
200- closure : Callable [[], Any ],
201- model : Optional [Union ["pl.LightningModule" , Module ]] = None ,
202- ** kwargs : Any ,
203- ) -> None :
204- """performs the actual optimizer step.
205-
206- Args:
207- optimizer: the optimizer performing the step
208- opt_idx: index of the current optimizer
209- closure: closure calculating the loss value
210- model: reference to the model, optionally defining optimizer step related hooks
211- **kwargs: Any extra arguments to ``optimizer.step``
212- """
213- model = model or self .lightning_module
214- self .training_type_plugin .precision_plugin .optimizer_step (model , optimizer , opt_idx , closure , ** kwargs )
215-
216- def optimizer_zero_grad (self , current_epoch : int , batch_idx : int , optimizer : Optimizer , opt_idx : int ) -> None :
217- """Zeros all model parameter's gradients."""
218- model_ref = self .lightning_module
219- model_ref .optimizer_zero_grad (current_epoch , batch_idx , optimizer , opt_idx )
220-
221- def setup_optimizers (self , trainer : "pl.Trainer" ) -> None :
222- """Creates optimizers and schedulers.
223-
224- Args:
225- trainer: the Trainer, these optimizers should be connected to
226- """
227- if trainer .state .fn not in (TrainerFn .FITTING , TrainerFn .TUNING ):
228- return
229- optimizers , lr_schedulers , optimizer_frequencies = self .training_type_plugin .init_optimizers (
230- trainer = trainer , model = self .lightning_module
231- )
232- self .optimizers = optimizers
233- self .lr_schedulers = lr_schedulers
234- self .optimizer_frequencies = optimizer_frequencies
235-
236- def setup_training_type_plugin (self ) -> None :
237- """Attaches the training type plugin to the accelerator."""
238- self .training_type_plugin .setup ()
239-
240- def setup_precision_plugin (self ) -> None :
241- """Attaches the precision plugin to the accelerator."""
242- model , optimizers , schedulers = self .training_type_plugin .precision_plugin .connect (
243- self .model , self .optimizers , self .lr_schedulers
244- )
245- self .model = model
246- self .optimizers = optimizers
247- self .lr_schedulers = schedulers
248-
249- @property
250- def amp_backend (self ) -> Optional [LightningEnum ]:
251- if isinstance (self .training_type_plugin .precision_plugin , ApexMixedPrecisionPlugin ):
252- return AMPType .APEX
253- if isinstance (self .training_type_plugin .precision_plugin , NativeMixedPrecisionPlugin ):
254- return AMPType .NATIVE
255- return None
256-
257- @property
258- def precision (self ) -> Union [str , int ]:
259- """The type of precision being used with this accelerator.
260-
261- .. deprecated::
262- This property been deprecated and will be removed soon.
263- Use ``training_type_plugin.precision_plugin.precision`` instead.
264- """
265- rank_zero_deprecation (
266- f"`{ self .__class__ .__name__ } .precision` has been deprecated and will be removed soon"
267- f" Use `training_type_plugin.precision_plugin.precision` instead."
268- )
269- return self .training_type_plugin .precision_plugin .precision
270-
271- @property
272- def scaler (self ) -> Optional ["GradScaler" ]:
273- return getattr (self .training_type_plugin .precision_plugin , "scaler" , None )
274-
275- def optimizer_state (self , optimizer : Optimizer ) -> Dict [str , Tensor ]:
276- """Returns state of an optimizer.
277-
278- Allows for syncing/collating optimizer state from processes in custom plugins.
279- """
280- return getattr (self .training_type_plugin , "optimizer_state" , lambda x : x .state_dict ())(optimizer )
281-
282159 @contextlib .contextmanager
283160 def model_sharded_context (self ) -> Generator [None , None , None ]:
284161 """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to.
285162
286163 shard the model instantly - useful for extremely large models. Can save memory and
287164 initialization time.
288-
289165 Returns:
290166 Model parallel context.
291167 """
0 commit comments