Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.3.2] - 2021-05-18

### Changed

- `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238))

### Fixed

- Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433))
- Fixed recursive passing of `wrong_type` keyword argument in `pytorch_lightning.utilities.apply_to_collection` ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433))
- Fixed setting correct `DistribType` for `ddp_cpu` (spawn) backend ([#7492](https://github.com/PyTorchLightning/pytorch-lightning/pull/7492))
- Fixed incorrect number of calls to LR scheduler when `check_val_every_n_epoch > 1` ([#7032](https://github.com/PyTorchLightning/pytorch-lightning/pull/7032))


## [1.3.1] - 2021-05-11

### Fixed
Expand Down
28 changes: 16 additions & 12 deletions docs/source/extensions/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,6 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)


.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.


---------------

LightningDataModule API
Expand Down Expand Up @@ -228,7 +224,7 @@ There are also data operations you might want to perform on every GPU. Use setup
def setup(self, stage: Optional[str] = None):

# Assign Train/val split(s) for use in Dataloaders
if stage == 'fit' or stage is None:
if stage in (None, 'fit'):
mnist_full = MNIST(
self.data_dir,
train=True,
Expand All @@ -239,7 +235,7 @@ There are also data operations you might want to perform on every GPU. Use setup
self.dims = self.mnist_train[0][0].shape

# Assign Test split(s) for use in Dataloaders
if stage == 'test' or stage is None:
if stage in (None, 'test'):
self.mnist_test = MNIST(
self.data_dir,
train=False,
Expand All @@ -249,10 +245,17 @@ There are also data operations you might want to perform on every GPU. Use setup
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)


.. warning:: ``setup`` is called from every process. Setting state here is okay.

:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` expects an ``stage: Optional[str]`` argument.
It is used to separate setup logic for ``trainer.{fit,validate,test}``. If ``setup`` is called with ``stage = None``,
we assume all stages have been set-up.

.. note:: ``setup`` is called from every process. Setting state here is okay.
.. note:: ``teardown`` can be used to clean up the state. It is also called from every process
.. note::
``{setup,teardown,prepare_data}`` call will be only called once for a specific stage.
If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that
any duplicate ``dm.setup('fit')`` calls will be a no-op. To avoid this, you can overwrite
``dm._has_setup_fit = False``


train_dataloader
Expand Down Expand Up @@ -396,11 +399,12 @@ The recommended way to use a DataModule is simply:
dm = MNISTDataModule()
model = Model()
trainer.fit(model, dm)

trainer.test(datamodule=dm)

If you need information from the dataset to build your model, then run `prepare_data` and `setup` manually (Lightning
still ensures the method runs on the correct devices)
If you need information from the dataset to build your model, then run
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data` and
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` manually (Lightning ensures
the method runs on the correct devices).

.. code-block:: python

Expand All @@ -416,7 +420,7 @@ still ensures the method runs on the correct devices)

----------------

Datamodules without Lightning
DataModules without Lightning
-----------------------------
You can of course use DataModules in plain PyTorch code as well.

Expand Down
2 changes: 0 additions & 2 deletions docs/source/starter/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,6 @@ When your models need to know about the data, it's best to process the data befo
1. use ``prepare_data()`` to download and process the dataset.
2. use ``setup()`` to do splits, and build your model internals

|

An alternative to using a DataModule is to defer initialization of the models modules to the ``setup`` method of your LightningModule as follows:

.. testcode::
Expand Down
4 changes: 2 additions & 2 deletions docs/source/starter/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -658,10 +658,10 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning.
transforms.Normalize((0.1307,), (0.3081,))
])
# split dataset
if stage == 'fit':
if stage in (None, 'fit'):
mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
if stage == 'test':
if stage == (None, 'test'):
self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform)

# return the dataloader for each split
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/__about__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

