|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | """LightningDataModule for loading DataLoaders with ease.""" |
| 15 | +import inspect |
15 | 16 | from argparse import ArgumentParser, Namespace |
16 | 17 | from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Union |
17 | 18 |
|
@@ -109,19 +110,22 @@ def from_datasets( |
109 | 110 | predict_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None, |
110 | 111 | batch_size: int = 1, |
111 | 112 | num_workers: int = 0, |
| 113 | + **datamodule_kwargs: Any, |
112 | 114 | ): |
113 | 115 | r""" |
114 | 116 | Create an instance from torch.utils.data.Dataset. |
115 | 117 |
|
116 | 118 | Args: |
117 | | - train_dataset: (optional) Dataset to be used for train_dataloader() |
118 | | - val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader() |
119 | | - test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader() |
120 | | - predict_dataset: (optional) Dataset or list of Dataset to be used for predict_dataloader() |
121 | | - batch_size: Batch size to use for each dataloader. Default is 1. |
| 119 | + train_dataset: Optional dataset to be used for train_dataloader() |
| 120 | + val_dataset: Optional dataset or list of Dataset to be used for val_dataloader() |
| 121 | + test_dataset: Optional dataset or list of Dataset to be used for test_dataloader() |
| 122 | + predict_dataset: Optional dataset or list of Dataset to be used for predict_dataloader() |
| 123 | + batch_size: Batch size to use for each dataloader. Default is 1. This parameter gets forwarded to the |
| 124 | + ``__init__`` if the datamodule has such a name defined in its signature. |
122 | 125 | num_workers: Number of subprocesses to use for data loading. 0 means that the |
123 | | - data will be loaded in the main process. Number of CPUs available. |
124 | | -
|
| 126 | + data will be loaded in the main process. Number of CPUs available. This parameter gets forwarded to the |
| 127 | + ``__init__`` if the datamodule has such a name defined in its signature. |
| 128 | + **datamodule_kwargs: Additional parameters that get passed down to the datamodule's ``__init__``. |
125 | 129 | """ |
126 | 130 |
|
127 | 131 | def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader: |
@@ -150,7 +154,17 @@ def predict_dataloader(): |
150 | 154 | return [dataloader(ds) for ds in predict_dataset] |
151 | 155 | return dataloader(predict_dataset) |
152 | 156 |
|
153 | | - datamodule = cls() |
| 157 | + candidate_kwargs = dict(batch_size=batch_size, num_workers=num_workers) |
| 158 | + accepted_params = inspect.signature(cls.__init__).parameters |
| 159 | + accepts_kwargs = any(param.kind == param.VAR_KEYWORD for param in accepted_params.values()) |
| 160 | + if accepts_kwargs: |
| 161 | + special_kwargs = candidate_kwargs |
| 162 | + else: |
| 163 | + accepted_params = set(accepted_params) |
| 164 | + accepted_params.discard("self") |
| 165 | + special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_params} |
| 166 | + |
| 167 | + datamodule = cls(**datamodule_kwargs, **special_kwargs) |
154 | 168 | if train_dataset is not None: |
155 | 169 | datamodule.train_dataloader = train_dataloader |
156 | 170 | if val_dataset is not None: |
|
0 commit comments