Skip to content

Commit d0a29b5

Browse files
awaelchlicarmocca
andcommitted
Fix dataloaders are not reset when tuning the model (#7566)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 84cfcbf commit d0a29b5

File tree

3 files changed

+37
-21
lines changed

3 files changed

+37
-21
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515

1616
- Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608))
1717
- Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592))
18+
- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))
1819

1920
## [1.3.2] - 2021-05-18
2021

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,10 @@ def _run_power_scaling(
160160
else:
161161
raise # some other error not memory related
162162

163-
if not changed:
163+
if changed:
164+
# Force the train dataloader to reset as the batch size has changed
165+
trainer.reset_train_dataloader(model)
166+
else:
164167
break
165168
return new_size
166169

@@ -192,7 +195,10 @@ def _run_binsearch_scaling(
192195
else:
193196
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
194197

195-
if not changed:
198+
if changed:
199+
# Force the train dataloader to reset as the batch size has changed
200+
trainer.reset_train_dataloader(model)
201+
else:
196202
break
197203

198204
except RuntimeError as exception:

tests/tuner/test_scale_batch_size.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@
2424
from pytorch_lightning.utilities import AMPType
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626
from tests.base import EvalModelTemplate
27-
from tests.helpers import BoringDataModule, BoringModel
27+
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
2828
from tests.helpers.datamodules import MNISTDataModule
2929
from tests.helpers.runif import RunIf
3030

3131

3232
class BatchSizeDataModule(BoringDataModule):
3333

34-
def __init__(self, batch_size=None):
34+
def __init__(self, batch_size):
3535
super().__init__()
3636
if batch_size is not None:
3737
self.batch_size = batch_size
@@ -42,21 +42,23 @@ def train_dataloader(self):
4242

4343
class BatchSizeModel(BoringModel):
4444

45-
def __init__(self, batch_size=None):
45+
def __init__(self, batch_size):
4646
super().__init__()
4747
if batch_size is not None:
4848
self.batch_size = batch_size
4949

50+
def train_dataloader(self):
51+
return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1))
5052

51-
@pytest.mark.parametrize(
52-
"model,datamodule", [
53-
(BatchSizeModel(2), None),
54-
(BatchSizeModel(2), BatchSizeDataModule(2)),
55-
(BatchSizeModel(2), BatchSizeDataModule(None)),
56-
(BatchSizeModel(None), BatchSizeDataModule(2)),
57-
]
58-
)
59-
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):
53+
54+
@pytest.mark.parametrize(["model_bs", "dm_bs"], [
55+
(2, -1),
56+
(2, 2),
57+
(2, None),
58+
(None, 2),
59+
(16, 16),
60+
])
61+
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_bs):
6062
""" Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
6163
trainer = Trainer(
6264
default_root_dir=tmpdir,
@@ -65,14 +67,21 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod
6567
max_epochs=1,
6668
)
6769
tuner = Tuner(trainer)
68-
new_batch_size = tuner.scale_batch_size(
69-
model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
70-
)
70+
71+
model = BatchSizeModel(model_bs)
72+
datamodule = BatchSizeDataModule(dm_bs) if dm_bs != -1 else None
73+
74+
new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule)
7175
assert new_batch_size == 16
72-
if hasattr(model, "batch_size"):
73-
assert model.batch_size == 16
74-
if datamodule is not None and hasattr(datamodule, "batch_size"):
75-
assert datamodule.batch_size == 16
76+
77+
if model_bs is not None:
78+
assert model.batch_size == new_batch_size
79+
if dm_bs == -1:
80+
# datamodule batch size takes precedence
81+
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
82+
if dm_bs not in (-1, None):
83+
assert datamodule.batch_size == new_batch_size
84+
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
7685

7786

7887
def test_model_reset_correctly(tmpdir):

0 commit comments

Comments
 (0)