_this_year = time.strftime("%Y")
__version__ = '1.3.1'
__version__ = '1.3.2'
__author__ = 'William Falcon et al.'
__author_email__ = '[email protected]'
__license__ = 'Apache-2.0'
Expand Down
14 changes: 11 additions & 3 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable
@functools.wraps(fn)
def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
name = fn.__name__
has_run = False

# If calling setup, we check the stage and assign stage-specific bool args
if name in ("setup", "teardown"):
Expand All @@ -366,15 +367,22 @@ def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
stage = args[0] if len(args) else kwargs.get("stage", None)

if stage is None:
has_run = True
for s in ("fit", "validate", "test"):
setattr(obj, f"_has_{name}_{s}", True)
attr = f"_has_{name}_{s}"
has_run &= getattr(obj, attr)
setattr(obj, attr, True)
else:
setattr(obj, f"_has_{name}_{stage}", True)
attr = f"_has_{name}_{stage}"
has_run = getattr(obj, attr)
setattr(obj, attr, True)

elif name == "prepare_data":
has_run = obj._has_prepared_data
obj._has_prepared_data = True

return fn(*args, **kwargs)
if not has_run:
return fn(*args, **kwargs)

return wrapped_fn

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def prepare_data(self):

def setup(self, stage: Optional[str] = None) -> None:
"""
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
Called at the beginning of fit (train + validate), validate, test, and predict.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):

# special case with DDP on CPUs
if self.distributed_backend == "ddp_cpu":
self._distrib_type = DistributedType.DDP
self._distrib_type = DistributedType.DDP_SPAWN
if self.num_gpus > 0:
rank_zero_warn(
'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.'
Expand Down
65 changes: 36 additions & 29 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,7 @@ def max_len(self) -> Union[int, float]:
def min_len(self) -> Union[int, float]:
return self._calc_num_data(self.datasets, 'min_size')

@staticmethod
def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]:
def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]:
"""
Compute the length of `CombinedDataset` according to the `mode`.

Expand All @@ -281,9 +280,7 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int,
raise MisconfigurationException(f"Invalid Mode: {mode}")

# extract the lengths
all_lengths = apply_to_collection(
datasets, (Dataset, Iterable, type(None)), get_len, wrong_dtype=(Sequence, Mapping)
)
all_lengths = self._get_len_recursive(datasets)

compute_func = CombinedDataset.COMPUTE_FUNCS[mode]

Expand All @@ -294,6 +291,30 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int,

return length

def _get_len_recursive(self, data) -> int:
if isinstance(data, Dataset):
return len(data)

elif isinstance(data, (float, int)):
return data

elif isinstance(data, Mapping):
if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data.values()):
return {k: self._get_len_recursive(v) for k, v in data.items()}
elif isinstance(data, Sequence):
data = list(data)
if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data):
return [self._get_len_recursive(v) for v in data]

return self._get_len(data)

@staticmethod
def _get_len(dataset) -> int:
try:
return len(dataset)
except (TypeError, NotImplementedError):
return float('inf')

def __len__(self) -> int:
"""Return the minimum length of the datasets."""
return self._calc_num_data(self.datasets, self.mode)
Expand Down Expand Up @@ -335,6 +356,9 @@ def __init__(self, loaders: Any, mode: str = 'min_size'):
'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones.

"""
if mode not in self.SUPPORTED_MODES:
raise MisconfigurationException(f"Invalid Mode: {mode}")

self.loaders = loaders

datasets = apply_to_collection(
Expand All @@ -343,9 +367,6 @@ def __init__(self, loaders: Any, mode: str = 'min_size'):
# could be multiple datasets, but use self.dataset to follow the name convention in DataLoader
self.dataset = CombinedDataset(datasets, mode)

if mode not in self.SUPPORTED_MODES:
raise MisconfigurationException(f"Invalid Mode: {mode}")

self.mode = mode

if self.mode == 'max_size_cycle':
Expand All @@ -366,27 +387,13 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
"""
all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))

