Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)


- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968))


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def on_train_start(self):
# provide rank to profiler
self.trainer.profile_connector.on_train_start(self.trainer)

def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None):
# clean hparams
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)
Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,20 @@ def on_trainer_init(self, auto_lr_find, auto_scale_batch_size):
self.trainer.auto_lr_find = auto_lr_find
self.trainer.auto_scale_batch_size = auto_scale_batch_size

def tune(self, model, train_dataloader, val_dataloaders, datamodule):
def setup_trainer(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: LightningDataModule = None,
):
self.trainer.model_connector.copy_trainer_model_properties(model)
# setup data, etc...
self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)

# hook
self.trainer.data_connector.prepare_data(model)

def tune(self, model, train_dataloader, val_dataloaders, datamodule):
# Run auto batch size scaling
if self.trainer.auto_scale_batch_size:
if isinstance(self.trainer.auto_scale_batch_size, bool):
Expand Down Expand Up @@ -104,6 +111,7 @@ def scale_batch_size(
or datamodule.

"""
self.setup_trainer(model, **fit_kwargs)
return scale_batch_size(
self.trainer,
model,
Expand All @@ -128,6 +136,7 @@ def lr_find(
datamodule: Optional[LightningDataModule] = None,
update_attr: bool = False,
):
self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule)
return lr_find(
self.trainer,
model,
Expand Down
65 changes: 65 additions & 0 deletions tests/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.tuner.tuning import Tuner
from tests.helpers import BoringDataModule, BoringModel


class BatchSizeDataModule(BoringDataModule):

def __init__(self, batch_size=None):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size

def train_dataloader(self):
return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1))


class BatchSizeModel(BoringModel):

def __init__(self, batch_size=None):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size


@pytest.mark.parametrize(
"model,datamodule", [
(BatchSizeModel(2), None),
(BatchSizeModel(2), BatchSizeDataModule(2)),
(BatchSizeModel(2), BatchSizeDataModule(None)),
(BatchSizeModel(None), BatchSizeDataModule(2)),
]
)
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):
""" Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=0,
max_epochs=1,
)
tuner = Tuner(trainer)
new_batch_size = tuner.scale_batch_size(
model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
)
assert new_batch_size == 16
if hasattr(model, "batch_size"):
assert model.batch_size == 16
if datamodule is not None and hasattr(datamodule, "batch_size"):
assert datamodule.batch_size == 16