Skip to content

Commit de8fe1b

Browse files
committed
fix imports
1 parent 94e0b28 commit de8fe1b

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
from torch.optim import Optimizer
1818

1919
from pytorch_lightning.core import LightningModule
20-
from pytorch_lightning.plugins.training_type import TrainingTypePlugin, HorovodPlugin
2120
from pytorch_lightning.plugins.precision import (
22-
PrecisionPlugin,
23-
MixedPrecisionPlugin,
2421
ApexMixedPrecisionPlugin,
22+
MixedPrecisionPlugin,
2523
NativeMixedPrecisionPlugin,
24+
PrecisionPlugin,
2625
)
26+
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
27+
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
2728
from pytorch_lightning.utilities.apply_func import move_data_to_device
2829
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
2930

@@ -374,4 +375,4 @@ def optimizer_state(self, optimizer: Optimizer) -> dict:
374375
return optimizer.state_dict()
375376

376377
def on_save(self, checkpoint):
377-
return checkpoint
378+
return checkpoint

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
# limitations under the License.
1414
import os
1515
from abc import ABC, abstractmethod
16-
from typing import Any, Optional, Sequence, Union
16+
from typing import Any, Optional, Sequence, TYPE_CHECKING, Union
1717

1818
import torch
1919

2020
from pytorch_lightning import _logger as log
2121
from pytorch_lightning.core.lightning import LightningModule
2222
from pytorch_lightning.plugins.base_plugin import Plugin
23-
from pytorch_lightning.trainer import Trainer
23+
24+
if TYPE_CHECKING:
25+
from pytorch_lightning.trainer.trainer import Trainer
2426

2527

2628
class TrainingTypePlugin(Plugin, ABC):
@@ -105,10 +107,10 @@ def results(self) -> Any:
105107
def rpc_enabled(self) -> bool:
106108
return False
107109

108-
def start_training(self, trainer: Trainer) -> None:
110+
def start_training(self, trainer: 'Trainer') -> None:
109111
# double dispatch to initiate the training loop
110112
self._results = trainer.train()
111113

112-
def start_testing(self, trainer: Trainer) -> None:
114+
def start_testing(self, trainer: 'Trainer') -> None:
113115
# double dispatch to initiate the test loop
114116
self._results = trainer.run_test()

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ ignore_errors = True
142142
ignore_errors = True
143143

144144
# todo: add proper typing to this module...
145-
[mypy-pytorch_lightning.accelerators.legacy.*]
145+
[mypy-pytorch_lightning.accelerators.*]
146146
ignore_errors = True
147147

148148
# todo: add proper typing to this module...

0 commit comments

Comments
 (0)