if isinstance(all_lengths, (int, float)):
length = all_lengths

elif isinstance(all_lengths, Mapping):
length = max(all_lengths.values())
length = _nested_calc_num_data(all_lengths, max)

elif isinstance(all_lengths, Sequence):
length = max(all_lengths)

if isinstance(self.loaders, Mapping):
self.loaders = type(self.loaders)({k: CycleIterator(v, length=length) for k, v in self.loaders.items()})

elif isinstance(self.loaders, Sequence):
self.loaders = type(self.loaders)([CycleIterator(v, length=length) for v in self.loaders])

# dataloaders are iterable but not sequence
elif isinstance(self.loaders, Iterable):
# only one dataloader, just keep it the same.
pass
else:
raise ValueError(f'Invalid Datatype for loaders: {type(self.loaders).__name__}')
# multiple loaders
if isinstance(self.loaders, (Sequence, Mapping)):
self.loaders = apply_to_collection(
self.loaders, Iterable, CycleIterator, length=length, wrong_dtype=(Sequence, Mapping)
)

def __iter__(self) -> Any:
"""
Expand Down Expand Up @@ -490,7 +497,7 @@ def create_loader_iters(

def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable):

if isinstance(data, int):
if isinstance(data, (float, int)):
return data

if isinstance(data, Mapping):
Expand Down
10 changes: 2 additions & 8 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,10 +1156,7 @@ def call_setup_hook(self, model: LightningModule) -> None:
self.accelerator.barrier("pre_setup")

if self.datamodule is not None:
called = getattr(self.datamodule, f'has_setup_{fn}')
if not called:
self.datamodule.setup(stage=fn)

self.datamodule.setup(stage=fn)
self.setup(model, stage=fn)
model.setup(stage=fn)

Expand All @@ -1182,10 +1179,7 @@ def call_teardown_hook(self, model: LightningModule) -> None:
fn = self.state.fn._setup_fn

if self.datamodule is not None:
called = getattr(self.datamodule, f'has_teardown_{fn}')
if not called:
self.datamodule.teardown(stage=fn)

self.datamodule.teardown(stage=fn)
self.profiler.teardown(stage=fn)
self.teardown(stage=fn)
model.teardown(stage=fn)
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,6 @@ def run_training_epoch(self):

train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
val_loop_called = False

batch_idx = None
is_last_batch = None
Expand Down Expand Up @@ -516,7 +515,6 @@ def run_training_epoch(self):
self.trainer.validating = True
self.trainer.run_evaluation()
self.trainer.training = True
val_loop_called = True

# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
Expand Down Expand Up @@ -565,7 +563,7 @@ def run_training_epoch(self):
should_train_only = self.trainer.disable_validation or should_skip_eval

# update epoch level lr_schedulers if no val loop outside train loop is triggered
if (val_loop_called and not should_check_val) or should_train_only:
if not should_check_val or should_train_only:
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')

if should_train_only:
Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,20 @@ def apply_to_collection(

# Recursively apply to collection items
if isinstance(data, Mapping):
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
return elem_type({
k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
for k, v in data.items()
})

if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
return elem_type(
*(apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data)
)

if isinstance(data, Sequence) and not isinstance(data, str):
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])
return elem_type([
apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data
])

# data is neither of dtype, nor a collection
return data
Expand Down
8 changes: 5 additions & 3 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,15 @@ def test_ipython_incompatible_backend_error(*_):
with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"):
Trainer(accelerator="ddp", gpus=2)

with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"):
Trainer(accelerator="ddp_cpu", num_processes=2)

with pytest.raises(MisconfigurationException, match="backend ddp2 is not compatible"):
Trainer(accelerator="ddp2", gpus=2)


@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True)
def test_ipython_compatible_backend(*_):
Trainer(accelerator="ddp_cpu", num_processes=2)


@pytest.mark.parametrize(
["accelerator", "plugin"],
[('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],
Expand Down
Loading