From 7bb88147a4c5823bfed45d93f2e59b4130351059 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 12:15:37 +0100 Subject: [PATCH 01/18] fix retrieval of batch indices when dataloader num_workers > 0 --- pl_examples/bug_report/bug_report_model.py | 36 ++++++++----------- .../loops/epoch/prediction_epoch_loop.py | 28 +++++++-------- pytorch_lightning/overrides/distributed.py | 5 +-- 3 files changed, 30 insertions(+), 39 deletions(-) diff --git a/pl_examples/bug_report/bug_report_model.py b/pl_examples/bug_report/bug_report_model.py index 7739630237d32..9701682ac16b3 100644 --- a/pl_examples/bug_report/bug_report_model.py +++ b/pl_examples/bug_report/bug_report_model.py @@ -4,6 +4,7 @@ from torch.utils.data import DataLoader, Dataset from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import BasePredictionWriter class RandomDataset(Dataset): @@ -26,40 +27,33 @@ def __init__(self): def forward(self, x): return self.layer(x) - def training_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("train_loss", loss) - return {"loss": loss} - - def validation_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("valid_loss", loss) - - def test_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("test_loss", loss) + def predict_step(self, batch, batch_idx): + return self(batch) def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) +class Writer(BasePredictionWriter): + def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx): + pass + + def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): + print(batch_indices[0]) + + def run(): - train_data = DataLoader(RandomDataset(32, 64), batch_size=2) - val_data = DataLoader(RandomDataset(32, 64), batch_size=2) - test_data = DataLoader(RandomDataset(32, 64), batch_size=2) + predict_data = DataLoader(RandomDataset(32, 16), batch_size=4, num_workers=2) model = BoringModel() trainer = Trainer( default_root_dir=os.getcwd(), - limit_train_batches=1, - limit_val_batches=1, - limit_test_batches=1, - num_sanity_val_steps=0, max_epochs=1, enable_model_summary=False, + enable_progress_bar=False, + callbacks=[Writer(write_interval="epoch")], ) - trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) - trainer.test(model, dataloaders=test_data) + trainer.predict(model, predict_data) if __name__ == "__main__": diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index a5e885efc4b29..9a5965b8f92cb 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -23,10 +23,10 @@ def __init__(self) -> None: self.current_batch_indices: List[int] = [] self.batch_progress = Progress() + self._dataloader_idx = 0 self._dl_max_batches = 0 self._num_dataloaders = 0 self._warning_cache = WarningCache() - self._all_batch_indices: List[int] = [] @property def done(self) -> bool: @@ -44,7 +44,6 @@ def connect(self, **kwargs: "Loop") -> None: def reset(self) -> None: """Resets the loops internal state.""" - self._all_batch_indices = [] self.predictions = [] self.batch_progress.reset_on_run() @@ -66,6 +65,7 @@ def on_run_start( # type: ignore[override] return_predictions: whether to return the obtained predictions """ void(dataloader_iter, dataloader_idx) + self._dataloader_idx = dataloader_idx self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders self.return_predictions = return_predictions @@ -101,11 +101,8 @@ def advance( # type: ignore[override] def on_run_end(self) -> Tuple[List[Any], List[int]]: """Returns the predictions and the corresponding batch indices.""" - predictions = self.predictions - all_batch_indices = self._all_batch_indices - # free memory - self.predictions = [] - self._all_batch_indices = [] + predictions, self.predictions = self.predictions, [] # free memory + all_batch_indices = self._get_batch_indices(self._dataloader_idx) return predictions, all_batch_indices def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: @@ -121,7 +118,6 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) # extract batch_indices and store them - self._store_batch_indices(dataloader_idx) model_ref = self.trainer.lightning_module @@ -160,12 +156,12 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict step_kwargs["dataloader_idx"] = dataloader_idx return step_kwargs - def _store_batch_indices(self, dataloader_idx: int) -> None: - """Stores the batch indices if the predictions should be stored.""" + def _get_batch_indices(self, dataloader_idx: int) -> List[int]: + """Returns all the seen batch indices if the dataloader has a batch sampler wrapped by our + :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`.""" batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler - if isinstance(batch_sampler, IndexBatchSamplerWrapper): - self.current_batch_indices = batch_sampler.batch_indices - if self.should_store_predictions: - self._all_batch_indices.append(batch_sampler.batch_indices) - else: - warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.") + if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions: + return batch_sampler.batch_indices + + warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.") + return [] diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index f7c2a71b4978d..c2dc359e25ab3 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -124,11 +124,12 @@ class IndexBatchSamplerWrapper: def __init__(self, sampler: BatchSampler) -> None: self._sampler = sampler - self.batch_indices: Optional[List[int]] = None + self.batch_indices: List[int] = [] def __iter__(self) -> Iterator[List[int]]: + self.batch_indices = [] for batch in self._sampler: - self.batch_indices = batch + self.batch_indices.append(batch) yield batch def __len__(self) -> int: From dd8b084ee9a33b90cc95df0dd5b90dcda393491c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 12:39:11 +0100 Subject: [PATCH 02/18] update --- .../loops/epoch/prediction_epoch_loop.py | 19 +++++++------- pytorch_lightning/overrides/distributed.py | 25 ++++++++++++++++--- tests/overrides/test_distributed.py | 4 +-- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 9a5965b8f92cb..6649ee59d5bc7 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -23,10 +23,10 @@ def __init__(self) -> None: self.current_batch_indices: List[int] = [] self.batch_progress = Progress() - self._dataloader_idx = 0 self._dl_max_batches = 0 self._num_dataloaders = 0 self._warning_cache = WarningCache() + self._all_batch_indices: List[List[int]] = [] @property def done(self) -> bool: @@ -65,9 +65,9 @@ def on_run_start( # type: ignore[override] return_predictions: whether to return the obtained predictions """ void(dataloader_iter, dataloader_idx) - self._dataloader_idx = dataloader_idx self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders + self._all_batch_indices = self._get_batch_indices(dataloader_idx) self.return_predictions = return_predictions def advance( # type: ignore[override] @@ -99,10 +99,10 @@ def advance( # type: ignore[override] with self.trainer.profiler.profile("predict_step"): self._predict_step(batch, batch_idx, dataloader_idx) - def on_run_end(self) -> Tuple[List[Any], List[int]]: + def on_run_end(self) -> Tuple[List[Any], List[List[int]]]: """Returns the predictions and the corresponding batch indices.""" - predictions, self.predictions = self.predictions, [] # free memory - all_batch_indices = self._get_batch_indices(self._dataloader_idx) + predictions, all_batch_indices = self.predictions, self._all_batch_indices + self.predictions, self._all_batch_indices = [], [] # free memory return predictions, all_batch_indices def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: @@ -116,11 +116,10 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None """ # configure step_kwargs step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) - - # extract batch_indices and store them - model_ref = self.trainer.lightning_module + self.current_batch_indices = self._all_batch_indices[batch_idx] if self._all_batch_indices else [] + self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx) self.batch_progress.increment_started() @@ -156,12 +155,12 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict step_kwargs["dataloader_idx"] = dataloader_idx return step_kwargs - def _get_batch_indices(self, dataloader_idx: int) -> List[int]: + def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]: """Returns all the seen batch indices if the dataloader has a batch sampler wrapped by our :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`.""" batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions: - return batch_sampler.batch_indices + return batch_sampler.all_batch_indices warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.") return [] diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index c2dc359e25ab3..7bf526eb284e6 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -21,6 +21,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.utilities import rank_zero_deprecation class LightningDistributedModule(_LightningModuleWrapperBase): @@ -123,13 +124,31 @@ class IndexBatchSamplerWrapper: """This class is used to wrap a :class:`torch.utils.data.BatchSampler` and capture its indices.""" def __init__(self, sampler: BatchSampler) -> None: + self.all_batch_indices: List[List[int]] = [] self._sampler = sampler - self.batch_indices: List[int] = [] + self._batch_indices: List[int] = [] + + @property + def batch_indices(self) -> List[int]: + rank_zero_deprecation( + "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.6 and will be removed in v1.8." + " Access the full list `all_batch_indices` instead." + ) + return self._batch_indices + + @batch_indices.setter + def batch_indices(self, indices: List[int]) -> None: + rank_zero_deprecation( + "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.6 and will be removed in v1.8." + " Access the full list `all_batch_indices` instead." + ) + self._batch_indices = indices def __iter__(self) -> Iterator[List[int]]: - self.batch_indices = [] + self.all_batch_indices = [] for batch in self._sampler: - self.batch_indices.append(batch) + self._batch_indices = batch + self.all_batch_indices.append(batch) yield batch def __len__(self) -> int: diff --git a/tests/overrides/test_distributed.py b/tests/overrides/test_distributed.py index c8d982bd733fe..43cf244b66d1e 100644 --- a/tests/overrides/test_distributed.py +++ b/tests/overrides/test_distributed.py @@ -54,9 +54,7 @@ def test_index_batch_sampler(tmpdir): assert batch_sampler.batch_size == index_batch_sampler.batch_size assert batch_sampler.drop_last == index_batch_sampler.drop_last assert batch_sampler.sampler is sampler - - for batch in index_batch_sampler: - assert index_batch_sampler.batch_indices == batch + assert list(index_batch_sampler) == index_batch_sampler.all_batch_indices def test_index_batch_sampler_methods(): From 51f99506610e6169426241cb4b62ea7439141a55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 12:42:34 +0100 Subject: [PATCH 03/18] update names --- .../loops/epoch/prediction_epoch_loop.py | 15 ++++++++------- pytorch_lightning/overrides/distributed.py | 10 +++++----- tests/overrides/test_distributed.py | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 6649ee59d5bc7..52b8e4f703371 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -26,7 +26,7 @@ def __init__(self) -> None: self._dl_max_batches = 0 self._num_dataloaders = 0 self._warning_cache = WarningCache() - self._all_batch_indices: List[List[int]] = [] + self._seen_batch_indices: List[List[int]] = [] @property def done(self) -> bool: @@ -44,6 +44,7 @@ def connect(self, **kwargs: "Loop") -> None: def reset(self) -> None: """Resets the loops internal state.""" + self._seen_batch_indices = [] self.predictions = [] self.batch_progress.reset_on_run() @@ -67,7 +68,7 @@ def on_run_start( # type: ignore[override] void(dataloader_iter, dataloader_idx) self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders - self._all_batch_indices = self._get_batch_indices(dataloader_idx) + self._seen_batch_indices = self._get_batch_indices(dataloader_idx) self.return_predictions = return_predictions def advance( # type: ignore[override] @@ -101,8 +102,8 @@ def advance( # type: ignore[override] def on_run_end(self) -> Tuple[List[Any], List[List[int]]]: """Returns the predictions and the corresponding batch indices.""" - predictions, all_batch_indices = self.predictions, self._all_batch_indices - self.predictions, self._all_batch_indices = [], [] # free memory + predictions, all_batch_indices = self.predictions, self._seen_batch_indices + self.predictions, self._seen_batch_indices = [], [] # free memory return predictions, all_batch_indices def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: @@ -118,7 +119,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) model_ref = self.trainer.lightning_module - self.current_batch_indices = self._all_batch_indices[batch_idx] if self._all_batch_indices else [] + self.current_batch_indices = self._seen_batch_indices[batch_idx] if self._seen_batch_indices else [] self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx) @@ -156,11 +157,11 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict return step_kwargs def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]: - """Returns all the seen batch indices if the dataloader has a batch sampler wrapped by our + """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`.""" batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions: - return batch_sampler.all_batch_indices + return batch_sampler.seen_batch_indices warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.") return [] diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index 7bf526eb284e6..0b2a952fa8b39 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -124,7 +124,7 @@ class IndexBatchSamplerWrapper: """This class is used to wrap a :class:`torch.utils.data.BatchSampler` and capture its indices.""" def __init__(self, sampler: BatchSampler) -> None: - self.all_batch_indices: List[List[int]] = [] + self.seen_batch_indices: List[List[int]] = [] self._sampler = sampler self._batch_indices: List[int] = [] @@ -132,7 +132,7 @@ def __init__(self, sampler: BatchSampler) -> None: def batch_indices(self) -> List[int]: rank_zero_deprecation( "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.6 and will be removed in v1.8." - " Access the full list `all_batch_indices` instead." + " Access the full list `seen_batch_indices` instead." ) return self._batch_indices @@ -140,15 +140,15 @@ def batch_indices(self) -> List[int]: def batch_indices(self, indices: List[int]) -> None: rank_zero_deprecation( "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.6 and will be removed in v1.8." - " Access the full list `all_batch_indices` instead." + " Access the full list `seen_batch_indices` instead." ) self._batch_indices = indices def __iter__(self) -> Iterator[List[int]]: - self.all_batch_indices = [] + self.seen_batch_indices = [] for batch in self._sampler: self._batch_indices = batch - self.all_batch_indices.append(batch) + self.seen_batch_indices.append(batch) yield batch def __len__(self) -> int: diff --git a/tests/overrides/test_distributed.py b/tests/overrides/test_distributed.py index 43cf244b66d1e..e425859fe34df 100644 --- a/tests/overrides/test_distributed.py +++ b/tests/overrides/test_distributed.py @@ -54,7 +54,7 @@ def test_index_batch_sampler(tmpdir): assert batch_sampler.batch_size == index_batch_sampler.batch_size assert batch_sampler.drop_last == index_batch_sampler.drop_last assert batch_sampler.sampler is sampler - assert list(index_batch_sampler) == index_batch_sampler.all_batch_indices + assert list(index_batch_sampler) == index_batch_sampler.seen_batch_indices def test_index_batch_sampler_methods(): From 1a47bd2bec876185f5228157a29a0acfd4887d89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 12:49:12 +0100 Subject: [PATCH 04/18] add deprecation test --- tests/deprecated_api/test_remove_1-8.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index d109b5dbfdcaa..11b2c5e7a3b28 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test deprecated functionality which will be removed in v1.8.0.""" +from unittest.mock import Mock + import pytest import torch +from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import DeviceType, DistributedType from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY @@ -40,3 +43,12 @@ def test_v1_8_0_deprecated_torchtext_batch(): data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3) batch = next(iter(data_iterator)) _ = move_data_to_device(batch=batch, device=torch.device("cpu")) + + +def test_v1_8_0_index_batch_sampler_wrapper_batch_indices(): + sampler = IndexBatchSamplerWrapper(Mock()) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + _ = sampler.batch_indices + + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + sampler.batch_indices = [] From da9dbfbab56950049cc33699b594bef15c615155 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 12:52:47 +0100 Subject: [PATCH 05/18] update changelog --- CHANGELOG.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 406310a5341e4..201526529e384 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,7 +93,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the property `Trainer.slurm_job_id` in favor of the new `SLURMEnvironment.job_id()` method ([#10622](https://github.com/PyTorchLightning/pytorch-lightning/pull/10622)) -- +- Deprecated the access to the attribute `IndexBatchSamplerWrapper.batch_indices` in favor of `IndexBatchSamplerWrapper.seen_batch_indices` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870)) + ### Removed @@ -193,6 +194,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Improved exception message if `rich` version is less than `10.2.2` ([#10839](https://github.com/PyTorchLightning/pytorch-lightning/pull/10839)) +- Fixed a bug that caused incorrect batch indices to be passed to the `BasePredictionWriter` hooks when using a dataloader with `num_workers > 0` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870)) + + + ## [1.5.4] - 2021-11-30 ### Fixed From aa598798018697caa91d438a1d4085cce51fba98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 12:57:25 +0100 Subject: [PATCH 06/18] reset bug report example --- pl_examples/bug_report/bug_report_model.py | 36 +++++++++++++--------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/pl_examples/bug_report/bug_report_model.py b/pl_examples/bug_report/bug_report_model.py index 9701682ac16b3..7739630237d32 100644 --- a/pl_examples/bug_report/bug_report_model.py +++ b/pl_examples/bug_report/bug_report_model.py @@ -4,7 +4,6 @@ from torch.utils.data import DataLoader, Dataset from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import BasePredictionWriter class RandomDataset(Dataset): @@ -27,33 +26,40 @@ def __init__(self): def forward(self, x): return self.layer(x) - def predict_step(self, batch, batch_idx): - return self(batch) - - def configure_optimizers(self): - return torch.optim.SGD(self.layer.parameters(), lr=0.1) + def training_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log("train_loss", loss) + return {"loss": loss} + def validation_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log("valid_loss", loss) -class Writer(BasePredictionWriter): - def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx): - pass + def test_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log("test_loss", loss) - def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): - print(batch_indices[0]) + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) def run(): - predict_data = DataLoader(RandomDataset(32, 16), batch_size=4, num_workers=2) + train_data = DataLoader(RandomDataset(32, 64), batch_size=2) + val_data = DataLoader(RandomDataset(32, 64), batch_size=2) + test_data = DataLoader(RandomDataset(32, 64), batch_size=2) model = BoringModel() trainer = Trainer( default_root_dir=os.getcwd(), + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + num_sanity_val_steps=0, max_epochs=1, enable_model_summary=False, - enable_progress_bar=False, - callbacks=[Writer(write_interval="epoch")], ) - trainer.predict(model, predict_data) + trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) + trainer.test(model, dataloaders=test_data) if __name__ == "__main__": From 837186c6adaf51dc362cdb47dd39434a6baf01ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 13:21:59 +0100 Subject: [PATCH 07/18] update tests --- tests/callbacks/test_prediction_writer.py | 71 ++++++++++++++--------- 1 file changed, 42 insertions(+), 29 deletions(-) diff --git a/tests/callbacks/test_prediction_writer.py b/tests/callbacks/test_prediction_writer.py index 75e0dbd31ec79..07dfa086a4ea3 100644 --- a/tests/callbacks/test_prediction_writer.py +++ b/tests/callbacks/test_prediction_writer.py @@ -11,54 +11,67 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import Mock import pytest +from torch.utils.data import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.callbacks import BasePredictionWriter from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset -def test_prediction_writer(tmpdir): - class CustomPredictionWriter(BasePredictionWriter): - def __init__(self, writer_interval: str): - super().__init__(writer_interval) +class DummyPredictionWriter(BasePredictionWriter): + def write_on_batch_end(self, *args, **kwargs): + pass - self.write_on_batch_end_called = False - self.write_on_epoch_end_called = False + def write_on_epoch_end(self, *args, **kwargs): + pass - def write_on_batch_end(self, *args, **kwargs): - self.write_on_batch_end_called = True - - def write_on_epoch_end(self, *args, **kwargs): - self.write_on_epoch_end_called = True +def test_prediction_writer_invalid_write_interval(): with pytest.raises(MisconfigurationException, match=r"`write_interval` should be one of \['batch"): - CustomPredictionWriter("something") + DummyPredictionWriter("something") + + +def test_prediction_writer_hook_call_intervals(tmpdir): + DummyPredictionWriter.write_on_batch_end = Mock() + DummyPredictionWriter.write_on_epoch_end = Mock() + + dataloader = DataLoader(RandomDataset(32, 64)) model = BoringModel() - cb = CustomPredictionWriter("batch_and_epoch") + cb = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - results = trainer.predict(model, dataloaders=model.train_dataloader()) + results = trainer.predict(model, dataloaders=dataloader) assert len(results) == 4 - assert cb.write_on_batch_end_called - assert cb.write_on_epoch_end_called + assert cb.write_on_batch_end.call_count == 4 + assert cb.write_on_epoch_end.call_count == 1 - cb = CustomPredictionWriter("batch_and_epoch") + DummyPredictionWriter.write_on_batch_end.reset_mock() + DummyPredictionWriter.write_on_epoch_end.reset_mock() + + cb = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) - assert cb.write_on_batch_end_called - assert cb.write_on_epoch_end_called + trainer.predict(model, dataloaders=dataloader, return_predictions=False) + assert cb.write_on_batch_end.call_count == 4 + assert cb.write_on_epoch_end.call_count == 1 + + DummyPredictionWriter.write_on_batch_end.reset_mock() + DummyPredictionWriter.write_on_epoch_end.reset_mock() - cb = CustomPredictionWriter("batch") + cb = DummyPredictionWriter("batch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) - assert cb.write_on_batch_end_called - assert not cb.write_on_epoch_end_called + trainer.predict(model, dataloaders=dataloader, return_predictions=False) + assert cb.write_on_batch_end.call_count == 4 + assert cb.write_on_epoch_end.call_count == 0 + + DummyPredictionWriter.write_on_batch_end.reset_mock() + DummyPredictionWriter.write_on_epoch_end.reset_mock() - cb = CustomPredictionWriter("epoch") + cb = DummyPredictionWriter("epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) - assert not cb.write_on_batch_end_called - assert cb.write_on_epoch_end_called + trainer.predict(model, dataloaders=dataloader, return_predictions=False) + assert cb.write_on_batch_end.call_count == 0 + assert cb.write_on_epoch_end.call_count == 1 From 098b62ffc7a48044ef09d6afd1ea7a3b590519e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 13:48:26 +0100 Subject: [PATCH 08/18] add new test --- tests/callbacks/test_prediction_writer.py | 30 +++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_prediction_writer.py b/tests/callbacks/test_prediction_writer.py index 07dfa086a4ea3..92826676939cf 100644 --- a/tests/callbacks/test_prediction_writer.py +++ b/tests/callbacks/test_prediction_writer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import Mock, call, ANY import pytest from torch.utils.data import DataLoader @@ -27,7 +27,7 @@ def write_on_batch_end(self, *args, **kwargs): pass def write_on_epoch_end(self, *args, **kwargs): - pass + print(*args, **kwargs) def test_prediction_writer_invalid_write_interval(): @@ -75,3 +75,29 @@ def test_prediction_writer_hook_call_intervals(tmpdir): trainer.predict(model, dataloaders=dataloader, return_predictions=False) assert cb.write_on_batch_end.call_count == 0 assert cb.write_on_epoch_end.call_count == 1 + + +def test_prediction_writer_batch_indices(tmpdir): + DummyPredictionWriter.write_on_batch_end = Mock() + DummyPredictionWriter.write_on_epoch_end = Mock() + + dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=2) + model = BoringModel() + writer = DummyPredictionWriter("batch_and_epoch") + trainer = Trainer(limit_predict_batches=4, callbacks=writer) + trainer.predict(model, dataloaders=dataloader) + + writer.write_on_batch_end.assert_has_calls( + [ + call(trainer, model, ANY, [[0, 1, 2, 3]], ANY, 0, 0), + call(trainer, model, ANY, [[4, 5, 6, 7]], ANY, 1, 0), + call(trainer, model, ANY, [[8, 9, 10, 11]], ANY, 2, 0), + call(trainer, model, ANY, [[12, 13, 14, 15]], ANY, 3, 0), + ] + ) + + writer.write_on_epoch_end.assert_has_calls( + [ + call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]), + ] + ) From 6cce70035a3ade27de374fc0db6b57e7f6548669 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Dec 2021 12:49:49 +0000 Subject: [PATCH 09/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/callbacks/test_prediction_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_prediction_writer.py b/tests/callbacks/test_prediction_writer.py index 92826676939cf..557d4dfc9fefc 100644 --- a/tests/callbacks/test_prediction_writer.py +++ b/tests/callbacks/test_prediction_writer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock, call, ANY +from unittest.mock import ANY, call, Mock import pytest from torch.utils.data import DataLoader From 1faa9e9f1a56aa97bb16711033e6adeed1280a68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 14:23:29 +0100 Subject: [PATCH 10/18] update seen indices in advance --- pytorch_lightning/loops/epoch/prediction_epoch_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 52b8e4f703371..4b2a698da19e7 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -89,6 +89,7 @@ def advance( # type: ignore[override] return_predictions: whether to return the obtained predictions """ batch_idx, batch = next(dataloader_iter) + self._seen_batch_indices = self._get_batch_indices(dataloader_idx) if batch is None: raise StopIteration From 9515f376f4632c525f0a4da1c0491c6df0b768ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 14:30:42 +0100 Subject: [PATCH 11/18] push a test that works --- tests/callbacks/test_prediction_writer.py | 27 +++++++++++++---------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/callbacks/test_prediction_writer.py b/tests/callbacks/test_prediction_writer.py index 557d4dfc9fefc..b0fa0b5eb70f7 100644 --- a/tests/callbacks/test_prediction_writer.py +++ b/tests/callbacks/test_prediction_writer.py @@ -27,15 +27,17 @@ def write_on_batch_end(self, *args, **kwargs): pass def write_on_epoch_end(self, *args, **kwargs): - print(*args, **kwargs) + pass def test_prediction_writer_invalid_write_interval(): + """Test that configuring an unknown interval name raises an error.""" with pytest.raises(MisconfigurationException, match=r"`write_interval` should be one of \['batch"): DummyPredictionWriter("something") def test_prediction_writer_hook_call_intervals(tmpdir): + """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined interval.""" DummyPredictionWriter.write_on_batch_end = Mock() DummyPredictionWriter.write_on_epoch_end = Mock() @@ -77,11 +79,12 @@ def test_prediction_writer_hook_call_intervals(tmpdir): assert cb.write_on_epoch_end.call_count == 1 -def test_prediction_writer_batch_indices(tmpdir): +@pytest.mark.parametrize("num_workers", [0, 2]) # TODO: configure slow CI for num_workers=2 +def test_prediction_writer_batch_indices(tmpdir, num_workers): DummyPredictionWriter.write_on_batch_end = Mock() DummyPredictionWriter.write_on_epoch_end = Mock() - dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=2) + dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers) model = BoringModel() writer = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=writer) @@ -89,15 +92,15 @@ def test_prediction_writer_batch_indices(tmpdir): writer.write_on_batch_end.assert_has_calls( [ - call(trainer, model, ANY, [[0, 1, 2, 3]], ANY, 0, 0), - call(trainer, model, ANY, [[4, 5, 6, 7]], ANY, 1, 0), - call(trainer, model, ANY, [[8, 9, 10, 11]], ANY, 2, 0), - call(trainer, model, ANY, [[12, 13, 14, 15]], ANY, 3, 0), + call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0), + call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0), + call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0), + call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0), ] ) - writer.write_on_epoch_end.assert_has_calls( - [ - call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]), - ] - ) + # writer.write_on_epoch_end.assert_has_calls( + # [ + # call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]), + # ] + # ) From d06ea029ed1a645dee249552ac11716587400fe5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Dec 2021 13:33:00 +0000 Subject: [PATCH 12/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/callbacks/test_prediction_writer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_prediction_writer.py b/tests/callbacks/test_prediction_writer.py index b0fa0b5eb70f7..14fdbc12f8494 100644 --- a/tests/callbacks/test_prediction_writer.py +++ b/tests/callbacks/test_prediction_writer.py @@ -37,7 +37,8 @@ def test_prediction_writer_invalid_write_interval(): def test_prediction_writer_hook_call_intervals(tmpdir): - """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined interval.""" + """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined + interval.""" DummyPredictionWriter.write_on_batch_end = Mock() DummyPredictionWriter.write_on_epoch_end = Mock() From d8c4ea180672ccbd4a81ef041158026575753d1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 18:19:05 +0100 Subject: [PATCH 13/18] truncate list to avoid interference with prefetching --- .../loops/epoch/prediction_epoch_loop.py | 3 +++ tests/callbacks/test_prediction_writer.py | 12 ++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 4b2a698da19e7..30cc6a53ac661 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -90,6 +90,9 @@ def advance( # type: ignore[override] """ batch_idx, batch = next(dataloader_iter) self._seen_batch_indices = self._get_batch_indices(dataloader_idx) + # we need to truncate the list of batch indicies due to prefetching in the dataloader and Lightning + self._seen_batch_indices = self._seen_batch_indices[: (self.batch_progress.current.completed + 1)] + if batch is None: raise StopIteration diff --git a/tests/callbacks/test_prediction_writer.py b/tests/callbacks/test_prediction_writer.py index 14fdbc12f8494..5abe4bac23b5b 100644 --- a/tests/callbacks/test_prediction_writer.py +++ b/tests/callbacks/test_prediction_writer.py @@ -80,7 +80,7 @@ def test_prediction_writer_hook_call_intervals(tmpdir): assert cb.write_on_epoch_end.call_count == 1 -@pytest.mark.parametrize("num_workers", [0, 2]) # TODO: configure slow CI for num_workers=2 +@pytest.mark.parametrize("num_workers", [0, 2]) # TODO: configure slow CI for num_workers > 0 def test_prediction_writer_batch_indices(tmpdir, num_workers): DummyPredictionWriter.write_on_batch_end = Mock() DummyPredictionWriter.write_on_epoch_end = Mock() @@ -100,8 +100,8 @@ def test_prediction_writer_batch_indices(tmpdir, num_workers): ] ) - # writer.write_on_epoch_end.assert_has_calls( - # [ - # call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]), - # ] - # ) + writer.write_on_epoch_end.assert_has_calls( + [ + call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]), + ] + ) From b6ce97356ceb28e5fd0aeb6cae1f2135470b5248 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 18:28:01 +0100 Subject: [PATCH 14/18] unused import --- pytorch_lightning/overrides/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index 0b2a952fa8b39..ab3d49bcffd2f 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, cast, Iterator, List, Optional, Sized, Union +from typing import Any, cast, Iterator, List, Sized, Union import torch from torch import Tensor From 390f391079260f5208121ddc7f03de8a9f16546b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 22:56:11 +0100 Subject: [PATCH 15/18] update to slow test --- tests/callbacks/test_prediction_writer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_prediction_writer.py b/tests/callbacks/test_prediction_writer.py index 5abe4bac23b5b..2cd3738ca875f 100644 --- a/tests/callbacks/test_prediction_writer.py +++ b/tests/callbacks/test_prediction_writer.py @@ -20,6 +20,7 @@ from pytorch_lightning.callbacks import BasePredictionWriter from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset +from tests.helpers.runif import RunIf class DummyPredictionWriter(BasePredictionWriter): @@ -80,7 +81,7 @@ def test_prediction_writer_hook_call_intervals(tmpdir): assert cb.write_on_epoch_end.call_count == 1 -@pytest.mark.parametrize("num_workers", [0, 2]) # TODO: configure slow CI for num_workers > 0 +@pytest.mark.parametrize("num_workers", [0, pytest.param(2, marks=RunIf(slow=True))]) def test_prediction_writer_batch_indices(tmpdir, num_workers): DummyPredictionWriter.write_on_batch_end = Mock() DummyPredictionWriter.write_on_epoch_end = Mock() From 24ff3bbabb6568161967262d2aadc208375b3718 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 04:20:16 +0100 Subject: [PATCH 16/18] adjust deprecation version --- pytorch_lightning/overrides/distributed.py | 8 ++++---- tests/deprecated_api/test_remove_1-7.py | 11 +++++++++++ tests/deprecated_api/test_remove_1-8.py | 11 ----------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index ab3d49bcffd2f..5fa9e6b87dbc4 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -131,16 +131,16 @@ def __init__(self, sampler: BatchSampler) -> None: @property def batch_indices(self) -> List[int]: rank_zero_deprecation( - "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.6 and will be removed in v1.8." - " Access the full list `seen_batch_indices` instead." + "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5.5 and will be removed in" + " v1.7. Access the full list `seen_batch_indices` instead." ) return self._batch_indices @batch_indices.setter def batch_indices(self, indices: List[int]) -> None: rank_zero_deprecation( - "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.6 and will be removed in v1.8." - " Access the full list `seen_batch_indices` instead." + "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5.5 and will be removed in" + " v1.7. Access the full list `seen_batch_indices` instead." ) self._batch_indices = indices diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 0065d5947dc26..517db38c2a44c 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -14,6 +14,7 @@ """Test deprecated functionality which will be removed in v1.7.0.""" import os from unittest import mock +from unittest.mock import Mock import pytest @@ -23,6 +24,7 @@ from pytorch_lightning.callbacks.progress import ProgressBar from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger +from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.plugins.environments import ( KubeflowEnvironment, LightningEnvironment, @@ -528,3 +530,12 @@ def is_using_torchelastic(): match=f"MyClusterEnvironment.{method_name}` has been deprecated in v1.6 and will be removed in v1.7" ): MyClusterEnvironment() + + +def test_v1_7_0_index_batch_sampler_wrapper_batch_indices(): + sampler = IndexBatchSamplerWrapper(Mock()) + with pytest.deprecated_call(match="was deprecated in v1.5.5 and will be removed in v1.7"): + _ = sampler.batch_indices + + with pytest.deprecated_call(match="was deprecated in v1.5.5 and will be removed in v1.7"): + sampler.batch_indices = [] diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 11b2c5e7a3b28..7ef0fe2a15e4f 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -12,12 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test deprecated functionality which will be removed in v1.8.0.""" -from unittest.mock import Mock import pytest import torch -from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import DeviceType, DistributedType from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY @@ -43,12 +41,3 @@ def test_v1_8_0_deprecated_torchtext_batch(): data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3) batch = next(iter(data_iterator)) _ = move_data_to_device(batch=batch, device=torch.device("cpu")) - - -def test_v1_8_0_index_batch_sampler_wrapper_batch_indices(): - sampler = IndexBatchSamplerWrapper(Mock()) - with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): - _ = sampler.batch_indices - - with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): - sampler.batch_indices = [] From 6e49f996ec0b4f4b7bb5fc2782fc940634a2aa7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 04:20:40 +0100 Subject: [PATCH 17/18] reduce diff to prev. version --- pytorch_lightning/loops/epoch/prediction_epoch_loop.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 30cc6a53ac661..985779a17c54d 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -121,10 +121,12 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None """ # configure step_kwargs step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) - model_ref = self.trainer.lightning_module + # extract batch_indices and store them self.current_batch_indices = self._seen_batch_indices[batch_idx] if self._seen_batch_indices else [] + model_ref = self.trainer.lightning_module + self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx) self.batch_progress.increment_started() From ed735c06be7d318380450f166c8d715822476d8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 10:39:12 +0100 Subject: [PATCH 18/18] Apply suggestions from code review Co-authored-by: Rohit Gupta --- pytorch_lightning/overrides/distributed.py | 4 ++-- tests/deprecated_api/test_remove_1-7.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index 5fa9e6b87dbc4..66644a91d5eea 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -131,7 +131,7 @@ def __init__(self, sampler: BatchSampler) -> None: @property def batch_indices(self) -> List[int]: rank_zero_deprecation( - "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5.5 and will be removed in" + "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5 and will be removed in" " v1.7. Access the full list `seen_batch_indices` instead." ) return self._batch_indices @@ -139,7 +139,7 @@ def batch_indices(self) -> List[int]: @batch_indices.setter def batch_indices(self, indices: List[int]) -> None: rank_zero_deprecation( - "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5.5 and will be removed in" + "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5 and will be removed in" " v1.7. Access the full list `seen_batch_indices` instead." ) self._batch_indices = indices diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 517db38c2a44c..6f7e1199ab438 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -534,8 +534,8 @@ def is_using_torchelastic(): def test_v1_7_0_index_batch_sampler_wrapper_batch_indices(): sampler = IndexBatchSamplerWrapper(Mock()) - with pytest.deprecated_call(match="was deprecated in v1.5.5 and will be removed in v1.7"): + with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"): _ = sampler.batch_indices - with pytest.deprecated_call(match="was deprecated in v1.5.5 and will be removed in v1.7"): + with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"): sampler.batch_indices = []