Skip to content

Commit febd7fa

Browse files
committed
yapf
1 parent 7b9fccb commit febd7fa

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

tests/helpers/pipelines.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,18 @@ def run_model_test_without_loggers(
3333
# correct result and ok accuracy
3434
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
3535

36-
pretrained_model = load_model_from_checkpoint(
36+
model2 = load_model_from_checkpoint(
3737
trainer.logger, trainer.checkpoint_callback.best_model_path, type(model)
3838
)
3939

4040
# test new model accuracy
41-
test_loaders = model.test_dataloader() if not data else data.test_dataloader()
41+
test_loaders = model2.test_dataloader() if not data else data.test_dataloader()
4242
if not isinstance(test_loaders, list):
4343
test_loaders = [test_loaders]
4444

45-
if not isinstance(model, BoringModel):
45+
if not isinstance(model2, BoringModel):
4646
for dataloader in test_loaders:
47-
run_prediction_eval_model_template(model, dataloader, min_acc=min_acc)
47+
run_prediction_eval_model_template(model2, dataloader, min_acc=min_acc)
4848

4949

5050
def run_model_test(

tests/models/test_horovod.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
from pytorch_lightning.utilities import _APEX_AVAILABLE, _HOROVOD_AVAILABLE, _NATIVE_AMP_AVAILABLE
3333
from tests.helpers import BoringModel
3434
from tests.helpers.advanced_models import BasicGAN
35-
from tests.helpers.datamodules import ClassifDataModule
36-
from tests.helpers.simple_models import ClassificationModel
3735

3836
if _HOROVOD_AVAILABLE:
3937
import horovod

tests/models/test_tpu.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,16 @@ def test_dataloaders_passed_to_fit(tmpdir):
226226
tutils.reset_seed()
227227
model = BoringModel()
228228

229-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, tpu_cores=8,)
230-
trainer.fit(model, train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader(),)
229+
trainer = Trainer(
230+
default_root_dir=tmpdir,
231+
max_epochs=1,
232+
tpu_cores=8,
233+
)
234+
trainer.fit(
235+
model,
236+
train_dataloader=model.train_dataloader(),
237+
val_dataloaders=model.val_dataloader(),
238+
)
231239
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
232240

233241

0 commit comments

Comments
 (0)