diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml
index 1064e603bee1f..3a4277a10fff5 100644
--- a/.github/workflows/ci_test-full.yml
+++ b/.github/workflows/ci_test-full.yml
@@ -159,6 +159,8 @@ jobs:
- name: Examples
run: |
python -m pytest pl_examples -v --durations=10
+ env:
+ PL_USE_MOCKED_MNIST: "1"
- name: Upload pytest results
uses: actions/upload-artifact@v2
diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py
index b86c2d4800d5e..6bc341744c0ba 100644
--- a/benchmarks/test_sharded_parity.py
+++ b/benchmarks/test_sharded_parity.py
@@ -20,7 +20,7 @@
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.plugins import DDPSpawnShardedPlugin
-from tests.helpers.boring_model import BoringModel, RandomDataset
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py
index 693514e0a3620..711af99f7a92d 100644
--- a/pl_examples/basic_examples/mnist_datamodule.py
+++ b/pl_examples/basic_examples/mnist_datamodule.py
@@ -14,30 +14,21 @@
import os
import platform
from typing import Optional
-from urllib.error import HTTPError
from warnings import warn
from torch.utils.data import DataLoader, random_split
from pl_examples import _DATASETS_PATH
from pytorch_lightning import LightningDataModule
+from pytorch_lightning.utilities.debugging_examples import MNIST as PL_MNIST
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
-if _TORCHVISION_AVAILABLE:
- from torchvision import transforms as transform_lib
-
-_TORCHVISION_MNIST_AVAILABLE = not bool(os.getenv("PL_USE_MOCKED_MNIST", False))
-if _TORCHVISION_MNIST_AVAILABLE:
- try:
- from torchvision.datasets import MNIST
-
- MNIST(_DATASETS_PATH, download=True)
- except HTTPError as e:
- print(f"Error {e} downloading `torchvision.datasets.MNIST`")
- _TORCHVISION_MNIST_AVAILABLE = False
-if not _TORCHVISION_MNIST_AVAILABLE:
- print("`torchvision.datasets.MNIST` not available. Using our hosted version")
- from tests.helpers.datasets import MNIST
+# check whether we are in CI. Users running this should get the `torchvision` implementation
+_USE_MOCKED_MNIST = bool(os.getenv("PL_USE_MOCKED_MNIST", False))
+if not _USE_MOCKED_MNIST and PL_MNIST._torchvision_available(_DATASETS_PATH, should_raise=False):
+ from torchvision.datasets import MNIST
+else:
+ MNIST = PL_MNIST
class MNISTDataModule(LightningDataModule):
@@ -148,11 +139,12 @@ def test_dataloader(self):
def default_transforms(self):
if not _TORCHVISION_AVAILABLE:
return None
+ from torchvision import transforms
+
if self.normalize:
- mnist_transforms = transform_lib.Compose(
- [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
+ mnist_transforms = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))]
)
else:
- mnist_transforms = transform_lib.ToTensor()
-
+ mnist_transforms = transforms.ToTensor()
return mnist_transforms
diff --git a/pytorch_lightning/utilities/debugging_examples.py b/pytorch_lightning/utilities/debugging_examples.py
new file mode 100644
index 0000000000000..3f1364ab13e66
--- /dev/null
+++ b/pytorch_lightning/utilities/debugging_examples.py
@@ -0,0 +1,293 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Debugging helpers meant to be used ONLY in the ``pl_examples`` and ``tests`` directories.
+
+Production usage is discouraged. No backwards-compatibility guarantees.
+"""
+import logging
+import os
+import random
+import time
+import urllib.request
+from typing import Optional, Tuple
+from urllib.error import HTTPError
+
+import torch
+from torch.utils.data import DataLoader, Dataset, Subset
+
+from pytorch_lightning import LightningDataModule, LightningModule
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+
+
+class RandomDataset(Dataset):
+ def __init__(self, size: int, length: int):
+ self.len = length
+ self.data = torch.randn(length, size)
+
+ def __getitem__(self, index):
+ return self.data[index]
+
+ def __len__(self):
+ return self.len
+
+
+class BoringModel(LightningModule):
+ def __init__(self):
+ """
+ Testing PL Module
+
+ Use as follows:
+ - subclass
+ - modify the behavior for what you want
+
+ class TestModel(BaseTestModel):
+ def training_step(...):
+ # do your own thing
+
+ or:
+
+ model = BaseTestModel()
+ model.training_epoch_end = None
+
+ """
+ super().__init__()
+ self.layer = torch.nn.Linear(32, 2)
+
+ def forward(self, x):
+ return self.layer(x)
+
+ def loss(self, batch, prediction):
+ # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
+ return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
+
+ def step(self, x):
+ x = self(x)
+ out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
+ return out
+
+ def training_step(self, batch, batch_idx):
+ output = self(batch)
+ loss = self.loss(batch, output)
+ return {"loss": loss}
+
+ def training_step_end(self, training_step_outputs):
+ return training_step_outputs
+
+ def training_epoch_end(self, outputs) -> None:
+ torch.stack([x["loss"] for x in outputs]).mean()
+
+ def validation_step(self, batch, batch_idx):
+ output = self(batch)
+ loss = self.loss(batch, output)
+ return {"x": loss}
+
+ def validation_epoch_end(self, outputs) -> None:
+ torch.stack([x["x"] for x in outputs]).mean()
+
+ def test_step(self, batch, batch_idx):
+ output = self(batch)
+ loss = self.loss(batch, output)
+ return {"y": loss}
+
+ def test_epoch_end(self, outputs) -> None:
+ torch.stack([x["y"] for x in outputs]).mean()
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
+ return [optimizer], [lr_scheduler]
+
+ def train_dataloader(self):
+ return DataLoader(RandomDataset(32, 64))
+
+ def val_dataloader(self):
+ return DataLoader(RandomDataset(32, 64))
+
+ def test_dataloader(self):
+ return DataLoader(RandomDataset(32, 64))
+
+ def predict_dataloader(self):
+ return DataLoader(RandomDataset(32, 64))
+
+
+class BoringDataModule(LightningDataModule):
+ def __init__(self, data_dir: str = "./"):
+ super().__init__()
+ self.data_dir = data_dir
+ self.non_picklable = None
+ self.checkpoint_state: Optional[str] = None
+
+ def prepare_data(self):
+ self.random_full = RandomDataset(32, 64 * 4)
+
+ def setup(self, stage: Optional[str] = None):
+ if stage == "fit" or stage is None:
+ self.random_train = Subset(self.random_full, indices=range(64))
+ self.dims = self.random_train[0].shape
+
+ if stage in ("fit", "validate") or stage is None:
+ self.random_val = Subset(self.random_full, indices=range(64, 64 * 2))
+
+ if stage == "test" or stage is None:
+ self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3))
+ self.dims = getattr(self, "dims", self.random_test[0].shape)
+
+ if stage == "predict" or stage is None:
+ self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))
+ self.dims = getattr(self, "dims", self.random_predict[0].shape)
+
+ def train_dataloader(self):
+ return DataLoader(self.random_train)
+
+ def val_dataloader(self):
+ return DataLoader(self.random_val)
+
+ def test_dataloader(self):
+ return DataLoader(self.random_test)
+
+ def predict_dataloader(self):
+ return DataLoader(self.random_predict)
+
+
+class MNIST(torch.utils.data.Dataset):
+ """
+ Customized `MNIST `_ dataset for testing Pytorch Lightning
+ without the torchvision dependency.
+
+ Part of the code was copied from
+ https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py
+
+ Args:
+ root: Root directory of dataset where ``MNIST/processed/training.pt``
+ and ``MNIST/processed/test.pt`` exist.
+ train: If ``True``, creates dataset from ``training.pt``,
+ otherwise from ``test.pt``.
+ normalize: mean and std deviation of the MNIST dataset.
+ download: If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+
+ RESOURCES = (
+ "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
+ "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
+ )
+
+ TRAIN_FILE_NAME = "training.pt"
+ TEST_FILE_NAME = "test.pt"
+ cache_folder_name = "complete"
+
+ def __init__(
+ self, root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, **kwargs
+ ):
+ super().__init__()
+
+ self.root = root
+ self.train = train # training set or test set
+ self.normalize = normalize
+
+ _USE_MOCKED_MNIST = bool(os.getenv("PL_USE_MOCKED_MNIST", False))
+ if not _USE_MOCKED_MNIST:
+ # avoid users accidentally importing this MNIST implementation
+ self._torchvision_available(self.cached_folder_path)
+ self.prepare_data(download)
+
+ data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
+ self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file))
+
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
+ img = self.data[idx].float().unsqueeze(0)
+ target = int(self.targets[idx])
+
+ if self.normalize is not None and len(self.normalize) == 2:
+ img = self.normalize_tensor(img, *self.normalize)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ @property
+ def cached_folder_path(self) -> str:
+ return os.path.join(self.root, "MNIST", self.cache_folder_name)
+
+ def _check_exists(self, data_folder: str) -> bool:
+ existing = True
+ for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
+ existing = existing and os.path.isfile(os.path.join(data_folder, fname))
+ return existing
+
+ def prepare_data(self, download: bool = True):
+ if download and not self._check_exists(self.cached_folder_path):
+ self._download(self.cached_folder_path)
+ if not self._check_exists(self.cached_folder_path):
+ raise RuntimeError("Dataset not found.")
+
+ def _download(self, data_folder: str) -> None:
+ os.makedirs(data_folder, exist_ok=True)
+ for url in self.RESOURCES:
+ logging.info(f"Downloading {url}")
+ fpath = os.path.join(data_folder, os.path.basename(url))
+ urllib.request.urlretrieve(url, fpath)
+
+ @staticmethod
+ def _try_load(path_data, trials: int = 30, delta: float = 1.0):
+ """Resolving loading from the same time from multiple concurrent processes."""
+ res, exception = None, None
+ assert trials, "at least some trial has to be set"
+ assert os.path.isfile(path_data), f"missing file: {path_data}"
+ for _ in range(trials):
+ try:
+ res = torch.load(path_data)
+ # todo: specify the possible exception
+ except Exception as e:
+ exception = e
+ time.sleep(delta * random.random())
+ else:
+ break
+ if exception is not None:
+ # raise the caught exception
+ raise exception
+ return res
+
+ @staticmethod
+ def normalize_tensor(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> torch.Tensor:
+ mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
+ std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
+ return tensor.sub(mean).div(std)
+
+ @staticmethod
+ def _torchvision_available(root: str, should_raise: bool = True) -> bool:
+ # local import to facilitate mocking
+ from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
+
+ try:
+ from torchvision.datasets import MNIST
+
+ MNIST(root, download=True)
+ download_successful = True
+ except HTTPError as e:
+ # allow using our implementation if torchvision hosting is down
+ print(f"Error {e} downloading `torchvision.datasets.MNIST`")
+ download_successful = False
+ _TORCHVISION_AVAILABLE &= download_successful
+
+ if _TORCHVISION_AVAILABLE and should_raise:
+ raise MisconfigurationException(
+ "The `torchvision` package is available. This implementation is meant for internal use."
+ " Please import MNIST with `from torchvision.datasets import MNIST`"
+ " instead of `from pytorch_lightning.utilities.debug_examples import MNIST`"
+ )
+ return _TORCHVISION_AVAILABLE
diff --git a/tests/__init__.py b/tests/__init__.py
index 9039a6e4b16e9..8e43f2fa13559 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -34,3 +34,6 @@
os.mkdir(_TEMP_PATH)
logging.basicConfig(level=logging.ERROR)
+
+# Use our MNIST implementation on tests to avoid the torchvision dependency
+os.environ.setdefault("PL_USE_MOCKED_MNIST", "1")
diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py
index e27b873a63941..02dba54c61454 100644
--- a/tests/accelerators/test_accelerator_connector.py
+++ b/tests/accelerators/test_accelerator_connector.py
@@ -43,8 +43,8 @@
TorchElasticEnvironment,
)
from pytorch_lightning.utilities import DistributedType
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py
index e67fb166f815b..47dd4cb0c2861 100644
--- a/tests/accelerators/test_common.py
+++ b/tests/accelerators/test_common.py
@@ -17,8 +17,8 @@
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import SingleDevicePlugin
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.accelerators.test_dp import CustomClassificationModelDP
-from tests.helpers.boring_model import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py
index 7280afea02f76..1b82d718ae0fa 100644
--- a/tests/accelerators/test_cpu.py
+++ b/tests/accelerators/test_cpu.py
@@ -12,8 +12,8 @@
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers.boring_model import BoringModel
def test_unsupported_precision_plugins():
diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py
index dc83e4ad4f02e..8faf0066d3ec5 100644
--- a/tests/accelerators/test_ddp.py
+++ b/tests/accelerators/test_ddp.py
@@ -23,8 +23,8 @@
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.accelerators import ddp_model
-from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
from tests.utilities.distributed import call_training_script
diff --git a/tests/accelerators/test_ddp_spawn.py b/tests/accelerators/test_ddp_spawn.py
index a21078cf55542..3163aa19c9e12 100644
--- a/tests/accelerators/test_ddp_spawn.py
+++ b/tests/accelerators/test_ddp_spawn.py
@@ -16,7 +16,7 @@
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import memory
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel
diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py
index efaf761cb7116..21afb07295bdd 100644
--- a/tests/accelerators/test_dp.py
+++ b/tests/accelerators/test_dp.py
@@ -22,8 +22,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities import memory
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel, RandomDataset
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel
diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py
index ba2de43da110e..55b86b91a2b1f 100644
--- a/tests/accelerators/test_ipu.py
+++ b/tests/accelerators/test_ipu.py
@@ -25,8 +25,8 @@
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import _IPU_AVAILABLE
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers.boring_model import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel
diff --git a/tests/accelerators/test_multi_nodes_gpu.py b/tests/accelerators/test_multi_nodes_gpu.py
index ea591e47041f8..5daccd3f1d367 100644
--- a/tests/accelerators/test_multi_nodes_gpu.py
+++ b/tests/accelerators/test_multi_nodes_gpu.py
@@ -25,7 +25,7 @@
from pytorch_lightning import LightningModule # noqa: E402
from pytorch_lightning import Trainer # noqa: E402
-from tests.helpers.boring_model import BoringModel # noqa: E402
+from pytorch_lightning.utilities.debugging_examples import BoringModel # noqa: E402
# TODO(Borda): When multi-node tests are re-enabled (.github/workflows/ci_test-mnodes.yml)
diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py
index 99ac579eb99b0..afb2294c3a7a8 100644
--- a/tests/accelerators/test_tpu_backend.py
+++ b/tests/accelerators/test_tpu_backend.py
@@ -23,8 +23,8 @@
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import TPUSpawnPlugin
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
from tests.helpers.utils import pl_multi_process_test
diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py
index 1a2ecae7b94c0..186196d2f36b1 100644
--- a/tests/callbacks/test_callback_hook_outputs.py
+++ b/tests/callbacks/test_callback_hook_outputs.py
@@ -14,7 +14,7 @@
import pytest
from pytorch_lightning import Callback, Trainer
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
@pytest.mark.parametrize("single_cb", [False, True])
diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py
index d190feed7e1f7..e086dcc2340f2 100644
--- a/tests/callbacks/test_callbacks.py
+++ b/tests/callbacks/test_callbacks.py
@@ -15,7 +15,7 @@
from unittest.mock import call, Mock
from pytorch_lightning import Callback, Trainer
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
def test_callbacks_configured_in_model(tmpdir):
diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py
index 4c3b990dd1b13..98c9f291e49fe 100644
--- a/tests/callbacks/test_early_stopping.py
+++ b/tests/callbacks/test_early_stopping.py
@@ -24,8 +24,8 @@
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel
diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py
index 9a28c7a8fc478..099e78050dbeb 100644
--- a/tests/callbacks/test_finetuning_callback.py
+++ b/tests/callbacks/test_finetuning_callback.py
@@ -22,7 +22,7 @@
from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint
from pytorch_lightning.callbacks.base import Callback
-from tests.helpers import BoringModel, RandomDataset
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
class TestBackboneFinetuningCallback(BackboneFinetuning):
diff --git a/tests/callbacks/test_gpu_stats_monitor.py b/tests/callbacks/test_gpu_stats_monitor.py
index eaba4d30684f3..cb0c8c347707f 100644
--- a/tests/callbacks/test_gpu_stats_monitor.py
+++ b/tests/callbacks/test_gpu_stats_monitor.py
@@ -22,8 +22,8 @@
from pytorch_lightning.callbacks import GPUStatsMonitor
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.loggers.csv_logs import ExperimentWriter
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py
index 82f64d676c774..5d2ced197e359 100644
--- a/tests/callbacks/test_lambda_function.py
+++ b/tests/callbacks/test_lambda_function.py
@@ -16,7 +16,7 @@
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import Callback, LambdaCallback
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
def test_lambda_call(tmpdir):
diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py
index d742781599d77..2d5bc23c79ac3 100644
--- a/tests/callbacks/test_lr_monitor.py
+++ b/tests/callbacks/test_lr_monitor.py
@@ -20,8 +20,8 @@
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.simple_models import ClassificationModel
diff --git a/tests/callbacks/test_prediction_writer.py b/tests/callbacks/test_prediction_writer.py
index 75e0dbd31ec79..ae7b110e4d695 100644
--- a/tests/callbacks/test_prediction_writer.py
+++ b/tests/callbacks/test_prediction_writer.py
@@ -16,8 +16,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import BasePredictionWriter
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel
def test_prediction_writer(tmpdir):
diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py
index 1c3176f39a886..58bc234774819 100644
--- a/tests/callbacks/test_progress_bar.py
+++ b/tests/callbacks/test_progress_bar.py
@@ -25,8 +25,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.progress import tqdm
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py
index fe6e14d1084d9..9793bdc259059 100644
--- a/tests/callbacks/test_pruning.py
+++ b/tests/callbacks/test_pruning.py
@@ -24,8 +24,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py
index c6f44759ba371..10ff0f7cd9ea6 100644
--- a/tests/callbacks/test_rich_progress_bar.py
+++ b/tests/callbacks/test_rich_progress_bar.py
@@ -17,7 +17,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py
index 0bfaa359bb1a8..8fbfbfd15cd14 100644
--- a/tests/callbacks/test_stochastic_weight_avg.py
+++ b/tests/callbacks/test_stochastic_weight_avg.py
@@ -25,8 +25,9 @@
from pytorch_lightning.callbacks import StochasticWeightAveraging
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
+from tests.helpers.datasets import RandomIterableDataset
from tests.helpers.runif import RunIf
diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py
index 44d3c305bb1ac..c6a69a7c15ca6 100644
--- a/tests/callbacks/test_timer.py
+++ b/tests/callbacks/test_timer.py
@@ -21,8 +21,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.timer import Timer
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/callbacks/test_xla_stats_monitor.py b/tests/callbacks/test_xla_stats_monitor.py
index bf4dde3983921..7c7a070569a90 100644
--- a/tests/callbacks/test_xla_stats_monitor.py
+++ b/tests/callbacks/test_xla_stats_monitor.py
@@ -20,8 +20,8 @@
from pytorch_lightning.callbacks import XLAStatsMonitor
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.loggers.csv_logs import ExperimentWriter
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py
index 2e7fcfb8c2253..553bd7a5bd5fc 100644
--- a/tests/checkpointing/test_checkpoint_callback_frequency.py
+++ b/tests/checkpointing/test_checkpoint_callback_frequency.py
@@ -18,7 +18,7 @@
import torch
from pytorch_lightning import callbacks, seed_everything, Trainer
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py
index 314ed899c588a..8ce81ca8bda20 100644
--- a/tests/checkpointing/test_model_checkpoint.py
+++ b/tests/checkpointing/test_model_checkpoint.py
@@ -38,8 +38,8 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/checkpointing/test_torch_saving.py b/tests/checkpointing/test_torch_saving.py
index e95ce1c91d6f4..2910cf60a0abd 100644
--- a/tests/checkpointing/test_torch_saving.py
+++ b/tests/checkpointing/test_torch_saving.py
@@ -16,7 +16,7 @@
import torch
from pytorch_lightning import Trainer
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py
index f76e76b2f9dd9..9213b39008d71 100644
--- a/tests/checkpointing/test_trainer_checkpoint.py
+++ b/tests/checkpointing/test_trainer_checkpoint.py
@@ -19,7 +19,7 @@
import pytorch_lightning as pl
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
def test_finetuning_with_resume_from_checkpoint(tmpdir):
diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py
index 3bfe3aaa6cf80..b13bc95ee9945 100644
--- a/tests/core/test_datamodules.py
+++ b/tests/core/test_datamodules.py
@@ -23,9 +23,9 @@
from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import AttributeDict
+from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
-from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel
diff --git a/tests/core/test_decorators.py b/tests/core/test_decorators.py
index c26836e6a5c90..5f00c06faca43 100644
--- a/tests/core/test_decorators.py
+++ b/tests/core/test_decorators.py
@@ -15,7 +15,7 @@
import torch
from pytorch_lightning.core.decorators import auto_move_data
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py
index 8b787e0f57fcb..0721b3c7b56fe 100644
--- a/tests/core/test_lightning_module.py
+++ b/tests/core/test_lightning_module.py
@@ -22,7 +22,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities import _TORCH_SHARDED_TENSOR_AVAILABLE
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py
index f8c3317b6c595..0a6660ed7bfe7 100644
--- a/tests/core/test_lightning_optimizer.py
+++ b/tests/core/test_lightning_optimizer.py
@@ -20,7 +20,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.core.optimizer import LightningOptimizer
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
def test_lightning_optimizer(tmpdir):
diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py
index 7c6a985d09f33..c288bd4cdc1e9 100644
--- a/tests/core/test_metric_result_integration.py
+++ b/tests/core/test_metric_result_integration.py
@@ -28,8 +28,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_7
-from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py
index 40dfc069ac449..f70d9719ba232 100644
--- a/tests/deprecated_api/test_remove_1-5.py
+++ b/tests/deprecated_api/test_remove_1-5.py
@@ -19,8 +19,8 @@
from pytorch_lightning.core.decorators import auto_move_data
from pytorch_lightning.plugins import DeepSpeedPlugin
from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PyTorchProfiler, SimpleProfiler
+from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel
from tests.deprecated_api import no_deprecated_call
-from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
from tests.helpers.utils import no_warning_call
diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py
index 98c5c4e0320ea..3814800fb4b9e 100644
--- a/tests/deprecated_api/test_remove_1-6.py
+++ b/tests/deprecated_api/test_remove_1-6.py
@@ -18,11 +18,11 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
+from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.model_summary import ModelSummary
from tests.deprecated_api import _soft_unimport_module
-from tests.helpers import BoringDataModule, BoringModel
def test_v1_6_0_trainer_model_hook_mixin(tmpdir):
diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py
index 7581bf2b0c142..84ffde97ec1c3 100644
--- a/tests/deprecated_api/test_remove_1-7.py
+++ b/tests/deprecated_api/test_remove_1-7.py
@@ -16,8 +16,8 @@
import pytest
from pytorch_lightning import LightningDataModule, Trainer
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.deprecated_api import _soft_unimport_module
-from tests.helpers import BoringModel
from tests.helpers.datamodules import MNISTDataModule
diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py
deleted file mode 100644
index e6fa5cfa70795..0000000000000
--- a/tests/helpers/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from tests.helpers.boring_model import BoringDataModule, BoringModel, RandomDataset # noqa: F401
-from tests.helpers.datasets import TrialMNIST # noqa: F401
diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py
deleted file mode 100644
index aeaf85ecf254c..0000000000000
--- a/tests/helpers/boring_model.py
+++ /dev/null
@@ -1,186 +0,0 @@
-# Copyright The PyTorch Lightning team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import Optional
-
-import torch
-from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset
-
-from pytorch_lightning import LightningDataModule, LightningModule
-
-
-class RandomDictDataset(Dataset):
- def __init__(self, size: int, length: int):
- self.len = length
- self.data = torch.randn(length, size)
-
- def __getitem__(self, index):
- a = self.data[index]
- b = a + 2
- return {"a": a, "b": b}
-
- def __len__(self):
- return self.len
-
-
-class RandomDataset(Dataset):
- def __init__(self, size: int, length: int):
- self.len = length
- self.data = torch.randn(length, size)
-
- def __getitem__(self, index):
- return self.data[index]
-
- def __len__(self):
- return self.len
-
-
-class RandomIterableDataset(IterableDataset):
- def __init__(self, size: int, count: int):
- self.count = count
- self.size = size
-
- def __iter__(self):
- for _ in range(self.count):
- yield torch.randn(self.size)
-
-
-class RandomIterableDatasetWithLen(IterableDataset):
- def __init__(self, size: int, count: int):
- self.count = count
- self.size = size
-
- def __iter__(self):
- for _ in range(len(self)):
- yield torch.randn(self.size)
-
- def __len__(self):
- return self.count
-
-
-class BoringModel(LightningModule):
- def __init__(self):
- """
- Testing PL Module
-
- Use as follows:
- - subclass
- - modify the behavior for what you want
-
- class TestModel(BaseTestModel):
- def training_step(...):
- # do your own thing
-
- or:
-
- model = BaseTestModel()
- model.training_epoch_end = None
-
- """
- super().__init__()
- self.layer = torch.nn.Linear(32, 2)
-
- def forward(self, x):
- return self.layer(x)
-
- def loss(self, batch, prediction):
- # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
- return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
-
- def step(self, x):
- x = self(x)
- out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
- return out
-
- def training_step(self, batch, batch_idx):
- output = self(batch)
- loss = self.loss(batch, output)
- return {"loss": loss}
-
- def training_step_end(self, training_step_outputs):
- return training_step_outputs
-
- def training_epoch_end(self, outputs) -> None:
- torch.stack([x["loss"] for x in outputs]).mean()
-
- def validation_step(self, batch, batch_idx):
- output = self(batch)
- loss = self.loss(batch, output)
- return {"x": loss}
-
- def validation_epoch_end(self, outputs) -> None:
- torch.stack([x["x"] for x in outputs]).mean()
-
- def test_step(self, batch, batch_idx):
- output = self(batch)
- loss = self.loss(batch, output)
- return {"y": loss}
-
- def test_epoch_end(self, outputs) -> None:
- torch.stack([x["y"] for x in outputs]).mean()
-
- def configure_optimizers(self):
- optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
- lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
- return [optimizer], [lr_scheduler]
-
- def train_dataloader(self):
- return DataLoader(RandomDataset(32, 64))
-
- def val_dataloader(self):
- return DataLoader(RandomDataset(32, 64))
-
- def test_dataloader(self):
- return DataLoader(RandomDataset(32, 64))
-
- def predict_dataloader(self):
- return DataLoader(RandomDataset(32, 64))
-
-
-class BoringDataModule(LightningDataModule):
- def __init__(self, data_dir: str = "./"):
- super().__init__()
- self.data_dir = data_dir
- self.non_picklable = None
- self.checkpoint_state: Optional[str] = None
-
- def prepare_data(self):
- self.random_full = RandomDataset(32, 64 * 4)
-
- def setup(self, stage: Optional[str] = None):
- if stage == "fit" or stage is None:
- self.random_train = Subset(self.random_full, indices=range(64))
- self.dims = self.random_train[0].shape
-
- if stage in ("fit", "validate") or stage is None:
- self.random_val = Subset(self.random_full, indices=range(64, 64 * 2))
-
- if stage == "test" or stage is None:
- self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3))
- self.dims = getattr(self, "dims", self.random_test[0].shape)
-
- if stage == "predict" or stage is None:
- self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))
- self.dims = getattr(self, "dims", self.random_predict[0].shape)
-
- def train_dataloader(self):
- return DataLoader(self.random_train)
-
- def val_dataloader(self):
- return DataLoader(self.random_val)
-
- def test_dataloader(self):
- return DataLoader(self.random_test)
-
- def predict_dataloader(self):
- return DataLoader(self.random_predict)
diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py
index 10a4c6b6e7ca7..1dfea134de3b1 100644
--- a/tests/helpers/datasets.py
+++ b/tests/helpers/datasets.py
@@ -11,126 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
import os
-import random
-import time
-import urllib.request
-from typing import Optional, Sequence, Tuple
+from typing import Optional, Sequence
import torch
-from torch import Tensor
-from torch.utils.data import Dataset
+from torch.utils.data import Dataset, IterableDataset
-
-class MNIST(Dataset):
- """
- Customized `MNIST `_ dataset for testing Pytorch Lightning
- without the torchvision dependency.
-
- Part of the code was copied from
- https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py
-
- Args:
- root: Root directory of dataset where ``MNIST/processed/training.pt``
- and ``MNIST/processed/test.pt`` exist.
- train: If ``True``, creates dataset from ``training.pt``,
- otherwise from ``test.pt``.
- normalize: mean and std deviation of the MNIST dataset.
- download: If true, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
-
- Examples:
- >>> dataset = MNIST(".", download=True)
- >>> len(dataset)
- 60000
- >>> torch.bincount(dataset.targets)
- tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949])
- """
-
- RESOURCES = (
- "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
- "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
- )
-
- TRAIN_FILE_NAME = "training.pt"
- TEST_FILE_NAME = "test.pt"
- cache_folder_name = "complete"
-
- def __init__(
- self, root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, **kwargs
- ):
- super().__init__()
- self.root = root
- self.train = train # training set or test set
- self.normalize = normalize
-
- self.prepare_data(download)
-
- data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
- self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file))
-
- def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
- img = self.data[idx].float().unsqueeze(0)
- target = int(self.targets[idx])
-
- if self.normalize is not None and len(self.normalize) == 2:
- img = self.normalize_tensor(img, *self.normalize)
-
- return img, target
-
- def __len__(self) -> int:
- return len(self.data)
-
- @property
- def cached_folder_path(self) -> str:
- return os.path.join(self.root, "MNIST", self.cache_folder_name)
-
- def _check_exists(self, data_folder: str) -> bool:
- existing = True
- for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
- existing = existing and os.path.isfile(os.path.join(data_folder, fname))
- return existing
-
- def prepare_data(self, download: bool = True):
- if download and not self._check_exists(self.cached_folder_path):
- self._download(self.cached_folder_path)
- if not self._check_exists(self.cached_folder_path):
- raise RuntimeError("Dataset not found.")
-
- def _download(self, data_folder: str) -> None:
- os.makedirs(data_folder, exist_ok=True)
- for url in self.RESOURCES:
- logging.info(f"Downloading {url}")
- fpath = os.path.join(data_folder, os.path.basename(url))
- urllib.request.urlretrieve(url, fpath)
-
- @staticmethod
- def _try_load(path_data, trials: int = 30, delta: float = 1.0):
- """Resolving loading from the same time from multiple concurrent processes."""
- res, exception = None, None
- assert trials, "at least some trial has to be set"
- assert os.path.isfile(path_data), f"missing file: {path_data}"
- for _ in range(trials):
- try:
- res = torch.load(path_data)
- # todo: specify the possible exception
- except Exception as e:
- exception = e
- time.sleep(delta * random.random())
- else:
- break
- if exception is not None:
- # raise the caught exception
- raise exception
- return res
-
- @staticmethod
- def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
- mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
- std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
- return tensor.sub(mean).div(std)
+from pytorch_lightning.utilities.debugging_examples import MNIST
class TrialMNIST(MNIST):
@@ -214,3 +101,40 @@ def __getitem__(self, idx):
def __len__(self):
return len(self.y)
+
+
+class RandomDictDataset(Dataset):
+ def __init__(self, size: int, length: int):
+ self.len = length
+ self.data = torch.randn(length, size)
+
+ def __getitem__(self, index):
+ a = self.data[index]
+ b = a + 2
+ return {"a": a, "b": b}
+
+ def __len__(self):
+ return self.len
+
+
+class RandomIterableDataset(IterableDataset):
+ def __init__(self, size: int, count: int):
+ self.count = count
+ self.size = size
+
+ def __iter__(self):
+ for _ in range(self.count):
+ yield torch.randn(self.size)
+
+
+class RandomIterableDatasetWithLen(IterableDataset):
+ def __init__(self, size: int, count: int):
+ self.count = count
+ self.size = size
+
+ def __iter__(self):
+ for _ in range(len(self)):
+ yield torch.randn(self.size)
+
+ def __len__(self):
+ return self.count
diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py
index 3e5066d708da0..02983c2686a92 100644
--- a/tests/helpers/pipelines.py
+++ b/tests/helpers/pipelines.py
@@ -16,7 +16,7 @@
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.utilities import DistributedType
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.utils import get_default_logger, load_model_from_checkpoint, reset_seed
diff --git a/tests/helpers/test_models.py b/tests/helpers/test_models.py
index b6d853f2ac594..bc4cf28d8b035 100644
--- a/tests/helpers/test_models.py
+++ b/tests/helpers/test_models.py
@@ -16,8 +16,8 @@
import pytest
from pytorch_lightning import Trainer
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN
-from tests.helpers.boring_model import BoringModel
from tests.helpers.datamodules import ClassifDataModule, RegressDataModule
from tests.helpers.simple_models import ClassificationModel, RegressionModel
diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py
index eb1d53ccc62e3..1a29f2f8e9ad5 100644
--- a/tests/loggers/test_all.py
+++ b/tests/loggers/test_all.py
@@ -32,7 +32,7 @@
WandbLogger,
)
from pytorch_lightning.loggers.base import DummyExperiment
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
from tests.loggers.test_comet import _patch_comet_atexit
from tests.loggers.test_mlflow import mock_mlflow_run_creation
diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py
index 7c8673c34956f..c884e14818274 100644
--- a/tests/loggers/test_base.py
+++ b/tests/loggers/test_base.py
@@ -23,7 +23,7 @@
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger
from pytorch_lightning.utilities import rank_zero_only
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
def test_logger_collection():
diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py
index cb3641e669d50..6f10ae5d4472f 100644
--- a/tests/loggers/test_comet.py
+++ b/tests/loggers/test_comet.py
@@ -18,8 +18,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel
def _patch_comet_atexit(monkeypatch):
diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py
index 8523edf69a980..7069428128b50 100644
--- a/tests/loggers/test_mlflow.py
+++ b/tests/loggers/test_mlflow.py
@@ -20,7 +20,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import _MLFLOW_AVAILABLE, MLFlowLogger
from pytorch_lightning.loggers.mlflow import MLFLOW_RUN_NAME, resolve_tags
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None):
diff --git a/tests/loggers/test_neptune.py b/tests/loggers/test_neptune.py
index c58bb4ef59fbe..a60a56149c9c7 100644
--- a/tests/loggers/test_neptune.py
+++ b/tests/loggers/test_neptune.py
@@ -17,7 +17,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import NeptuneLogger
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
@patch("pytorch_lightning.loggers.neptune.neptune")
diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py
index a1c66c0559d75..47d3a221d0c2b 100644
--- a/tests/loggers/test_tensorboard.py
+++ b/tests/loggers/test_tensorboard.py
@@ -25,8 +25,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.imports import _compare_version
-from tests.helpers import BoringModel
@pytest.mark.skipif(
diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py
index 0684727e84ac8..d8411b7d201fe 100644
--- a/tests/loggers/test_wandb.py
+++ b/tests/loggers/test_wandb.py
@@ -20,8 +20,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
diff --git a/tests/loops/test_iterator_batch_processor.py b/tests/loops/test_iterator_batch_processor.py
index 2cd6a172f6941..4beeb7ad19557 100644
--- a/tests/loops/test_iterator_batch_processor.py
+++ b/tests/loops/test_iterator_batch_processor.py
@@ -17,9 +17,9 @@
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT
-from tests.helpers import BoringModel, RandomDataset
_BATCH_SIZE = 32
_DATASET_LEN = 64
diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py
index 65cbebc8203e5..e62c1f070c04c 100644
--- a/tests/loops/test_loops.py
+++ b/tests/loops/test_loops.py
@@ -24,7 +24,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.loops import Loop, TrainingBatchLoop
from pytorch_lightning.trainer.progress import BaseProgress
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py
index f851a7de0837e..4fde65d529d83 100644
--- a/tests/models/data/horovod/train_default_model.py
+++ b/tests/models/data/horovod/train_default_model.py
@@ -37,7 +37,7 @@
else:
print("You requested to import Horovod which is missing or not supported for your OS.")
-from tests.helpers import BoringModel # noqa: E402
+from pytorch_lightning.utilities.debugging_examples import BoringModel # noqa: E402
from tests.helpers.utils import reset_seed, set_random_master_port # noqa: E402
parser = argparse.ArgumentParser()
diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py
index 79c0cf7c12f15..60725d6981297 100644
--- a/tests/models/test_amp.py
+++ b/tests/models/test_amp.py
@@ -23,8 +23,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py
index 015c79458e1aa..59c595b259884 100644
--- a/tests/models/test_cpu.py
+++ b/tests/models/test_cpu.py
@@ -19,7 +19,7 @@
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel
diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py
index 1d23ed2f76907..23c892c8301aa 100644
--- a/tests/models/test_gpu.py
+++ b/tests/models/test_gpu.py
@@ -25,9 +25,9 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.utilities import device_parser
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version
-from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.imports import Batch, Dataset, Example, Field, LabelField
from tests.helpers.runif import RunIf
diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py
index fb0c0d4a8da76..3e6b61c83b436 100644
--- a/tests/models/test_grad_norm.py
+++ b/tests/models/test_grad_norm.py
@@ -17,7 +17,7 @@
import pytest
from pytorch_lightning import Trainer
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.utils import reset_seed
diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py
index 990178b09d07f..fbfaeb016762e 100644
--- a/tests/models/test_hooks.py
+++ b/tests/models/test_hooks.py
@@ -21,7 +21,7 @@
from torch.utils.data import DataLoader
from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, Trainer
-from tests.helpers import BoringDataModule, BoringModel, RandomDataset
+from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel, RandomDataset
from tests.helpers.runif import RunIf
diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py
index c3f3cdcf7ffc2..3494aa6e3e418 100644
--- a/tests/models/test_horovod.py
+++ b/tests/models/test_horovod.py
@@ -30,7 +30,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.advanced_models import BasicGAN
from tests.helpers.runif import RunIf
diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py
index 7e37fab22cd3f..3d70ef3956d48 100644
--- a/tests/models/test_hparams.py
+++ b/tests/models/test_hparams.py
@@ -31,8 +31,8 @@
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel, RandomDataset
if _HYDRA_EXPERIMENTAL_AVAILABLE:
from hydra.experimental import compose, initialize
diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py
index cec01e828d1ed..10f7598521dd3 100644
--- a/tests/models/test_onnx.py
+++ b/tests/models/test_onnx.py
@@ -21,7 +21,7 @@
import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py
index d1d870fa30116..6662da51ad88b 100644
--- a/tests/models/test_restore.py
+++ b/tests/models/test_restore.py
@@ -28,7 +28,7 @@
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.states import RunningStage
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel
diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py
index 5748f6d9a1095..912124c2ad80b 100644
--- a/tests/models/test_torchscript.py
+++ b/tests/models/test_torchscript.py
@@ -20,7 +20,7 @@
from fsspec.implementations.local import LocalFileSystem
from pytorch_lightning.utilities.cloud_io import get_filesystem
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.advanced_models import BasicGAN, ParityModuleRNN
from tests.helpers.datamodules import MNISTDataModule
from tests.helpers.runif import RunIf
diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py
index 5aa605cdf38bb..340d58b4020ea 100644
--- a/tests/models/test_tpu.py
+++ b/tests/models/test_tpu.py
@@ -27,9 +27,9 @@
from pytorch_lightning.plugins import TPUSpawnPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync
from pytorch_lightning.utilities import _TPU_AVAILABLE
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
from tests.helpers.utils import pl_multi_process_test
diff --git a/tests/models/test_truncated_bptt.py b/tests/models/test_truncated_bptt.py
index ab10a527e3cda..ac7eda8db98c0 100644
--- a/tests/models/test_truncated_bptt.py
+++ b/tests/models/test_truncated_bptt.py
@@ -16,7 +16,7 @@
import torch
from pytorch_lightning import Trainer
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
@pytest.mark.parametrize("n_hidden_states", (1, 2))
diff --git a/tests/overrides/test_base.py b/tests/overrides/test_base.py
index 4b76fd028af66..90dba69d5ab17 100644
--- a/tests/overrides/test_base.py
+++ b/tests/overrides/test_base.py
@@ -20,7 +20,7 @@
_LightningPrecisionModuleWrapperBase,
unwrap_lightning_module,
)
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
@pytest.mark.parametrize("wrapper_class", [_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase])
diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py
index c481fdbef0f33..5dec4319ba366 100644
--- a/tests/overrides/test_data_parallel.py
+++ b/tests/overrides/test_data_parallel.py
@@ -27,7 +27,7 @@
unsqueeze_scalar_tensor,
)
from pytorch_lightning.trainer.states import RunningStage
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/plugins/environments/torch_elastic_deadlock.py b/tests/plugins/environments/torch_elastic_deadlock.py
index ac2348285d9af..40d04932c4587 100644
--- a/tests/plugins/environments/torch_elastic_deadlock.py
+++ b/tests/plugins/environments/torch_elastic_deadlock.py
@@ -4,8 +4,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
-from tests.helpers.boring_model import BoringModel
if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1":
diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py
index 15ec43973b0ed..2e237f70ab155 100644
--- a/tests/plugins/test_amp_plugins.py
+++ b/tests/plugins/test_amp_plugins.py
@@ -21,8 +21,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py
index 810127a03f361..f405d4039bd4b 100644
--- a/tests/plugins/test_checkpoint_io_plugin.py
+++ b/tests/plugins/test_checkpoint_io_plugin.py
@@ -20,9 +20,9 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import CheckpointIO, DeepSpeedPlugin, SingleDevicePlugin, TPUSpawnPlugin
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PATH
-from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py
index 939c05d1b7afe..b84aded417511 100644
--- a/tests/plugins/test_custom_plugin.py
+++ b/tests/plugins/test_custom_plugin.py
@@ -19,7 +19,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin, SingleDevicePlugin
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py
index 3dd4897864947..4b8f9639fed5d 100644
--- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py
+++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py
@@ -9,8 +9,8 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import DDPFullyShardedPlugin, FullyShardedNativeMixedPrecisionPlugin
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py
index 60ec1930852d0..9dae8d432f4d7 100644
--- a/tests/plugins/test_ddp_plugin.py
+++ b/tests/plugins/test_ddp_plugin.py
@@ -21,7 +21,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins.environments import LightningEnvironment
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py
index 49c4cd18ef316..7f11d225e1cbc 100644
--- a/tests/plugins/test_ddp_plugin_with_comm_hook.py
+++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py
@@ -16,7 +16,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8:
diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py
index 1ab94446c8176..495e8cc1077cd 100644
--- a/tests/plugins/test_ddp_spawn_plugin.py
+++ b/tests/plugins/test_ddp_spawn_plugin.py
@@ -15,7 +15,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPSpawnPlugin
-from tests.helpers.boring_model import BoringDataModule, BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py
index a5e4e1d189aaa..2aafb41d33ce4 100644
--- a/tests/plugins/test_deepspeed_plugin.py
+++ b/tests/plugins/test_deepspeed_plugin.py
@@ -15,10 +15,11 @@
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
-from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
from tests.helpers.datamodules import ClassifDataModule
+from tests.helpers.datasets import RandomIterableDataset
from tests.helpers.runif import RunIf
if _DEEPSPEED_AVAILABLE:
diff --git a/tests/plugins/test_double_plugin.py b/tests/plugins/test_double_plugin.py
index 71595024c80af..22aba697e65a5 100644
--- a/tests/plugins/test_double_plugin.py
+++ b/tests/plugins/test_double_plugin.py
@@ -21,7 +21,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DoublePrecisionPlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
-from tests.helpers.boring_model import BoringModel, RandomDataset
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py
index 82e899a6f4aac..ab0767eb24d17 100644
--- a/tests/plugins/test_sharded_plugin.py
+++ b/tests/plugins/test_sharded_plugin.py
@@ -7,8 +7,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/plugins/test_single_device_plugin.py b/tests/plugins/test_single_device_plugin.py
index 8d42bbb3e99ec..e2d53f62ca918 100644
--- a/tests/plugins/test_single_device_plugin.py
+++ b/tests/plugins/test_single_device_plugin.py
@@ -15,7 +15,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import SingleDevicePlugin
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py
index 036e26a7c4a2f..c4191eca8ae52 100644
--- a/tests/plugins/test_tpu_spawn.py
+++ b/tests/plugins/test_tpu_spawn.py
@@ -22,8 +22,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.training_type import TPUSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.dataloaders import CustomNotImplementedErrorDataloader
from tests.helpers.runif import RunIf
from tests.helpers.utils import pl_multi_process_test
diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py
index 2145ab83e9cdb..5477456e49c39 100644
--- a/tests/profiler/test_profiler.py
+++ b/tests/profiler/test_profiler.py
@@ -27,9 +27,9 @@
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
-from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005
diff --git a/tests/profiler/test_xla_profiler.py b/tests/profiler/test_xla_profiler.py
index 2afbf69a6d0b0..435851867d129 100644
--- a/tests/profiler/test_xla_profiler.py
+++ b/tests/profiler/test_xla_profiler.py
@@ -19,7 +19,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.profiler import XLAProfiler
from pytorch_lightning.utilities import _TPU_AVAILABLE
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
if _TPU_AVAILABLE:
diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py
index 43158865f9e75..586c7ce503e9a 100644
--- a/tests/trainer/connectors/test_callback_connector.py
+++ b/tests/trainer/connectors/test_callback_connector.py
@@ -24,7 +24,7 @@
ProgressBar,
)
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
def test_checkpoint_callbacks_are_last(tmpdir):
diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py
index 83a45f02224d5..f59345a4f5117 100644
--- a/tests/trainer/connectors/test_checkpoint_connector.py
+++ b/tests/trainer/connectors/test_checkpoint_connector.py
@@ -17,7 +17,7 @@
import torch
from pytorch_lightning import Trainer
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
class HPCHookdedModel(BoringModel):
diff --git a/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py
index 15e817da975be..3c24a42b4e0c2 100644
--- a/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py
+++ b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py
@@ -15,7 +15,7 @@
from torch.utils.data import Dataset
from pytorch_lightning import Trainer
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
class RandomDatasetA(Dataset):
diff --git a/tests/trainer/flags/test_check_val_every_n_epoch.py b/tests/trainer/flags/test_check_val_every_n_epoch.py
index 97c6ddf7803ab..fc133f7452f3c 100644
--- a/tests/trainer/flags/test_check_val_every_n_epoch.py
+++ b/tests/trainer/flags/test_check_val_every_n_epoch.py
@@ -14,7 +14,7 @@
import pytest
from pytorch_lightning.trainer import Trainer
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
@pytest.mark.parametrize(
diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py
index cff0c8a43727d..5266f92a87832 100644
--- a/tests/trainer/flags/test_fast_dev_run.py
+++ b/tests/trainer/flags/test_fast_dev_run.py
@@ -7,7 +7,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers.base import DummyLogger
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
@pytest.mark.parametrize("tuner_alg", ["batch size scaler", "learning rate finder"])
diff --git a/tests/trainer/flags/test_min_max_epochs.py b/tests/trainer/flags/test_min_max_epochs.py
index 059a447e10edb..e57aa3e134ad6 100644
--- a/tests/trainer/flags/test_min_max_epochs.py
+++ b/tests/trainer/flags/test_min_max_epochs.py
@@ -1,7 +1,7 @@
import pytest
from pytorch_lightning import Trainer
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
@pytest.mark.parametrize(
diff --git a/tests/trainer/flags/test_overfit_batches.py b/tests/trainer/flags/test_overfit_batches.py
index 798b9988469df..043174ddc7028 100644
--- a/tests/trainer/flags/test_overfit_batches.py
+++ b/tests/trainer/flags/test_overfit_batches.py
@@ -15,7 +15,7 @@
import torch
from pytorch_lightning import Trainer
-from tests.helpers.boring_model import BoringModel, RandomDataset
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
def test_overfit_multiple_val_loaders(tmpdir):
diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py
index 8bc6ef774f4a1..97823438d8d52 100644
--- a/tests/trainer/flags/test_val_check_interval.py
+++ b/tests/trainer/flags/test_val_check_interval.py
@@ -14,7 +14,7 @@
import pytest
from pytorch_lightning.trainer import Trainer
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
@pytest.mark.parametrize("max_epochs", [1, 2, 3])
diff --git a/tests/trainer/logging_/test_distributed_logging.py b/tests/trainer/logging_/test_distributed_logging.py
index 727e95b894060..c3b825b2e4003 100644
--- a/tests/trainer/logging_/test_distributed_logging.py
+++ b/tests/trainer/logging_/test_distributed_logging.py
@@ -18,7 +18,7 @@
import pytorch_lightning as pl
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers.base import LightningLoggerBase
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py
index 8579bc044734a..a996ae0578072 100644
--- a/tests/trainer/logging_/test_eval_loop_logging.py
+++ b/tests/trainer/logging_/test_eval_loop_logging.py
@@ -25,7 +25,7 @@
from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.loggers import TensorBoardLogger
-from tests.helpers import BoringModel, RandomDataset
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
def test__validation_step__log(tmpdir):
diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py
index ed7711b32ffda..ded4cf6662dae 100644
--- a/tests/trainer/logging_/test_logger_connector.py
+++ b/tests/trainer/logging_/test_logger_connector.py
@@ -24,8 +24,8 @@
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator
from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
from tests.models.test_hooks import get_members
diff --git a/tests/trainer/logging_/test_progress_bar_logging.py b/tests/trainer/logging_/test_progress_bar_logging.py
index d19251c02d37c..b8fce774f4ba0 100644
--- a/tests/trainer/logging_/test_progress_bar_logging.py
+++ b/tests/trainer/logging_/test_progress_bar_logging.py
@@ -1,7 +1,7 @@
import pytest
from pytorch_lightning import Trainer
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
def test_logging_to_progress_bar_with_reserved_key(tmpdir):
diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py
index b218df9e3b15d..b34cd48e08168 100644
--- a/tests/trainer/logging_/test_train_loop_logging.py
+++ b/tests/trainer/logging_/test_train_loop_logging.py
@@ -26,8 +26,9 @@
from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers.boring_model import BoringModel, RandomDictDataset
+from tests.helpers.datasets import RandomDictDataset
from tests.helpers.runif import RunIf
diff --git a/tests/trainer/loops/test_all.py b/tests/trainer/loops/test_all.py
index 5975937018e16..9d229125086fc 100644
--- a/tests/trainer/loops/test_all.py
+++ b/tests/trainer/loops/test_all.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Callback, Trainer
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py
index d7acd7e65727e..92bf724e0f029 100644
--- a/tests/trainer/loops/test_evaluation_loop.py
+++ b/tests/trainer/loops/test_evaluation_loop.py
@@ -17,7 +17,7 @@
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
-from tests.helpers.boring_model import BoringModel, RandomDataset
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
diff --git a/tests/trainer/loops/test_flow_warnings.py b/tests/trainer/loops/test_flow_warnings.py
index e14bd8825510a..199e7e9c5a363 100644
--- a/tests/trainer/loops/test_flow_warnings.py
+++ b/tests/trainer/loops/test_flow_warnings.py
@@ -14,7 +14,7 @@
import warnings
from pytorch_lightning import Trainer
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
class TestModel(BoringModel):
diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py
index 22258b8e52eea..0c8e148f00038 100644
--- a/tests/trainer/loops/test_training_loop.py
+++ b/tests/trainer/loops/test_training_loop.py
@@ -17,8 +17,8 @@
import torch
from pytorch_lightning import seed_everything, Trainer
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel
def test_outputs_format(tmpdir):
diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py
index 4ee9d858d44c9..fcf12c7de0336 100644
--- a/tests/trainer/loops/test_training_loop_flow_scalar.py
+++ b/tests/trainer/loops/test_training_loop_flow_scalar.py
@@ -19,7 +19,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.states import RunningStage
-from tests.helpers.boring_model import BoringModel, RandomDataset
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from tests.helpers.deterministic_model import DeterministicModel
from tests.helpers.utils import no_warning_call
diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py
index 670e8b4842a89..27904d6048a39 100644
--- a/tests/trainer/optimization/test_manual_optimization.py
+++ b/tests/trainer/optimization/test_manual_optimization.py
@@ -24,7 +24,7 @@
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.callbacks import Callback
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py
index fccb4e60657d9..1ff6a6242219b 100644
--- a/tests/trainer/optimization/test_multiple_optimizers.py
+++ b/tests/trainer/optimization/test_multiple_optimizers.py
@@ -18,7 +18,7 @@
import torch
import pytorch_lightning as pl
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
class MultiOptModel(BoringModel):
diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py
index 3c8a3d5ae8e68..e16d6b76fecad 100644
--- a/tests/trainer/optimization/test_optimizers.py
+++ b/tests/trainer/optimization/test_optimizers.py
@@ -19,8 +19,8 @@
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/trainer/properties/log_dir.py b/tests/trainer/properties/log_dir.py
index d940dabd99c09..d339885ee6ecd 100644
--- a/tests/trainer/properties/log_dir.py
+++ b/tests/trainer/properties/log_dir.py
@@ -16,7 +16,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
class TestModel(BoringModel):
diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py
index 9a0527d46330c..2411a82086e7e 100644
--- a/tests/trainer/properties/test_get_model.py
+++ b/tests/trainer/properties/test_get_model.py
@@ -13,7 +13,7 @@
# limitations under the License.
from pytorch_lightning import Trainer
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py
index daa01b5abe7b5..c0bdf910b440d 100644
--- a/tests/trainer/test_config_validator.py
+++ b/tests/trainer/test_config_validator.py
@@ -15,8 +15,8 @@
import torch
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests.helpers import BoringModel, RandomDataset
def test_wrong_train_setting(tmpdir):
diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py
index e9d5d3cc047cb..cfaa872b0c8e1 100644
--- a/tests/trainer/test_data_loading.py
+++ b/tests/trainer/test_data_loading.py
@@ -19,9 +19,9 @@
from torch.utils.data.sampler import BatchSampler, Sampler, SequentialSampler
from pytorch_lightning import Trainer
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7
-from tests.helpers import BoringModel, RandomDataset
@pytest.mark.skipif(
diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py
index e6686cf8117e0..1a97ee859a1ec 100644
--- a/tests/trainer/test_dataloaders.py
+++ b/tests/trainer/test_dataloaders.py
@@ -26,9 +26,10 @@
from pytorch_lightning import Callback, seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
-from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset, RandomIterableDatasetWithLen
+from tests.helpers.datasets import RandomIterableDataset, RandomIterableDatasetWithLen
from tests.helpers.runif import RunIf
diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py
index 861885b4c052b..3d34a1cf25cb4 100644
--- a/tests/trainer/test_states.py
+++ b/tests/trainer/test_states.py
@@ -15,7 +15,7 @@
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
def test_initialize_state():
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index 0c4833c903a66..27c44a15d4ebb 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -41,10 +41,10 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import DeviceType, DistributedType
from pytorch_lightning.utilities.cloud_io import load as pl_load
+from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
from tests.base import EvalModelTemplate
-from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py
index de1873ee391d8..1b12e4586ca07 100644
--- a/tests/tuner/test_lr_finder.py
+++ b/tests/tuner/test_lr_finder.py
@@ -18,9 +18,9 @@
import torch
from pytorch_lightning import seed_everything, Trainer
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
-from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.simple_models import ClassificationModel
diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py
index 3d4aa35a7da03..06b61107b2a5c 100644
--- a/tests/tuner/test_scale_batch_size.py
+++ b/tests/tuner/test_scale_batch_size.py
@@ -22,9 +22,9 @@
from pytorch_lightning import Trainer
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import AMPType
+from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
-from tests.helpers import BoringDataModule, BoringModel, RandomDataset
from tests.helpers.datamodules import MNISTDataModule
from tests.helpers.runif import RunIf
diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py
index 62c543619ee4d..0e9591336ac39 100644
--- a/tests/utilities/test_all_gather_grad.py
+++ b/tests/utilities/test_all_gather_grad.py
@@ -6,7 +6,7 @@
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.utilities import AllGatherGrad
-from tests.helpers.boring_model import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py
index e665fc79e4323..76ea0d2f4872b 100644
--- a/tests/utilities/test_auto_restart.py
+++ b/tests/utilities/test_auto_restart.py
@@ -37,10 +37,10 @@
CaptureIterableDataset,
FastForwardSampler,
)
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
-from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py
index b72eabc14a814..f6b25df35ec7f 100644
--- a/tests/utilities/test_cli.py
+++ b/tests/utilities/test_cli.py
@@ -33,8 +33,8 @@
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import instantiate_class, LightningArgumentParser, LightningCLI, SaveConfigCallback
+from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
-from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
torchvision_version = version.parse("0")
@@ -365,8 +365,11 @@ def test_lightning_cli_save_config_cases(tmpdir):
def test_lightning_cli_config_and_subclass_mode(tmpdir):
config = dict(
- model=dict(class_path="tests.helpers.BoringModel"),
- data=dict(class_path="tests.helpers.BoringDataModule", init_args=dict(data_dir=str(tmpdir))),
+ model=dict(class_path="pytorch_lightning.utilities.debugging_examples.BoringModel"),
+ data=dict(
+ class_path="pytorch_lightning.utilities.debugging_examples.BoringDataModule",
+ init_args=dict(data_dir=str(tmpdir)),
+ ),
trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, weights_summary=None),
)
config_path = tmpdir / "config.yaml"
@@ -414,7 +417,7 @@ def test_lightning_cli_help():
if param not in skip_params:
assert f"--trainer.{param}" in out
- cli_args = ["any.py", "--data.help=tests.helpers.BoringDataModule"]
+ cli_args = ["any.py", "--data.help=pytorch_lightning.utilities.debugging_examples.BoringDataModule"]
out = StringIO()
with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit):
any_model_any_data_cli()
@@ -426,8 +429,8 @@ def test_lightning_cli_print_config():
cli_args = [
"any.py",
"--seed_everything=1234",
- "--model=tests.helpers.BoringModel",
- "--data=tests.helpers.BoringDataModule",
+ "--model=pytorch_lightning.utilities.debugging_examples.BoringModel",
+ "--data=pytorch_lightning.utilities.debugging_examples.BoringDataModule",
"--print_config",
]
@@ -437,8 +440,8 @@ def test_lightning_cli_print_config():
outval = yaml.safe_load(out.getvalue())
assert outval["seed_everything"] == 1234
- assert outval["model"]["class_path"] == "tests.helpers.BoringModel"
- assert outval["data"]["class_path"] == "tests.helpers.BoringDataModule"
+ assert outval["model"]["class_path"] == "pytorch_lightning.utilities.debugging_examples.BoringModel"
+ assert outval["data"]["class_path"] == "pytorch_lightning.utilities.debugging_examples.BoringDataModule"
def test_lightning_cli_submodules(tmpdir):
@@ -451,9 +454,9 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai
config = """model:
main_param: 2
submodule1:
- class_path: tests.helpers.BoringModel
+ class_path: pytorch_lightning.utilities.debugging_examples.BoringModel
submodule2:
- class_path: tests.helpers.BoringModel
+ class_path: pytorch_lightning.utilities.debugging_examples.BoringModel
"""
config_path = tmpdir / "config.yaml"
with open(config_path, "w") as f:
diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py
index 6c5b5a8e33ffb..a5cc421d860c8 100644
--- a/tests/utilities/test_data.py
+++ b/tests/utilities/test_data.py
@@ -3,7 +3,8 @@
from torch.utils.data.dataloader import DataLoader
from pytorch_lightning.utilities.data import extract_batch_size, get_len, has_iterable_dataset, has_len
-from tests.helpers.boring_model import RandomDataset, RandomIterableDataset
+from pytorch_lightning.utilities.debugging_examples import RandomDataset
+from tests.helpers.datasets import RandomIterableDataset
def test_extract_batch_size():
diff --git a/tests/utilities/test_deepspeed_collate_checkpoint.py b/tests/utilities/test_deepspeed_collate_checkpoint.py
index 45c8f1a9a1d4f..0c911f2f1a5c9 100644
--- a/tests/utilities/test_deepspeed_collate_checkpoint.py
+++ b/tests/utilities/test_deepspeed_collate_checkpoint.py
@@ -17,8 +17,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
-from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py
index f209a310eea39..745e5f2547cbf 100644
--- a/tests/utilities/test_dtype_device_mixin.py
+++ b/tests/utilities/test_dtype_device_mixin.py
@@ -17,7 +17,7 @@
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py
index b351165e03fd8..7cd6169119629 100644
--- a/tests/utilities/test_fetching.py
+++ b/tests/utilities/test_fetching.py
@@ -23,9 +23,9 @@
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.supporters import CombinedLoader
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher
-from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
diff --git a/tests/utilities/test_memory.py b/tests/utilities/test_memory.py
index b486157480877..82cde9e234b2b 100644
--- a/tests/utilities/test_memory.py
+++ b/tests/utilities/test_memory.py
@@ -16,8 +16,8 @@
import torch
import torch.nn as nn
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.memory import get_model_size_mb, recursive_detach
-from tests.helpers import BoringModel
def test_recursive_detach():
diff --git a/tests/utilities/test_model_helpers.py b/tests/utilities/test_model_helpers.py
index 1319e6b44fd8f..3e57fb32ac538 100644
--- a/tests/utilities/test_model_helpers.py
+++ b/tests/utilities/test_model_helpers.py
@@ -17,8 +17,8 @@
import pytest
from pytorch_lightning import LightningDataModule, Trainer
+from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel
from pytorch_lightning.utilities.model_helpers import is_overridden
-from tests.helpers import BoringDataModule, BoringModel
def test_is_overridden():
diff --git a/tests/utilities/test_model_summary.py b/tests/utilities/test_model_summary.py
index 0d993bee18ff2..ee4b90c7cd642 100644
--- a/tests/utilities/test_model_summary.py
+++ b/tests/utilities/test_model_summary.py
@@ -17,9 +17,9 @@
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_9
+from pytorch_lightning.utilities.debugging_examples import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize, UNKNOWN_SIZE
-from tests.helpers import BoringModel
from tests.helpers.advanced_models import ParityModuleRNN
from tests.helpers.runif import RunIf
diff --git a/tests/utilities/test_remote_filesystem.py b/tests/utilities/test_remote_filesystem.py
index 75173a69ae84c..1534e52f5620f 100644
--- a/tests/utilities/test_remote_filesystem.py
+++ b/tests/utilities/test_remote_filesystem.py
@@ -19,7 +19,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
-from tests.helpers import BoringModel
+from pytorch_lightning.utilities.debugging_examples import BoringModel
GCS_BUCKET_PATH = os.getenv("GCS_BUCKET_PATH", None)
_GCS_BUCKET_PATH_AVAILABLE = GCS_BUCKET_PATH is not None