Skip to content

Commit 92c7eec

Browse files
tchatoncarmocca
andauthored
2/n inter batch parallelism (#9047)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 9fdc0be commit 92c7eec

File tree

10 files changed

+160
-70
lines changed

10 files changed

+160
-70
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6363
- Added `DataLoaderIterDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020))
6464

6565

66+
- Added `DataFetcher` within `Fit / Evaluation` Loop ([#9047](https://github.com/PyTorchLightning/pytorch-lightning/pull/9047))
67+
68+
6669
- Added a friendly error message when DDP attempts to spawn new distributed processes with rank > 0 ([#9005](https://github.com/PyTorchLightning/pytorch-lightning/pull/9005))
6770

6871

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ def _training_step(
280280
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
281281
self.trainer.accelerator.post_training_step()
282282

283+
del step_kwargs
284+
283285
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
284286

285287
_check_training_step_output(self.trainer.lightning_module, training_step_output)

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, List, Optional, Sequence, Union
15+
from typing import Any, Iterator, List, Optional, Sequence, Union
1616

1717
from deprecate.utils import void
1818
from torch.utils.data.dataloader import DataLoader
1919

2020
from pytorch_lightning.loops.dataloader import DataLoaderLoop
2121
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
2222
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
23-
from pytorch_lightning.utilities.fetching import DataFetcher
2423
from pytorch_lightning.utilities.model_helpers import is_overridden
2524
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
2625

@@ -98,10 +97,13 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None:
9897
def advance(self, *args: Any, **kwargs: Any) -> None:
9998
"""Performs evaluation on one single dataloader"""
10099
void(*args, **kwargs)
100+
101101
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
102-
data_fetcher = DataFetcher()
103-
data_fetcher.setup(dataloader)
104-
dataloader_iter = enumerate(data_fetcher)
102+
dataloader = self.trainer.data_connector.get_profiled_dataloader(
103+
dataloader, dataloader_idx=self.current_dataloader_idx
104+
)
105+
dataloader_iter = iter(dataloader)
106+
105107
dl_max_batches = self._max_batches[self.current_dataloader_idx]
106108

107109
dl_outputs = self.epoch_loop.run(

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ def advance(
9191
if batch is None:
9292
raise StopIteration
9393

94-
with self.trainer.profiler.profile("evaluation_batch_to_device"):
95-
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)
94+
if not self.trainer.data_connector.evaluation_data_fetcher.store_on_device:
95+
with self.trainer.profiler.profile("evaluation_batch_to_device"):
96+
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)
9697

9798
self.batch_progress.increment_ready()
9899

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,13 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
132132
else:
133133
_, (batch, is_last) = next(dataloader_iter)
134134

135-
# ------------------------------------
136-
# TRAINING_STEP + TRAINING_STEP_END
137-
# ------------------------------------
138-
# FIXME: Remove with InterBatchProcessor.
139-
if not self.trainer.data_connector.data_fetcher.store_on_device:
135+
if not self.trainer.data_connector.train_data_fetcher.store_on_device:
140136
with self.trainer.profiler.profile("training_batch_to_device"):
141137
batch = self.trainer.accelerator.batch_to_device(batch)
142138

139+
# ------------------------------------
140+
# TRAINING_STEP + TRAINING_STEP_END
141+
# ------------------------------------
143142
self.batch_progress.increment_ready()
144143

145144
with self.trainer.profiler.profile("run_training_batch"):

pytorch_lightning/loops/fit_loop.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
from contextlib import suppress
17-
from typing import Optional
17+
from typing import Iterator, Optional
1818

1919
from pytorch_lightning.loops import Loop
2020
from pytorch_lightning.loops.epoch import TrainingEpochLoop
@@ -192,12 +192,13 @@ def on_advance_start(self) -> None:
192192

193193
def advance(self) -> None:
194194
"""Runs one whole epoch."""
195-
train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)
196-
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
195+
dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)
196+
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader)
197+
dataloader_iter = iter(dataloader)
197198

198199
with self.trainer.profiler.profile("run_training_epoch"):
199200
# run train epoch
200-
epoch_output = self.epoch_loop.run(train_dataloader)
201+
epoch_output = self.epoch_loop.run(dataloader_iter)
201202

202203
if epoch_output is None:
203204
return

pytorch_lightning/loops/processors/iterator_batch_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]:
113113
Args:
114114
dataloader_iter: the iterator over the dataloader producing the new batch
115115
"""
116-
_, (dataloader_iter, batch_idx, is_last) = next(dataloader_iter)
116+
batch_idx, (dataloader_iter, is_last) = next(dataloader_iter)
117117

118118
self.trainer.logger_connector.on_batch_start()
119119
response = self.trainer.call_hook("on_batch_start")

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,47 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import os
15+
from functools import partial
1516
from typing import Callable, Iterable, Optional, Union
1617

1718
import pytorch_lightning as pl
1819
from pytorch_lightning.utilities import rank_zero_deprecation
1920
from pytorch_lightning.utilities.exceptions import MisconfigurationException
20-
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher, InterBatchParallelDataFetcher
21+
from pytorch_lightning.utilities.fetching import (
22+
AbstractDataFetcher,
23+
DataFetcher,
24+
DataLoaderIterDataFetcher,
25+
InterBatchParallelDataFetcher,
26+
)
2127
from pytorch_lightning.utilities.model_helpers import is_overridden
28+
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
2229
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
30+
from pytorch_lightning.utilities.warnings import rank_zero_warn
2331

2432

2533
class DataConnector:
26-
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
34+
def __init__(
35+
self,
36+
trainer: "pl.Trainer",
37+
multiple_trainloader_mode: str = "max_size_cycle",
38+
train_data_fetcher: Optional[AbstractDataFetcher] = None,
39+
validate_data_fetcher: Optional[AbstractDataFetcher] = None,
40+
test_data_fetcher: Optional[AbstractDataFetcher] = None,
41+
):
2742
self.trainer = trainer
2843
self.multiple_trainloader_mode = multiple_trainloader_mode
29-
self.data_fetcher: AbstractDataFetcher = DataFetcher()
44+
45+
self.train_data_fetcher = train_data_fetcher
46+
self.validate_data_fetcher = validate_data_fetcher
47+
self.test_data_fetcher = test_data_fetcher
48+
self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None
49+
50+
@property
51+
def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]:
52+
if self.trainer.sanity_checking:
53+
return self.sanity_check_data_fetcher
54+
return self.test_data_fetcher if self.trainer.testing else self.validate_data_fetcher
3055

3156
def on_trainer_init(
3257
self,
@@ -66,15 +91,42 @@ def on_trainer_init(
6691
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
6792
self.trainer._is_data_prepared = False
6893

69-
def get_profiled_train_dataloader(self, train_dataloader) -> Iterable:
70-
# FIXME: Temporary hack
71-
if isinstance(self.data_fetcher, InterBatchParallelDataFetcher):
72-
self.data_fetcher.setup(train_dataloader, batch_to_device=self.trainer.accelerator.batch_to_device)
73-
else:
74-
self.data_fetcher.setup(train_dataloader)
75-
prefetcher_iter = iter(self.data_fetcher)
76-
profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefetcher_iter), "get_train_batch")
77-
return profiled_dl
94+
def _check_training_step_requires_dataloader_iter(self) -> bool:
95+
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
96+
contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True)
97+
return contains_dataloader_iter
98+
99+
def _select_data_fetcher(self) -> AbstractDataFetcher:
100+
if self.trainer.sanity_checking:
101+
return DataFetcher()
102+
103+
if self.trainer.training and self._check_training_step_requires_dataloader_iter():
104+
rank_zero_warn(
105+
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
106+
"this signature is experimental and the behavior is subject to change."
107+
)
108+
return DataLoaderIterDataFetcher()
109+
elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
110+
# note: this is an experimental feature
111+
if not self.trainer.training_type_plugin.on_gpu:
112+
raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
113+
return InterBatchParallelDataFetcher()
114+
115+
return DataFetcher()
116+
117+
def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) -> Iterable:
118+
stage: str = self.trainer.state.stage.value
119+
data_fetcher = setattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher()
120+
data_fetcher.setup(
121+
dataloader,
122+
stage=stage,
123+
batch_to_device=partial(self.trainer.accelerator.batch_to_device, dataloader_idx=dataloader_idx),
124+
profiler=self.trainer.profiler,
125+
)
126+
setattr(self, f"{stage}_data_fetcher", data_fetcher)
127+
if isinstance(data_fetcher, DataLoaderIterDataFetcher):
128+
return data_fetcher
129+
return enumerate(data_fetcher)
78130

79131
def prepare_data(self) -> None:
80132
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0

pytorch_lightning/utilities/fetching.py

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,27 @@ class SimpleDataFetcher(AbstractDataFetcher):
4747
def fetching_function(self):
4848
while True:
4949
try:
50-
yield next(self.dataloader_iter), False
50+
return next(self.dataloader_iter), False
5151
except StopIteration:
5252
return None, True
5353
"""
5454

5555
@abstractmethod
56-
def fetching_function(self) -> Generator:
56+
def fetching_function(self) -> Any:
5757
"""Override with your own fetching logic."""
5858

59+
@abstractmethod
60+
def prefetching(self, prefetch_batches: int) -> None:
61+
"""Override with your own pre-fetching logic."""
62+
5963
def __init__(
6064
self,
6165
prefetch_batches: int = 0,
6266
) -> None:
6367
if prefetch_batches < 0:
6468
raise MisconfigurationException("`prefetch_batches` should at least be 0.")
6569

70+
self.store_on_device = False
6671
self.prefetch_batches = prefetch_batches + 1
6772

6873
self.dataloader: Optional[Iterable] = None
@@ -192,6 +197,10 @@ def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
192197
self.reset()
193198
self.dataloader_iter = iter(self.dataloader)
194199
self._apply_patch()
200+
self.prefetching(self.prefetch_batches)
201+
return self
202+
203+
def __next__(self):
195204
return self.fetching_function()
196205

197206
def reset(self) -> None:
@@ -241,34 +250,38 @@ def on_fetch_end(self, batch, on_fetch_start_output: Optional[Any] = None) -> No
241250
def wait(self) -> None:
242251
"""Hook to override to indicate the `DataFetcher` to wait for an event."""
243252

244-
def fetching_function(self) -> Generator:
245-
self.done = False
246-
while not self.done:
247-
self._prefetching(self.prefetch_batches)
248-
249-
while self.batches:
250-
try:
251-
yield_batch = self.pop_batch()
252-
self._fetch_next_batch()
253-
254-
# wait for batch to be available.
255-
self.wait()
256-
257-
# yield last and has next
258-
yield (self.move_data_to_device(yield_batch) if not self.store_on_device else yield_batch, False)
259-
except StopIteration:
260-
self.batches.insert(0, yield_batch)
261-
break
262-
263-
yield from self._consume_prefetched_batches()
264-
265-
def _prefetching(self, prefetch_batches: int) -> None:
253+
def prefetching(self, prefetch_batches: int) -> None:
266254
for _ in range(prefetch_batches):
267255
try:
268256
self._fetch_next_batch()
269257
except StopIteration:
270258
break
271259

260+
def fetching_function(self) -> Optional[Tuple[Any, bool]]:
261+
if self.done:
262+
while self.batches:
263+
return self._get_queued_batch()
264+
raise StopIteration
265+
else:
266+
try:
267+
yield_batch = self.pop_batch()
268+
self._fetch_next_batch()
269+
270+
# wait for batch to be available.
271+
self.wait()
272+
273+
# yield last and has next
274+
return yield_batch, False
275+
# FIXME: Why does this count as a python `referrers` ?
276+
# return (self.move_data_to_device(yield_batch) if not self.store_on_device else yield_batch, False)
277+
except StopIteration:
278+
self.batches.insert(0, yield_batch)
279+
self.done = True
280+
return self._get_queued_batch()
281+
282+
except IndexError:
283+
raise StopIteration
284+
272285
@contextmanager
273286
def apply_profiler(self, name: str) -> Generator:
274287
if self.profiler:
@@ -291,13 +304,13 @@ def _consume_prefetched_batches(self) -> Generator:
291304
while self.batches:
292305
yield from self._yield_batch()
293306

294-
def _yield_batch(self) -> Generator:
307+
def _get_queued_batch(self) -> Tuple[Any, bool]:
295308
self.wait()
296309
batch = self.batches.pop(0)
297310
if not self.store_on_device:
298311
batch = self.move_data_to_device(batch)
299312
is_last = len(self.batches) == 0
300-
yield batch, is_last
313+
return batch, is_last
301314

302315
def move_data_to_device(self, batch: Any) -> Any:
303316
if self.batch_to_device:
@@ -406,7 +419,15 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None:
406419
...
407420
"""
408421

409-
def fetching_function(self) -> Generator:
410-
iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self))
422+
def __init__(self):
423+
super().__init__()
424+
# prevent calling ``move_batch_to_device```
425+
self.store_on_device = True
426+
427+
def prefetching(self, prefetch_batches: int) -> None:
428+
self.iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self))
429+
430+
def fetching_function(self):
411431
while not self.done:
412-
yield iterator, self.fetched, self.done
432+
return self.fetched, (self.iterator, self.done)
433+
raise StopIteration

0 commit comments

Comments
 (0)