|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | - |
| 14 | +import os |
| 15 | +from functools import partial |
15 | 16 | from typing import Callable, Iterable, Optional, Union |
16 | 17 |
|
17 | 18 | import pytorch_lightning as pl |
18 | 19 | from pytorch_lightning.utilities import rank_zero_deprecation |
19 | 20 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
20 | | -from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher, InterBatchParallelDataFetcher |
| 21 | +from pytorch_lightning.utilities.fetching import ( |
| 22 | + AbstractDataFetcher, |
| 23 | + DataFetcher, |
| 24 | + DataLoaderIterDataFetcher, |
| 25 | + InterBatchParallelDataFetcher, |
| 26 | +) |
21 | 27 | from pytorch_lightning.utilities.model_helpers import is_overridden |
| 28 | +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature |
22 | 29 | from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS |
| 30 | +from pytorch_lightning.utilities.warnings import rank_zero_warn |
23 | 31 |
|
24 | 32 |
|
25 | 33 | class DataConnector: |
26 | | - def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + trainer: "pl.Trainer", |
| 37 | + multiple_trainloader_mode: str = "max_size_cycle", |
| 38 | + train_data_fetcher: Optional[AbstractDataFetcher] = None, |
| 39 | + validate_data_fetcher: Optional[AbstractDataFetcher] = None, |
| 40 | + test_data_fetcher: Optional[AbstractDataFetcher] = None, |
| 41 | + ): |
27 | 42 | self.trainer = trainer |
28 | 43 | self.multiple_trainloader_mode = multiple_trainloader_mode |
29 | | - self.data_fetcher: AbstractDataFetcher = DataFetcher() |
| 44 | + |
| 45 | + self.train_data_fetcher = train_data_fetcher |
| 46 | + self.validate_data_fetcher = validate_data_fetcher |
| 47 | + self.test_data_fetcher = test_data_fetcher |
| 48 | + self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None |
| 49 | + |
| 50 | + @property |
| 51 | + def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]: |
| 52 | + if self.trainer.sanity_checking: |
| 53 | + return self.sanity_check_data_fetcher |
| 54 | + return self.test_data_fetcher if self.trainer.testing else self.validate_data_fetcher |
30 | 55 |
|
31 | 56 | def on_trainer_init( |
32 | 57 | self, |
@@ -66,15 +91,42 @@ def on_trainer_init( |
66 | 91 | self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs |
67 | 92 | self.trainer._is_data_prepared = False |
68 | 93 |
|
69 | | - def get_profiled_train_dataloader(self, train_dataloader) -> Iterable: |
70 | | - # FIXME: Temporary hack |
71 | | - if isinstance(self.data_fetcher, InterBatchParallelDataFetcher): |
72 | | - self.data_fetcher.setup(train_dataloader, batch_to_device=self.trainer.accelerator.batch_to_device) |
73 | | - else: |
74 | | - self.data_fetcher.setup(train_dataloader) |
75 | | - prefetcher_iter = iter(self.data_fetcher) |
76 | | - profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefetcher_iter), "get_train_batch") |
77 | | - return profiled_dl |
| 94 | + def _check_training_step_requires_dataloader_iter(self) -> bool: |
| 95 | + training_step_fx = getattr(self.trainer.lightning_module, "training_step") |
| 96 | + contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True) |
| 97 | + return contains_dataloader_iter |
| 98 | + |
| 99 | + def _select_data_fetcher(self) -> AbstractDataFetcher: |
| 100 | + if self.trainer.sanity_checking: |
| 101 | + return DataFetcher() |
| 102 | + |
| 103 | + if self.trainer.training and self._check_training_step_requires_dataloader_iter(): |
| 104 | + rank_zero_warn( |
| 105 | + "Found `dataloader_iter` argument in the `training_step`. Note that the support for " |
| 106 | + "this signature is experimental and the behavior is subject to change." |
| 107 | + ) |
| 108 | + return DataLoaderIterDataFetcher() |
| 109 | + elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": |
| 110 | + # note: this is an experimental feature |
| 111 | + if not self.trainer.training_type_plugin.on_gpu: |
| 112 | + raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") |
| 113 | + return InterBatchParallelDataFetcher() |
| 114 | + |
| 115 | + return DataFetcher() |
| 116 | + |
| 117 | + def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) -> Iterable: |
| 118 | + stage: str = self.trainer.state.stage.value |
| 119 | + data_fetcher = setattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher() |
| 120 | + data_fetcher.setup( |
| 121 | + dataloader, |
| 122 | + stage=stage, |
| 123 | + batch_to_device=partial(self.trainer.accelerator.batch_to_device, dataloader_idx=dataloader_idx), |
| 124 | + profiler=self.trainer.profiler, |
| 125 | + ) |
| 126 | + setattr(self, f"{stage}_data_fetcher", data_fetcher) |
| 127 | + if isinstance(data_fetcher, DataLoaderIterDataFetcher): |
| 128 | + return data_fetcher |
| 129 | + return enumerate(data_fetcher) |
78 | 130 |
|
79 | 131 | def prepare_data(self) -> None: |
80 | 132 | # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 |
|
0 commit comments