1515from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Sequence , Union
1616
1717import torch
18+ from torch import Tensor
19+ from torch .nn import Module
1820from torch .optim import Optimizer
1921from torch .utils .data import DataLoader
2022
2123import pytorch_lightning as pl
22- from pytorch_lightning .core import LightningModule
2324from pytorch_lightning .plugins .precision import ApexMixedPrecisionPlugin , NativeMixedPrecisionPlugin , PrecisionPlugin
2425from pytorch_lightning .plugins .training_type import TrainingTypePlugin
2526from pytorch_lightning .trainer .states import TrainerState
26- from pytorch_lightning .utilities import rank_zero_warn
27+ from pytorch_lightning .utilities import _NATIVE_AMP_AVAILABLE , rank_zero_warn
2728from pytorch_lightning .utilities .apply_func import move_data_to_device
2829from pytorch_lightning .utilities .enums import AMPType , GradClipAlgorithmType , LightningEnum
2930
31+ if _NATIVE_AMP_AVAILABLE :
32+ from torch .cuda .amp import GradScaler
33+
3034_STEP_OUTPUT_TYPE = Union [torch .Tensor , Dict [str , torch .Tensor ], None ]
3135
3236
33- class Accelerator ( object ) :
37+ class Accelerator :
3438 """
3539 The Accelerator Base Class.
3640 An Accelerator is meant to deal with one type of Hardware.
@@ -52,7 +56,6 @@ def __init__(
5256 training_type_plugin : TrainingTypePlugin ,
5357 ) -> None :
5458 """
55-
5659 Args:
5760 precision_plugin: the plugin to handle precision-specific parts
5861 training_type_plugin: the plugin to handle different training routines
@@ -64,7 +67,7 @@ def __init__(
6467 self .lr_schedulers : Sequence = []
6568 self .optimizer_frequencies : Sequence = []
6669
67- def connect (self , model : LightningModule ) -> None :
70+ def connect (self , model : 'pl. LightningModule' ) -> None :
6871 """Transfers ownership of the model to this plugin"""
6972 self .training_type_plugin .connect (model )
7073
@@ -76,7 +79,7 @@ def setup_environment(self) -> None:
7679 """
7780 self .training_type_plugin .setup_environment ()
7881
79- def setup (self , trainer : 'pl.Trainer' , model : LightningModule ) -> None :
82+ def setup (self , trainer : 'pl.Trainer' , model : 'pl. LightningModule' ) -> None :
8083 """
8184 Setup plugins for the trainer fit and creates optimizers.
8285
@@ -111,22 +114,22 @@ def post_dispatch(self, trainer: 'pl.Trainer') -> None:
111114 self .precision_plugin .post_dispatch ()
112115
113116 @property
114- def model (self ) -> torch .nn .Module :
115- """Returns the model. This can also be a wrapped LightningModule.
117+ def model (self ) -> Module :
118+ """
119+ Returns the model. This can also be a wrapped LightningModule.
116120 For retrieving the pure LightningModule use :attr:`Accelerator.lightning_module`
117-
118121 """
119122 return self .training_type_plugin .model
120123
121124 @model .setter
122- def model (self , new_model : torch . nn . Module ) -> None :
125+ def model (self , new_model : Module ) -> None :
123126 self .training_type_plugin .model = new_model
124127
125128 @property
126- def lightning_module (self ) -> LightningModule :
127- """Returns the pure LightningModule.
129+ def lightning_module (self ) -> 'pl.LightningModule' :
130+ """
131+ Returns the pure LightningModule.
128132 To get the potentially wrapped model use :attr:`Accelerator.model`
129-
130133 """
131134 return self .training_type_plugin .lightning_module
132135
@@ -135,7 +138,8 @@ def root_device(self) -> torch.device:
135138 return self .training_type_plugin .root_device
136139
137140 def teardown (self ) -> None :
138- """This method is called to teardown the training process.
141+ """
142+ This method is called to teardown the training process.
139143 It is the right place to release memory and free other ressources.
140144 """
141145 pass
@@ -268,13 +272,13 @@ def validation_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
268272
269273 def backward (
270274 self ,
271- closure_loss : torch . Tensor ,
275+ closure_loss : Tensor ,
272276 optimizer : Optimizer ,
273277 optimizer_idx : int ,
274278 should_accumulate : bool ,
275279 * args : Any ,
276280 ** kwargs : Any ,
277- ) -> torch . Tensor :
281+ ) -> Tensor :
278282 """Forwards backward-calls to the precision plugin.
279283
280284 Args:
@@ -325,9 +329,7 @@ def clip_gradients(
325329 gradient_clip_algorithm : GradClipAlgorithmType = GradClipAlgorithmType .NORM ,
326330 ) -> None :
327331 """clips all the optimizer parameters to the given value"""
328- self .precision_plugin .clip_gradients (
329- self .model , optimizer , clip_val , gradient_clip_algorithm = gradient_clip_algorithm
330- )
332+ self .precision_plugin .clip_gradients (optimizer , clip_val , gradient_clip_algorithm = gradient_clip_algorithm )
331333
332334 def on_train_epoch_end (self , outputs : Sequence [_STEP_OUTPUT_TYPE ]) -> None :
333335 """Hook to do something on the end of an training epoch
@@ -342,11 +344,11 @@ def on_train_end(self) -> None:
342344 pass
343345
344346 def setup_optimizers (self , trainer : 'pl.Trainer' ) -> None :
345- """creates optimizers and schedulers
347+ """
348+ Creates optimizers and schedulers
346349
347350 Args:
348351 trainer: the Trainer, these optimizers should be connected to
349- model: the model to be optimized by the created optimizers
350352 """
351353 if trainer .state not in (TrainerState .FITTING , TrainerState .TUNING ):
352354 return
@@ -357,7 +359,7 @@ def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
357359 self .lr_schedulers = lr_schedulers
358360 self .optimizer_frequencies = optimizer_frequencies
359361
360- def setup_training_type_plugin (self , plugin : TrainingTypePlugin , model : LightningModule ) -> None :
362+ def setup_training_type_plugin (self , plugin : TrainingTypePlugin , model : 'pl. LightningModule' ) -> None :
361363 """Attaches the training type plugin to the accelerator."""
362364 plugin .setup (model )
363365
@@ -390,22 +392,21 @@ def precision(self) -> Union[str, int]:
390392 return self .precision_plugin .precision
391393
392394 @property
393- def scaler (self ) -> Optional ['torch.cuda.amp.GradScaler' ]:
394-
395+ def scaler (self ) -> Optional ['GradScaler' ]:
395396 return getattr (self .precision_plugin , 'scaler' , None )
396397
397398 @property
398399 def rpc_enabled (self ) -> bool :
399400 return self .training_type_plugin .rpc_enabled
400401
401- def optimizer_state (self , optimizer : Optimizer ) -> Dict [str , torch . Tensor ]:
402+ def optimizer_state (self , optimizer : Optimizer ) -> Dict [str , Tensor ]:
402403 """
403404 Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
404405 plugins.
405406 """
406407 return getattr (self .training_type_plugin , 'optimizer_state' , lambda x : x .state_dict ())(optimizer )
407408
408- def on_save (self , checkpoint : Dict [str , Union [Any , torch . Tensor ]]) -> Dict [str , Union [Any , torch . Tensor ]]:
409+ def on_save (self , checkpoint : Dict [str , Union [Any , Tensor ]]) -> Dict [str , Union [Any , Tensor ]]:
409410 return self .training_type_plugin .on_save (checkpoint )
410411
411412 def barrier (self , name : Optional [str ] = None ) -> None :
@@ -420,7 +421,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
420421 """
421422 return self .training_type_plugin .broadcast (obj , src )
422423
423- def all_gather (self , tensor : torch . Tensor , group : Optional [Any ] = None , sync_grads : bool = False ) -> torch . Tensor :
424+ def all_gather (self , tensor : Tensor , group : Optional [Any ] = None , sync_grads : bool = False ) -> Tensor :
424425 """
425426 Function to gather a tensor from several distributed processes.
426427
@@ -464,7 +465,7 @@ def model_sharded_context(self) -> Generator[None, None, None]:
464465 yield
465466
466467 # todo: remove in v1.5
467- def connect_training_type_plugin (self , plugin : TrainingTypePlugin , model : LightningModule ) -> None :
468+ def connect_training_type_plugin (self , plugin : TrainingTypePlugin , model : 'pl. LightningModule' ) -> None :
468469 """
469470 Attaches the training type plugin to the accelerator.
470471 Also transfers ownership of the model to this plugin
0 commit comments