Skip to content

Commit 7fbfd0a

Browse files
committed
tests: add default_root_dir=tmpdir
1 parent 90f641a commit 7fbfd0a

15 files changed

+76
-66
lines changed

tests/callbacks/test_callbacks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def on_test_end(self, trainer, pl_module):
165165
limit_val_batches=0.1,
166166
limit_train_batches=0.2,
167167
progress_bar_refresh_rate=0,
168+
default_root_dir=tmpdir,
168169
)
169170

170171
assert not test_callback.setup_called

tests/callbacks/test_progress_bar.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_progress_bar_misconfiguration():
5454
Trainer(callbacks=callbacks)
5555

5656

57-
def test_progress_bar_totals():
57+
def test_progress_bar_totals(tmpdir):
5858
"""Test that the progress finishes with the correct total steps processed."""
5959

6060
model = EvalModelTemplate()
@@ -63,6 +63,7 @@ def test_progress_bar_totals():
6363
progress_bar_refresh_rate=1,
6464
limit_val_batches=1.0,
6565
max_epochs=1,
66+
default_root_dir=tmpdir,
6667
)
6768
bar = trainer.progress_bar_callback
6869
assert 0 == bar.total_train_batches
@@ -136,7 +137,7 @@ def test_progress_bar_fast_dev_run():
136137

137138

138139
@pytest.mark.parametrize('refresh_rate', [0, 1, 50])
139-
def test_progress_bar_progress_refresh(refresh_rate):
140+
def test_progress_bar_progress_refresh(tmpdir, refresh_rate):
140141
"""Test that the three progress bars get correctly updated when using different refresh rates."""
141142

142143
model = EvalModelTemplate()
@@ -177,6 +178,7 @@ def on_test_batch_end(self, trainer, pl_module):
177178
limit_train_batches=1.0,
178179
num_sanity_val_steps=2,
179180
max_epochs=3,
181+
default_root_dir=tmpdir,
180182
)
181183
assert trainer.progress_bar_callback.refresh_rate == refresh_rate
182184

tests/loggers/test_all.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def log_metrics(self, metrics, step):
5454
limit_train_batches=0.2,
5555
limit_val_batches=0.5,
5656
fast_dev_run=True,
57+
default_root_dir=tmpdir,
5758
)
5859
trainer.fit(model)
5960

tests/loggers/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_multiple_loggers(tmpdir):
102102
assert logger2.finalized_status == "success"
103103

104104

105-
def test_multiple_loggers_pickle(tmpdir):
105+
def test_multiple_loggers_pickle():
106106
"""Verify that pickling trainer with multiple loggers works."""
107107

108108
logger1 = CustomLogger()

tests/loggers/test_tensorboard.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
def test_tensorboard_hparams_reload(tmpdir):
1818
model = EvalModelTemplate()
1919

20-
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
20+
trainer = Trainer(
21+
max_epochs=1,
22+
default_root_dir=tmpdir,
23+
)
2124
trainer.fit(model)
2225

2326
folder_path = trainer.logger.log_dir

tests/models/test_amp.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_amp_single_gpu(tmpdir, backend):
2121
max_epochs=1,
2222
gpus=1,
2323
distributed_backend=backend,
24-
precision=16
24+
precision=16,
2525
)
2626

2727
model = EvalModelTemplate()
@@ -39,18 +39,15 @@ def test_amp_multi_gpu(tmpdir, backend):
3939
tutils.set_random_master_port()
4040

