|
16 | 16 |
|
17 | 17 | import torch |
18 | 18 | from torch import Tensor |
| 19 | +from torch.cuda.amp import GradScaler |
19 | 20 | from torch.nn import Module |
20 | 21 | from torch.optim import Optimizer |
21 | 22 | from torch.utils.data import DataLoader |
|
24 | 25 | from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin |
25 | 26 | from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin |
26 | 27 | from pytorch_lightning.trainer.states import TrainerFn |
27 | | -from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_deprecation |
| 28 | +from pytorch_lightning.utilities import rank_zero_deprecation |
28 | 29 | from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device |
29 | 30 | from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum |
30 | 31 | from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT |
31 | 32 |
|
32 | | -if _NATIVE_AMP_AVAILABLE: |
33 | | - from torch.cuda.amp import GradScaler |
34 | | - |
35 | 33 |
|
36 | 34 | class Accelerator: |
37 | 35 | """The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware. |
@@ -258,8 +256,6 @@ def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Cal |
258 | 256 | ) |
259 | 257 | if make_optimizer_step: |
260 | 258 | self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs) |
261 | | - self.precision_plugin.post_optimizer_step(optimizer, opt_idx) |
262 | | - self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs) |
263 | 259 |
|
264 | 260 | def run_optimizer_step( |
265 | 261 | self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any |
|
0 commit comments