1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import contextlib
15- from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Sequence , TYPE_CHECKING , Union
15+ from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Sequence , Union
1616
1717import torch
1818from torch .optim import Optimizer
1919from torch .utils .data import DataLoader
2020
21+ import pytorch_lightning as pl
2122from pytorch_lightning .core import LightningModule
2223from pytorch_lightning .plugins .precision import ApexMixedPrecisionPlugin , NativeMixedPrecisionPlugin , PrecisionPlugin
2324from pytorch_lightning .plugins .training_type import TrainingTypePlugin
2627from pytorch_lightning .utilities .apply_func import move_data_to_device
2728from pytorch_lightning .utilities .enums import AMPType , GradClipAlgorithmType , LightningEnum
2829
29- if TYPE_CHECKING :
30- from torch .cuda .amp import GradScaler
31-
32- from pytorch_lightning .trainer .trainer import Trainer
33-
3430_STEP_OUTPUT_TYPE = Union [torch .Tensor , Dict [str , torch .Tensor ], None ]
3531
3632
@@ -40,6 +36,7 @@ class Accelerator(object):
4036 An Accelerator is meant to deal with one type of Hardware.
4137
4238 Currently there are accelerators for:
39+
4340 - CPU
4441 - GPU
4542 - TPU
@@ -79,9 +76,10 @@ def setup_environment(self) -> None:
7976 """
8077 self .training_type_plugin .setup_environment ()
8178
82- def setup (self , trainer : 'Trainer' , model : LightningModule ) -> None :
79+ def setup (self , trainer : 'pl. Trainer' , model : LightningModule ) -> None :
8380 """
8481 Setup plugins for the trainer fit and creates optimizers.
82+
8583 Args:
8684 trainer: the trainer instance
8785 model: the LightningModule
@@ -91,23 +89,23 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
9189 self .setup_optimizers (trainer )
9290 self .setup_precision_plugin (self .precision_plugin )
9391
94- def start_training (self , trainer : 'Trainer' ) -> None :
92+ def start_training (self , trainer : 'pl. Trainer' ) -> None :
9593 self .training_type_plugin .start_training (trainer )
9694
97- def start_evaluating (self , trainer : 'Trainer' ) -> None :
95+ def start_evaluating (self , trainer : 'pl. Trainer' ) -> None :
9896 self .training_type_plugin .start_evaluating (trainer )
9997
100- def start_predicting (self , trainer : 'Trainer' ) -> None :
98+ def start_predicting (self , trainer : 'pl. Trainer' ) -> None :
10199 self .training_type_plugin .start_predicting (trainer )
102100
103- def pre_dispatch (self , trainer : 'Trainer' ) -> None :
101+ def pre_dispatch (self , trainer : 'pl. Trainer' ) -> None :
104102 """Hook to do something before the training/evaluation/prediction starts."""
105103 self .training_type_plugin .pre_dispatch ()
106104 if self .training_type_plugin .setup_optimizers_in_pre_dispatch :
107105 self .setup_optimizers (trainer )
108106 self .precision_plugin .pre_dispatch ()
109107
110- def post_dispatch (self , trainer : 'Trainer' ) -> None :
108+ def post_dispatch (self , trainer : 'pl. Trainer' ) -> None :
111109 """Hook to do something before the training/evaluation/prediction starts."""
112110 self .training_type_plugin .post_dispatch ()
113111 self .precision_plugin .post_dispatch ()
@@ -169,12 +167,13 @@ def training_step(
169167
170168 Args:
171169 args: the arguments for the models training step. Can consist of the following:
172- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
173- The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
174- batch_idx (int): Integer displaying index of this batch
175- optimizer_idx (int): When using multiple optimizers, this argument will also be present.
176- hiddens(:class:`~torch.Tensor`): Passed in if
177- :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
170+
171+ - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
172+ The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
173+ - batch_idx (int): Integer displaying index of this batch
174+ - optimizer_idx (int): When using multiple optimizers, this argument will also be present.
175+ - hiddens(:class:`~torch.Tensor`): Passed in if
176+ :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
178177
179178 """
180179 args [0 ] = self .to_device (args [0 ])
@@ -190,11 +189,12 @@ def validation_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
190189
191190 Args:
192191 args: the arguments for the models validation step. Can consist of the following:
193- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
194- The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
195- batch_idx (int): The index of this batch
196- dataloader_idx (int): The index of the dataloader that produced this batch
197- (only if multiple val dataloaders used)
192+
193+ - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
194+ The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
195+ - batch_idx (int): The index of this batch
196+ - dataloader_idx (int): The index of the dataloader that produced this batch
197+ (only if multiple val dataloaders used)
198198 """
199199 batch = self .to_device (args [0 ])
200200
@@ -208,11 +208,12 @@ def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
208208
209209 Args:
210210 args: the arguments for the models test step. Can consist of the following:
211- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
212- The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
213- batch_idx (int): The index of this batch.
214- dataloader_idx (int): The index of the dataloader that produced this batch
215- (only if multiple test dataloaders used).
211+
212+ - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
213+ The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
214+ - batch_idx (int): The index of this batch.
215+ - dataloader_idx (int): The index of the dataloader that produced this batch
216+ (only if multiple test dataloaders used).
216217 """
217218 batch = self .to_device (args [0 ])
218219
@@ -226,11 +227,13 @@ def predict_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
226227
227228 Args:
228229 args: the arguments for the models predict step. Can consist of the following:
229- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
230- The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
231- batch_idx (int): The index of this batch.
232- dataloader_idx (int): The index of the dataloader that produced this batch
233- (only if multiple predict dataloaders used).
230+
231+ - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
232+ The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
233+ - batch_idx (int): The index of this batch.
234+ - dataloader_idx (int): The index of the dataloader that produced this batch
235+ (only if multiple predict dataloaders used).
236+
234237 """
235238 batch = self .to_device (args [0 ])
236239
@@ -336,7 +339,7 @@ def on_train_end(self) -> None:
336339 """Hook to do something at the end of the training"""
337340 pass
338341
339- def setup_optimizers (self , trainer : 'Trainer' ) -> None :
342+ def setup_optimizers (self , trainer : 'pl. Trainer' ) -> None :
340343 """creates optimizers and schedulers
341344
342345 Args:
@@ -385,7 +388,7 @@ def precision(self) -> Union[str, int]:
385388 return self .precision_plugin .precision
386389
387390 @property
388- def scaler (self ) -> Optional ['GradScaler' ]:
391+ def scaler (self ) -> Optional ['torch.cuda.amp. GradScaler' ]:
389392
390393 return getattr (self .precision_plugin , 'scaler' , None )
391394
@@ -423,6 +426,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
423426 tensor: tensor of shape (batch, ...)
424427 group: the process group to gather results from. Defaults to all processes (world)
425428 sync_grads: flag that allows users to synchronize gradients for all_gather op
429+
426430 Return:
427431 A tensor of shape (world_size, batch, ...)
428432 """
@@ -451,7 +455,8 @@ def model_sharded_context(self) -> Generator[None, None, None]:
451455 shard the model instantly - useful for extremely large models. Can save memory and
452456 initialization time.
453457
454- Returns: Model parallel context.
458+ Returns:
459+ Model parallel context.
455460 """
456461 with self .training_type_plugin .model_sharded_context ():
457462 yield
@@ -498,7 +503,9 @@ def call_configure_sharded_model_hook(self) -> bool:
498503 """
499504 Allow model parallel hook to be called in suitable environments determined by the training type plugin.
500505 This is useful for when we want to shard the model once within fit.
501- Returns: True if we want to call the model parallel setup hook.
506+
507+ Returns:
508+ True if we want to call the model parallel setup hook.
502509 """
503510 return self .training_type_plugin .call_configure_sharded_model_hook
504511
@@ -512,7 +519,9 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
512519 Override to delay setting optimizers and schedulers till after dispatch.
513520 This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
514521 However this may break certain precision plugins such as APEX which require optimizers to be set.
515- Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
522+
523+ Returns:
524+ If True, delay setup optimizers until `pre_dispatch`, else call within `setup`.
516525 """
517526 return self .training_type_plugin .setup_optimizers_in_pre_dispatch
518527
0 commit comments