4141
model = EvalModelTemplate()
42-
43-
trainer_options = dict(
42+
# tutils.run_model_test(trainer_options, model)
43+
trainer = Trainer(
4444
default_root_dir=tmpdir,
4545
max_epochs=1,
4646
# gpus=2,
4747
gpus='0, 1', # test init with gpu string
4848
distributed_backend=backend,
4949
precision=16,
5050
)
51-
52-
# tutils.run_model_test(trainer_options, model)
53-
trainer = Trainer(**trainer_options)
5451
result = trainer.fit(model)
5552
assert result
5653

@@ -66,17 +63,15 @@ def test_multi_gpu_wandb(tmpdir, backend):
6663
model = EvalModelTemplate()
6764
logger = WandbLogger(name='utest')
6865

69-
trainer_options = dict(
66+
# tutils.run_model_test(trainer_options, model)
67+
trainer = Trainer(
7068
default_root_dir=tmpdir,
7169
max_epochs=1,
7270
gpus=2,
7371
distributed_backend=backend,
7472
precision=16,
7573
logger=logger,
76-
7774
)
78-
# tutils.run_model_test(trainer_options, model)
79-
trainer = Trainer(**trainer_options)
8075
result = trainer.fit(model)
8176
assert result
8277
trainer.test(model)
@@ -106,6 +101,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
106101
precision=16,
107102
checkpoint_callback=checkpoint,
108103
logger=logger,
104+
default_root_dir=tmpdir,
109105
)
110106
trainer.is_slurm_managing_tasks = True
111107
result = trainer.fit(model)

tests/models/test_cpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def test_cpu_slurm_save_load(tmpdir):
2929
limit_train_batches=0.2,
3030
limit_val_batches=0.2,
3131
checkpoint_callback=ModelCheckpoint(tmpdir),
32+
default_root_dir=tmpdir,
3233
)
3334
result = trainer.fit(model)
3435
real_global_step = trainer.global_step
@@ -64,6 +65,7 @@ def test_cpu_slurm_save_load(tmpdir):
6465
max_epochs=1,
6566
logger=logger,
6667
checkpoint_callback=ModelCheckpoint(tmpdir),
68+
default_root_dir=tmpdir,
6769
)
6870
model = EvalModelTemplate(**hparams)
6971

@@ -220,6 +222,7 @@ def test_running_test_no_val(tmpdir):
220222
checkpoint_callback=checkpoint,
221223
logger=logger,
222224
early_stop_callback=False,
225+
default_root_dir=tmpdir,
223226
)
224227
result = trainer.fit(model)
225228

tests/models/test_gpu.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,16 @@ def test_multi_gpu_model(tmpdir, backend):
3838
"""Make sure DDP works."""
3939
tutils.set_random_master_port()
4040

41-
trainer_options = dict(
41+
model = EvalModelTemplate()
42+
# tutils.run_model_test(trainer_options, model)
43+
trainer = Trainer(
4244
default_root_dir=tmpdir,
4345
max_epochs=1,
4446
limit_train_batches=0.4,
4547
limit_val_batches=0.2,
4648
gpus=[0, 1],
4749
distributed_backend=backend,
4850
)
49-
50-
model = EvalModelTemplate()
51-
# tutils.run_model_test(trainer_options, model)
52-
trainer = Trainer(**trainer_options)
5351
result = trainer.fit(model)
5452
assert result
5553

@@ -63,7 +61,11 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
6361
"""Make sure DDP works with dataloaders passed to fit()"""
6462
tutils.set_random_master_port()
6563

66-
trainer_options = dict(
64+
model = EvalModelTemplate()
65+
fit_options = dict(train_dataloader=model.train_dataloader(),
66+
val_dataloaders=model.val_dataloader())
67+
68+
trainer = Trainer(
6769
default_root_dir=tmpdir,
6870
progress_bar_refresh_rate=0,
6971
max_epochs=1,
@@ -72,12 +74,6 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
7274
gpus=[0, 1],
7375
distributed_backend='ddp'
7476
)
75-
76-
model = EvalModelTemplate()
77-
fit_options = dict(train_dataloader=model.train_dataloader(),
78-
val_dataloaders=model.val_dataloader())
79-
80-
trainer = Trainer(**trainer_options)
8177
result = trainer.fit(model, **fit_options)
8278
assert result == 1, "DDP doesn't work with dataloaders passed to fit()."
8379

tests/models/test_grad_norm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
8888
logger=logger,
8989
track_grad_norm=norm_type,
9090
row_log_interval=1, # request grad_norms every batch
91+
default_root_dir=tmpdir,
9192
)
9293
result = trainer.fit(model)
9394

tests/models/test_hooks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
@pytest.mark.parametrize('max_steps', [1, 2, 3])
11-
def test_on_before_zero_grad_called(max_steps):
11+
def test_on_before_zero_grad_called(tmpdir, max_steps):
1212

1313
class CurrentTestModel(EvalModelTemplate):
1414
on_before_zero_grad_called = 0
@@ -21,6 +21,7 @@ def on_before_zero_grad(self, optimizer):
2121
trainer = Trainer(
2222
max_steps=max_steps,
2323
num_sanity_val_steps=5,
24+
default_root_dir=tmpdir,
2425
)
2526
assert 0 == model.on_before_zero_grad_called
2627
trainer.fit(model)

0 commit comments

Comments
 (0)