Skip to content

Commit dd65024

Browse files
committed
tests: add default_root_dir=tmpdir
1 parent 1d565e1 commit dd65024

File tree

13 files changed

+52
-32
lines changed

13 files changed

+52
-32
lines changed

tests/callbacks/test_callbacks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def on_test_end(self, trainer, pl_module):
166166
limit_val_batches=0.1,
167167
limit_train_batches=0.2,
168168
progress_bar_refresh_rate=0,
169+
default_root_dir=tmpdir,
169170
)
170171

171172
assert not test_callback.setup_called

tests/callbacks/test_progress_bar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def test_progress_bar_totals(tmpdir):
6666
progress_bar_refresh_rate=1,
6767
limit_val_batches=1.0,
6868
max_epochs=1,
69+
default_root_dir=tmpdir,
6970
)
7071
bar = trainer.progress_bar_callback
7172
assert 0 == bar.total_train_batches
@@ -182,6 +183,7 @@ def on_test_batch_end(self, trainer, pl_module):
182183
limit_train_batches=1.0,
183184
num_sanity_val_steps=2,
184185
max_epochs=3,
186+
default_root_dir=tmpdir,
185187
)
186188
assert trainer.progress_bar_callback.refresh_rate == refresh_rate
187189

tests/loggers/test_all.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def log_metrics(self, metrics, step):
7272
limit_train_batches=0.2,
7373
limit_val_batches=0.5,
7474
fast_dev_run=True,
75+
default_root_dir=tmpdir,
7576
)
7677
trainer.fit(model)
7778
trainer.test()

tests/loggers/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_multiple_loggers(tmpdir):
111111
assert logger2.finalized_status == "success"
112112

113113

114-
def test_multiple_loggers_pickle(tmpdir):
114+
def test_multiple_loggers_pickle():
115115
"""Verify that pickling trainer with multiple loggers works."""
116116

117117
logger1 = CustomLogger()

tests/loggers/test_tensorboard.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
def test_tensorboard_hparams_reload(tmpdir):
1818
model = EvalModelTemplate()
1919

20-
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
20+
trainer = Trainer(
21+
max_epochs=1,
22+
default_root_dir=tmpdir,
23+
)
2124
trainer.fit(model)
2225

2326
folder_path = trainer.logger.log_dir

tests/models/test_cpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def test_cpu_slurm_save_load(tmpdir):
3030
limit_train_batches=0.2,
3131
limit_val_batches=0.2,
3232
checkpoint_callback=ModelCheckpoint(tmpdir),
33+
default_root_dir=tmpdir,
3334
)
3435
result = trainer.fit(model)
3536
real_global_step = trainer.global_step
@@ -66,6 +67,7 @@ def test_cpu_slurm_save_load(tmpdir):
6667
max_epochs=1,
6768
logger=logger,
6869
checkpoint_callback=ModelCheckpoint(tmpdir),
70+
default_root_dir=tmpdir,
6971
)
7072
model = EvalModelTemplate(**hparams)
7173

@@ -222,6 +224,7 @@ def test_running_test_no_val(tmpdir):
222224
checkpoint_callback=checkpoint,
223225
logger=logger,
224226
early_stop_callback=False,
227+
default_root_dir=tmpdir,
225228
)
226229
result = trainer.fit(model)
227230

tests/models/test_gpu.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
136136
"""Make sure DDP works with dataloaders passed to fit()"""
137137
tutils.set_random_master_port()
138138

139-
trainer_options = dict(
139+
model = EvalModelTemplate()
140+
fit_options = dict(train_dataloader=model.train_dataloader(),
141+
val_dataloaders=model.val_dataloader())
142+
143+
trainer = Trainer(
140144
default_root_dir=tmpdir,
141145
progress_bar_refresh_rate=0,
142146
max_epochs=1,
@@ -145,12 +149,6 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
145149
gpus=[0, 1],
146150
distributed_backend='ddp_spawn'
147151
)
148-
149-
model = EvalModelTemplate()
150-
fit_options = dict(train_dataloader=model.train_dataloader(),
151-
val_dataloaders=model.val_dataloader())
152-
153-
trainer = Trainer(**trainer_options)
154152
result = trainer.fit(model, **fit_options)
155153
assert result == 1, "DDP doesn't work with dataloaders passed to fit()."
156154

tests/models/test_grad_norm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
8989
logger=logger,
9090
track_grad_norm=norm_type,
9191
row_log_interval=1, # request grad_norms every batch
92+
default_root_dir=tmpdir,
9293
)
9394
result = trainer.fit(model)
9495

tests/models/test_hooks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def on_before_zero_grad(self, optimizer):
2323
max_steps=max_steps,
2424
max_epochs=2,
2525
num_sanity_val_steps=5,
26+
default_root_dir=tmpdir,
2627
)
2728
assert 0 == model.on_before_zero_grad_called
2829
trainer.fit(model)

tests/models/test_horovod.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ def validation_step(self, batch, *args, **kwargs):
147147
def test_horovod_multi_optimizer(tmpdir):
148148
model = TestGAN(**EvalModelTemplate.get_default_hparams())
149149

150-
trainer_options = dict(
150+
# fit model
151+
trainer = Trainer(
151152
default_root_dir=str(tmpdir),
152153
progress_bar_refresh_rate=0,
153154
max_epochs=1,
@@ -156,9 +157,6 @@ def test_horovod_multi_optimizer(tmpdir):
156157
deterministic=True,
157158
distributed_backend='horovod',
158159
)
159-
160-
# fit model
161-
trainer = Trainer(**trainer_options)
162160
result = trainer.fit(model)
163161
assert result == 1, 'model failed to complete'
164162

0 commit comments

Comments
 (0)