Skip to content

Commit b3ad05c

Browse files
awaelchliotajcarmocca
authored andcommitted
Forward extra keyword arguments in LightningDataModule.from_datasets (#14185)
Co-authored-by: otaj <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent cec79c8 commit b3ad05c

File tree

3 files changed

+72
-9
lines changed

3 files changed

+72
-9
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290))
1313

1414

15-
-
15+
- Added support for passing extra init-parameters to the `LightningDataModule.from_datasets` ([#14185](https://github.com/Lightning-AI/lightning/issues/14185))
16+
1617

1718

1819
### Changed

src/pytorch_lightning/core/datamodule.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""LightningDataModule for loading DataLoaders with ease."""
15+
import inspect
1516
from argparse import ArgumentParser, Namespace
1617
from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Union
1718

@@ -109,19 +110,22 @@ def from_datasets(
109110
predict_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
110111
batch_size: int = 1,
111112
num_workers: int = 0,
113+
**datamodule_kwargs: Any,
112114
):
113115
r"""
114116
Create an instance from torch.utils.data.Dataset.
115117
116118
Args:
117-
train_dataset: (optional) Dataset to be used for train_dataloader()
118-
val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader()
119-
test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader()
120-
predict_dataset: (optional) Dataset or list of Dataset to be used for predict_dataloader()
121-
batch_size: Batch size to use for each dataloader. Default is 1.
119+
train_dataset: Optional dataset to be used for train_dataloader()
120+
val_dataset: Optional dataset or list of Dataset to be used for val_dataloader()
121+
test_dataset: Optional dataset or list of Dataset to be used for test_dataloader()
122+
predict_dataset: Optional dataset or list of Dataset to be used for predict_dataloader()
123+
batch_size: Batch size to use for each dataloader. Default is 1. This parameter gets forwarded to the
124+
``__init__`` if the datamodule has such a name defined in its signature.
122125
num_workers: Number of subprocesses to use for data loading. 0 means that the
123-
data will be loaded in the main process. Number of CPUs available.
124-
126+
data will be loaded in the main process. Number of CPUs available. This parameter gets forwarded to the
127+
``__init__`` if the datamodule has such a name defined in its signature.
128+
**datamodule_kwargs: Additional parameters that get passed down to the datamodule's ``__init__``.
125129
"""
126130

127131
def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
@@ -150,7 +154,17 @@ def predict_dataloader():
150154
return [dataloader(ds) for ds in predict_dataset]
151155
return dataloader(predict_dataset)
152156

153-
datamodule = cls()
157+
candidate_kwargs = dict(batch_size=batch_size, num_workers=num_workers)
158+
accepted_params = inspect.signature(cls.__init__).parameters
159+
accepts_kwargs = any(param.kind == param.VAR_KEYWORD for param in accepted_params.values())
160+
if accepts_kwargs:
161+
special_kwargs = candidate_kwargs
162+
else:
163+
accepted_params = set(accepted_params)
164+
accepted_params.discard("self")
165+
special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_params}
166+
167+
datamodule = cls(**datamodule_kwargs, **special_kwargs)
154168
if train_dataset is not None:
155169
datamodule.train_dataloader = train_dataloader
156170
if val_dataset is not None:

tests/tests_pytorch/core/test_datamodules.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,54 @@ def test_dm_init_from_datasets_dataloaders(iterable):
366366
)
367367

368368

369+
def test_dm_init_from_datasets_with_init_params():
370+
"""Test that extra kwargs can be passed down to the init via the ``LightningDataModule.from_datasets`` method.
371+
372+
The two special arguments batch_size and num_workers get passed down depending on whether the __init__ accepts them.
373+
"""
374+
# No additional parameters
375+
LightningDataModule.from_datasets(DummyDS(), batch_size=4, num_workers=2)
376+
377+
class KnownExtraParametersDataModule(LightningDataModule):
378+
def __init__(self, batch_size=1, num_workers=0):
379+
super().__init__()
380+
self.batch_size = batch_size
381+
self.num_workers = num_workers
382+
383+
# batch_size and num_workers get special treatment - they are part of the `from_datasets` signature
384+
dm = KnownExtraParametersDataModule.from_datasets(DummyDS(), batch_size=4, num_workers=2)
385+
assert dm.batch_size == 4
386+
assert dm.num_workers == 2
387+
388+
class UnknownExtraParametersDataModule(LightningDataModule):
389+
def __init__(self, other, batch_size=1):
390+
super().__init__()
391+
self.other = other
392+
self.batch_size = batch_size
393+
394+
# additional parameter `other` gets forwarded, alongside the special `batch_size` parameter
395+
dm = UnknownExtraParametersDataModule.from_datasets(DummyDS(), batch_size=4, num_workers=2, other=5)
396+
assert dm.batch_size == 4
397+
assert dm.other == 5
398+
399+
# positional arguments raise an error as they would when instantiating the datamodule normally
400+
with pytest.raises(TypeError, match="missing 1 required positional argument: 'other'"):
401+
UnknownExtraParametersDataModule.from_datasets(DummyDS(), batch_size=4, num_workers=2)
402+
403+
class KwargsParametersDataModule(LightningDataModule):
404+
def __init__(self, num_workers, **kwargs):
405+
super().__init__()
406+
self.num_workers = num_workers
407+
for key, value in kwargs.items():
408+
setattr(self, key, value)
409+
410+
# everything gets forwarded, because there is `**kwargs` present
411+
dm = KwargsParametersDataModule.from_datasets(DummyDS(), batch_size=10, num_workers=100, another=None)
412+
assert dm.batch_size == 10
413+
assert dm.num_workers == 100
414+
assert dm.another is None
415+
416+
369417
# all args
370418
class DataModuleWithHparams_0(LightningDataModule):
371419
def __init__(self, arg0, arg1, kwarg0=None):

0 commit comments

Comments
 (0)