Skip to content

Commit 6cce24f

Browse files
awaelchliBorda
authored andcommitted
fix cyclic import
1 parent 829a822 commit 6cce24f

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from torch.utils.data import DataLoader
1919

2020
from pytorch_lightning.core import LightningModule
21-
from pytorch_lightning.trainer.trainer import Trainer
2221
from pytorch_lightning.plugins.precision import (
2322
ApexMixedPrecisionPlugin,
2423
NativeMixedPrecisionPlugin,
@@ -63,7 +62,7 @@ def __init__(
6362
self.lr_schedulers = None
6463
self.optimizer_frequencies = None
6564

66-
def setup(self, trainer: "Trainer", model: LightningModule) -> None:
65+
def setup(self, trainer, model: LightningModule) -> None:
6766
"""
6867
Connects the plugins to the training process, creates optimizers
6968
@@ -302,7 +301,7 @@ def on_train_end(self) -> None:
302301
"""Hook to do something at the end of the training"""
303302
pass
304303

305-
def setup_optimizers(self, trainer: "Trainer"):
304+
def setup_optimizers(self, trainer):
306305
"""creates optimizers and schedulers
307306
308307
Args:

pytorch_lightning/plugins/training_type/rpc_sequential.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from torch.nn.parallel import DistributedDataParallel
2222
from torch.optim import Optimizer
2323

24-
from pytorch_lightning.trainer.trainer import Trainer
2524
from pytorch_lightning.core.lightning import LightningModule
2625
from pytorch_lightning.overrides.distributed import LightningDistributedModule
2726
from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin
@@ -330,11 +329,11 @@ def post_training(self):
330329
if self.main_rpc_process:
331330
super().post_training()
332331

333-
def start_training(self, trainer: 'Trainer') -> None:
332+
def start_training(self, trainer) -> None:
334333
if self.main_rpc_process:
335334
super().start_training(trainer)
336335

337-
def start_testing(self, trainer: 'Trainer') -> None:
336+
def start_testing(self, trainer) -> None:
338337
if self.main_rpc_process:
339338
super().start_testing(trainer)
340339

0 commit comments

Comments
 (0)