diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 1064e603bee1f..3a4277a10fff5 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -159,6 +159,8 @@ jobs: - name: Examples run: | python -m pytest pl_examples -v --durations=10 + env: + PL_USE_MOCKED_MNIST: "1" - name: Upload pytest results uses: actions/upload-artifact@v2 diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index b86c2d4800d5e..6bc341744c0ba 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -20,7 +20,7 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.plugins import DDPSpawnShardedPlugin -from tests.helpers.boring_model import BoringModel, RandomDataset +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from tests.helpers.runif import RunIf diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index 693514e0a3620..711af99f7a92d 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -14,30 +14,21 @@ import os import platform from typing import Optional -from urllib.error import HTTPError from warnings import warn from torch.utils.data import DataLoader, random_split from pl_examples import _DATASETS_PATH from pytorch_lightning import LightningDataModule +from pytorch_lightning.utilities.debugging_examples import MNIST as PL_MNIST from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE -if _TORCHVISION_AVAILABLE: - from torchvision import transforms as transform_lib - -_TORCHVISION_MNIST_AVAILABLE = not bool(os.getenv("PL_USE_MOCKED_MNIST", False)) -if _TORCHVISION_MNIST_AVAILABLE: - try: - from torchvision.datasets import MNIST - - MNIST(_DATASETS_PATH, download=True) - except HTTPError as e: - print(f"Error {e} downloading `torchvision.datasets.MNIST`") - _TORCHVISION_MNIST_AVAILABLE = False -if not _TORCHVISION_MNIST_AVAILABLE: - print("`torchvision.datasets.MNIST` not available. Using our hosted version") - from tests.helpers.datasets import MNIST +# check whether we are in CI. Users running this should get the `torchvision` implementation +_USE_MOCKED_MNIST = bool(os.getenv("PL_USE_MOCKED_MNIST", False)) +if not _USE_MOCKED_MNIST and PL_MNIST._torchvision_available(_DATASETS_PATH, should_raise=False): + from torchvision.datasets import MNIST +else: + MNIST = PL_MNIST class MNISTDataModule(LightningDataModule): @@ -148,11 +139,12 @@ def test_dataloader(self): def default_transforms(self): if not _TORCHVISION_AVAILABLE: return None + from torchvision import transforms + if self.normalize: - mnist_transforms = transform_lib.Compose( - [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] + mnist_transforms = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))] ) else: - mnist_transforms = transform_lib.ToTensor() - + mnist_transforms = transforms.ToTensor() return mnist_transforms diff --git a/pytorch_lightning/utilities/debugging_examples.py b/pytorch_lightning/utilities/debugging_examples.py new file mode 100644 index 0000000000000..3f1364ab13e66 --- /dev/null +++ b/pytorch_lightning/utilities/debugging_examples.py @@ -0,0 +1,293 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Debugging helpers meant to be used ONLY in the ``pl_examples`` and ``tests`` directories. + +Production usage is discouraged. No backwards-compatibility guarantees. +""" +import logging +import os +import random +import time +import urllib.request +from typing import Optional, Tuple +from urllib.error import HTTPError + +import torch +from torch.utils.data import DataLoader, Dataset, Subset + +from pytorch_lightning import LightningDataModule, LightningModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class RandomDataset(Dataset): + def __init__(self, size: int, length: int): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +class BoringModel(LightningModule): + def __init__(self): + """ + Testing PL Module + + Use as follows: + - subclass + - modify the behavior for what you want + + class TestModel(BaseTestModel): + def training_step(...): + # do your own thing + + or: + + model = BaseTestModel() + model.training_epoch_end = None + + """ + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def loss(self, batch, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + + def step(self, x): + x = self(x) + out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) + return out + + def training_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_step_end(self, training_step_outputs): + return training_step_outputs + + def training_epoch_end(self, outputs) -> None: + torch.stack([x["loss"] for x in outputs]).mean() + + def validation_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"x": loss} + + def validation_epoch_end(self, outputs) -> None: + torch.stack([x["x"] for x in outputs]).mean() + + def test_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"y": loss} + + def test_epoch_end(self, outputs) -> None: + torch.stack([x["y"] for x in outputs]).mean() + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def predict_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + +class BoringDataModule(LightningDataModule): + def __init__(self, data_dir: str = "./"): + super().__init__() + self.data_dir = data_dir + self.non_picklable = None + self.checkpoint_state: Optional[str] = None + + def prepare_data(self): + self.random_full = RandomDataset(32, 64 * 4) + + def setup(self, stage: Optional[str] = None): + if stage == "fit" or stage is None: + self.random_train = Subset(self.random_full, indices=range(64)) + self.dims = self.random_train[0].shape + + if stage in ("fit", "validate") or stage is None: + self.random_val = Subset(self.random_full, indices=range(64, 64 * 2)) + + if stage == "test" or stage is None: + self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3)) + self.dims = getattr(self, "dims", self.random_test[0].shape) + + if stage == "predict" or stage is None: + self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4)) + self.dims = getattr(self, "dims", self.random_predict[0].shape) + + def train_dataloader(self): + return DataLoader(self.random_train) + + def val_dataloader(self): + return DataLoader(self.random_val) + + def test_dataloader(self): + return DataLoader(self.random_test) + + def predict_dataloader(self): + return DataLoader(self.random_predict) + + +class MNIST(torch.utils.data.Dataset): + """ + Customized `MNIST `_ dataset for testing Pytorch Lightning + without the torchvision dependency. + + Part of the code was copied from + https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py + + Args: + root: Root directory of dataset where ``MNIST/processed/training.pt`` + and ``MNIST/processed/test.pt`` exist. + train: If ``True``, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + normalize: mean and std deviation of the MNIST dataset. + download: If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + RESOURCES = ( + "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt", + "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt", + ) + + TRAIN_FILE_NAME = "training.pt" + TEST_FILE_NAME = "test.pt" + cache_folder_name = "complete" + + def __init__( + self, root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, **kwargs + ): + super().__init__() + + self.root = root + self.train = train # training set or test set + self.normalize = normalize + + _USE_MOCKED_MNIST = bool(os.getenv("PL_USE_MOCKED_MNIST", False)) + if not _USE_MOCKED_MNIST: + # avoid users accidentally importing this MNIST implementation + self._torchvision_available(self.cached_folder_path) + self.prepare_data(download) + + data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME + self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: + img = self.data[idx].float().unsqueeze(0) + target = int(self.targets[idx]) + + if self.normalize is not None and len(self.normalize) == 2: + img = self.normalize_tensor(img, *self.normalize) + + return img, target + + def __len__(self) -> int: + return len(self.data) + + @property + def cached_folder_path(self) -> str: + return os.path.join(self.root, "MNIST", self.cache_folder_name) + + def _check_exists(self, data_folder: str) -> bool: + existing = True + for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): + existing = existing and os.path.isfile(os.path.join(data_folder, fname)) + return existing + + def prepare_data(self, download: bool = True): + if download and not self._check_exists(self.cached_folder_path): + self._download(self.cached_folder_path) + if not self._check_exists(self.cached_folder_path): + raise RuntimeError("Dataset not found.") + + def _download(self, data_folder: str) -> None: + os.makedirs(data_folder, exist_ok=True) + for url in self.RESOURCES: + logging.info(f"Downloading {url}") + fpath = os.path.join(data_folder, os.path.basename(url)) + urllib.request.urlretrieve(url, fpath) + + @staticmethod + def _try_load(path_data, trials: int = 30, delta: float = 1.0): + """Resolving loading from the same time from multiple concurrent processes.""" + res, exception = None, None + assert trials, "at least some trial has to be set" + assert os.path.isfile(path_data), f"missing file: {path_data}" + for _ in range(trials): + try: + res = torch.load(path_data) + # todo: specify the possible exception + except Exception as e: + exception = e + time.sleep(delta * random.random()) + else: + break + if exception is not None: + # raise the caught exception + raise exception + return res + + @staticmethod + def normalize_tensor(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> torch.Tensor: + mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) + std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) + return tensor.sub(mean).div(std) + + @staticmethod + def _torchvision_available(root: str, should_raise: bool = True) -> bool: + # local import to facilitate mocking + from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE + + try: + from torchvision.datasets import MNIST + + MNIST(root, download=True) + download_successful = True + except HTTPError as e: + # allow using our implementation if torchvision hosting is down + print(f"Error {e} downloading `torchvision.datasets.MNIST`") + download_successful = False + _TORCHVISION_AVAILABLE &= download_successful + + if _TORCHVISION_AVAILABLE and should_raise: + raise MisconfigurationException( + "The `torchvision` package is available. This implementation is meant for internal use." + " Please import MNIST with `from torchvision.datasets import MNIST`" + " instead of `from pytorch_lightning.utilities.debug_examples import MNIST`" + ) + return _TORCHVISION_AVAILABLE diff --git a/tests/__init__.py b/tests/__init__.py index 9039a6e4b16e9..8e43f2fa13559 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -34,3 +34,6 @@ os.mkdir(_TEMP_PATH) logging.basicConfig(level=logging.ERROR) + +# Use our MNIST implementation on tests to avoid the torchvision dependency +os.environ.setdefault("PL_USE_MOCKED_MNIST", "1") diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index e27b873a63941..02dba54c61454 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -43,8 +43,8 @@ TorchElasticEnvironment, ) from pytorch_lightning.utilities import DistributedType +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index e67fb166f815b..47dd4cb0c2861 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -17,8 +17,8 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.plugins import SingleDevicePlugin +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.accelerators.test_dp import CustomClassificationModelDP -from tests.helpers.boring_model import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 7280afea02f76..1b82d718ae0fa 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -12,8 +12,8 @@ from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel def test_unsupported_precision_plugins(): diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index dc83e4ad4f02e..8faf0066d3ec5 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -23,8 +23,8 @@ import pytorch_lightning as pl from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.accelerators import ddp_model -from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf from tests.utilities.distributed import call_training_script diff --git a/tests/accelerators/test_ddp_spawn.py b/tests/accelerators/test_ddp_spawn.py index a21078cf55542..3163aa19c9e12 100644 --- a/tests/accelerators/test_ddp_spawn.py +++ b/tests/accelerators/test_ddp_spawn.py @@ -16,7 +16,7 @@ from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.trainer import Trainer from pytorch_lightning.utilities import memory -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index efaf761cb7116..21afb07295bdd 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -22,8 +22,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.utilities import memory +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel, RandomDataset from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index ba2de43da110e..55b86b91a2b1f 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -25,8 +25,8 @@ from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities import _IPU_AVAILABLE +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel diff --git a/tests/accelerators/test_multi_nodes_gpu.py b/tests/accelerators/test_multi_nodes_gpu.py index ea591e47041f8..5daccd3f1d367 100644 --- a/tests/accelerators/test_multi_nodes_gpu.py +++ b/tests/accelerators/test_multi_nodes_gpu.py @@ -25,7 +25,7 @@ from pytorch_lightning import LightningModule # noqa: E402 from pytorch_lightning import Trainer # noqa: E402 -from tests.helpers.boring_model import BoringModel # noqa: E402 +from pytorch_lightning.utilities.debugging_examples import BoringModel # noqa: E402 # TODO(Borda): When multi-node tests are re-enabled (.github/workflows/ci_test-mnodes.yml) diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index 99ac579eb99b0..afb2294c3a7a8 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -23,8 +23,8 @@ from pytorch_lightning.accelerators.tpu import TPUAccelerator from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import TPUSpawnPlugin +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf from tests.helpers.utils import pl_multi_process_test diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index 1a2ecae7b94c0..186196d2f36b1 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -14,7 +14,7 @@ import pytest from pytorch_lightning import Callback, Trainer -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel @pytest.mark.parametrize("single_cb", [False, True]) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index d190feed7e1f7..e086dcc2340f2 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -15,7 +15,7 @@ from unittest.mock import call, Mock from pytorch_lightning import Callback, Trainer -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel def test_callbacks_configured_in_model(tmpdir): diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 4c3b990dd1b13..98c9f291e49fe 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -24,8 +24,8 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 9a28c7a8fc478..099e78050dbeb 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -22,7 +22,7 @@ from pytorch_lightning import LightningModule, seed_everything, Trainer from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint from pytorch_lightning.callbacks.base import Callback -from tests.helpers import BoringModel, RandomDataset +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset class TestBackboneFinetuningCallback(BackboneFinetuning): diff --git a/tests/callbacks/test_gpu_stats_monitor.py b/tests/callbacks/test_gpu_stats_monitor.py index eaba4d30684f3..cb0c8c347707f 100644 --- a/tests/callbacks/test_gpu_stats_monitor.py +++ b/tests/callbacks/test_gpu_stats_monitor.py @@ -22,8 +22,8 @@ from pytorch_lightning.callbacks import GPUStatsMonitor from pytorch_lightning.loggers import CSVLogger from pytorch_lightning.loggers.csv_logs import ExperimentWriter +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index 82f64d676c774..5d2ced197e359 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -16,7 +16,7 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import Callback, LambdaCallback -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel def test_lambda_call(tmpdir): diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py index d742781599d77..2d5bc23c79ac3 100644 --- a/tests/callbacks/test_lr_monitor.py +++ b/tests/callbacks/test_lr_monitor.py @@ -20,8 +20,8 @@ from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.finetuning import BackboneFinetuning +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.simple_models import ClassificationModel diff --git a/tests/callbacks/test_prediction_writer.py b/tests/callbacks/test_prediction_writer.py index 75e0dbd31ec79..ae7b110e4d695 100644 --- a/tests/callbacks/test_prediction_writer.py +++ b/tests/callbacks/test_prediction_writer.py @@ -16,8 +16,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import BasePredictionWriter +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel def test_prediction_writer(tmpdir): diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 1c3176f39a886..58bc234774819 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -25,8 +25,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase from pytorch_lightning.callbacks.progress import tqdm +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index fe6e14d1084d9..9793bdc259059 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -24,8 +24,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index c6f44759ba371..10ff0f7cd9ea6 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -17,7 +17,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 0bfaa359bb1a8..8fbfbfd15cd14 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -25,8 +25,9 @@ from pytorch_lightning.callbacks import StochasticWeightAveraging from pytorch_lightning.plugins import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset +from tests.helpers.datasets import RandomIterableDataset from tests.helpers.runif import RunIf diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py index 44d3c305bb1ac..c6a69a7c15ca6 100644 --- a/tests/callbacks/test_timer.py +++ b/tests/callbacks/test_timer.py @@ -21,8 +21,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.timer import Timer +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/callbacks/test_xla_stats_monitor.py b/tests/callbacks/test_xla_stats_monitor.py index bf4dde3983921..7c7a070569a90 100644 --- a/tests/callbacks/test_xla_stats_monitor.py +++ b/tests/callbacks/test_xla_stats_monitor.py @@ -20,8 +20,8 @@ from pytorch_lightning.callbacks import XLAStatsMonitor from pytorch_lightning.loggers import CSVLogger from pytorch_lightning.loggers.csv_logs import ExperimentWriter +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 2e7fcfb8c2253..553bd7a5bd5fc 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -18,7 +18,7 @@ import torch from pytorch_lightning import callbacks, seed_everything, Trainer -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 314ed899c588a..8ce81ca8bda20 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -38,8 +38,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/checkpointing/test_torch_saving.py b/tests/checkpointing/test_torch_saving.py index e95ce1c91d6f4..2910cf60a0abd 100644 --- a/tests/checkpointing/test_torch_saving.py +++ b/tests/checkpointing/test_torch_saving.py @@ -16,7 +16,7 @@ import torch from pytorch_lightning import Trainer -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index f76e76b2f9dd9..9213b39008d71 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -19,7 +19,7 @@ import pytorch_lightning as pl from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel def test_finetuning_with_resume_from_checkpoint(tmpdir): diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 3bfe3aaa6cf80..b13bc95ee9945 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -23,9 +23,9 @@ from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities import AttributeDict +from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from tests.helpers import BoringDataModule, BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel diff --git a/tests/core/test_decorators.py b/tests/core/test_decorators.py index c26836e6a5c90..5f00c06faca43 100644 --- a/tests/core/test_decorators.py +++ b/tests/core/test_decorators.py @@ -15,7 +15,7 @@ import torch from pytorch_lightning.core.decorators import auto_move_data -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 8b787e0f57fcb..0721b3c7b56fe 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -22,7 +22,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities import _TORCH_SHARDED_TENSOR_AVAILABLE -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index f8c3317b6c595..0a6660ed7bfe7 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -20,7 +20,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.core.optimizer import LightningOptimizer -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel def test_lightning_optimizer(tmpdir): diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 7c6a985d09f33..c288bd4cdc1e9 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -28,8 +28,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_7 -from tests.helpers import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 40dfc069ac449..f70d9719ba232 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -19,8 +19,8 @@ from pytorch_lightning.core.decorators import auto_move_data from pytorch_lightning.plugins import DeepSpeedPlugin from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PyTorchProfiler, SimpleProfiler +from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel from tests.deprecated_api import no_deprecated_call -from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf from tests.helpers.utils import no_warning_call diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 98c5c4e0320ea..3814800fb4b9e 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -18,11 +18,11 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin +from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.model_summary import ModelSummary from tests.deprecated_api import _soft_unimport_module -from tests.helpers import BoringDataModule, BoringModel def test_v1_6_0_trainer_model_hook_mixin(tmpdir): diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 7581bf2b0c142..84ffde97ec1c3 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -16,8 +16,8 @@ import pytest from pytorch_lightning import LightningDataModule, Trainer +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.deprecated_api import _soft_unimport_module -from tests.helpers import BoringModel from tests.helpers.datamodules import MNISTDataModule diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py deleted file mode 100644 index e6fa5cfa70795..0000000000000 --- a/tests/helpers/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from tests.helpers.boring_model import BoringDataModule, BoringModel, RandomDataset # noqa: F401 -from tests.helpers.datasets import TrialMNIST # noqa: F401 diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py deleted file mode 100644 index aeaf85ecf254c..0000000000000 --- a/tests/helpers/boring_model.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import torch -from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset - -from pytorch_lightning import LightningDataModule, LightningModule - - -class RandomDictDataset(Dataset): - def __init__(self, size: int, length: int): - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index): - a = self.data[index] - b = a + 2 - return {"a": a, "b": b} - - def __len__(self): - return self.len - - -class RandomDataset(Dataset): - def __init__(self, size: int, length: int): - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index): - return self.data[index] - - def __len__(self): - return self.len - - -class RandomIterableDataset(IterableDataset): - def __init__(self, size: int, count: int): - self.count = count - self.size = size - - def __iter__(self): - for _ in range(self.count): - yield torch.randn(self.size) - - -class RandomIterableDatasetWithLen(IterableDataset): - def __init__(self, size: int, count: int): - self.count = count - self.size = size - - def __iter__(self): - for _ in range(len(self)): - yield torch.randn(self.size) - - def __len__(self): - return self.count - - -class BoringModel(LightningModule): - def __init__(self): - """ - Testing PL Module - - Use as follows: - - subclass - - modify the behavior for what you want - - class TestModel(BaseTestModel): - def training_step(...): - # do your own thing - - or: - - model = BaseTestModel() - model.training_epoch_end = None - - """ - super().__init__() - self.layer = torch.nn.Linear(32, 2) - - def forward(self, x): - return self.layer(x) - - def loss(self, batch, prediction): - # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls - return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) - - def step(self, x): - x = self(x) - out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) - return out - - def training_step(self, batch, batch_idx): - output = self(batch) - loss = self.loss(batch, output) - return {"loss": loss} - - def training_step_end(self, training_step_outputs): - return training_step_outputs - - def training_epoch_end(self, outputs) -> None: - torch.stack([x["loss"] for x in outputs]).mean() - - def validation_step(self, batch, batch_idx): - output = self(batch) - loss = self.loss(batch, output) - return {"x": loss} - - def validation_epoch_end(self, outputs) -> None: - torch.stack([x["x"] for x in outputs]).mean() - - def test_step(self, batch, batch_idx): - output = self(batch) - loss = self.loss(batch, output) - return {"y": loss} - - def test_epoch_end(self, outputs) -> None: - torch.stack([x["y"] for x in outputs]).mean() - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) - return [optimizer], [lr_scheduler] - - def train_dataloader(self): - return DataLoader(RandomDataset(32, 64)) - - def val_dataloader(self): - return DataLoader(RandomDataset(32, 64)) - - def test_dataloader(self): - return DataLoader(RandomDataset(32, 64)) - - def predict_dataloader(self): - return DataLoader(RandomDataset(32, 64)) - - -class BoringDataModule(LightningDataModule): - def __init__(self, data_dir: str = "./"): - super().__init__() - self.data_dir = data_dir - self.non_picklable = None - self.checkpoint_state: Optional[str] = None - - def prepare_data(self): - self.random_full = RandomDataset(32, 64 * 4) - - def setup(self, stage: Optional[str] = None): - if stage == "fit" or stage is None: - self.random_train = Subset(self.random_full, indices=range(64)) - self.dims = self.random_train[0].shape - - if stage in ("fit", "validate") or stage is None: - self.random_val = Subset(self.random_full, indices=range(64, 64 * 2)) - - if stage == "test" or stage is None: - self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3)) - self.dims = getattr(self, "dims", self.random_test[0].shape) - - if stage == "predict" or stage is None: - self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4)) - self.dims = getattr(self, "dims", self.random_predict[0].shape) - - def train_dataloader(self): - return DataLoader(self.random_train) - - def val_dataloader(self): - return DataLoader(self.random_val) - - def test_dataloader(self): - return DataLoader(self.random_test) - - def predict_dataloader(self): - return DataLoader(self.random_predict) diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py index 10a4c6b6e7ca7..1dfea134de3b1 100644 --- a/tests/helpers/datasets.py +++ b/tests/helpers/datasets.py @@ -11,126 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging import os -import random -import time -import urllib.request -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence import torch -from torch import Tensor -from torch.utils.data import Dataset +from torch.utils.data import Dataset, IterableDataset - -class MNIST(Dataset): - """ - Customized `MNIST `_ dataset for testing Pytorch Lightning - without the torchvision dependency. - - Part of the code was copied from - https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py - - Args: - root: Root directory of dataset where ``MNIST/processed/training.pt`` - and ``MNIST/processed/test.pt`` exist. - train: If ``True``, creates dataset from ``training.pt``, - otherwise from ``test.pt``. - normalize: mean and std deviation of the MNIST dataset. - download: If true, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - - Examples: - >>> dataset = MNIST(".", download=True) - >>> len(dataset) - 60000 - >>> torch.bincount(dataset.targets) - tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]) - """ - - RESOURCES = ( - "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt", - "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt", - ) - - TRAIN_FILE_NAME = "training.pt" - TEST_FILE_NAME = "test.pt" - cache_folder_name = "complete" - - def __init__( - self, root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, **kwargs - ): - super().__init__() - self.root = root - self.train = train # training set or test set - self.normalize = normalize - - self.prepare_data(download) - - data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME - self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) - - def __getitem__(self, idx: int) -> Tuple[Tensor, int]: - img = self.data[idx].float().unsqueeze(0) - target = int(self.targets[idx]) - - if self.normalize is not None and len(self.normalize) == 2: - img = self.normalize_tensor(img, *self.normalize) - - return img, target - - def __len__(self) -> int: - return len(self.data) - - @property - def cached_folder_path(self) -> str: - return os.path.join(self.root, "MNIST", self.cache_folder_name) - - def _check_exists(self, data_folder: str) -> bool: - existing = True - for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): - existing = existing and os.path.isfile(os.path.join(data_folder, fname)) - return existing - - def prepare_data(self, download: bool = True): - if download and not self._check_exists(self.cached_folder_path): - self._download(self.cached_folder_path) - if not self._check_exists(self.cached_folder_path): - raise RuntimeError("Dataset not found.") - - def _download(self, data_folder: str) -> None: - os.makedirs(data_folder, exist_ok=True) - for url in self.RESOURCES: - logging.info(f"Downloading {url}") - fpath = os.path.join(data_folder, os.path.basename(url)) - urllib.request.urlretrieve(url, fpath) - - @staticmethod - def _try_load(path_data, trials: int = 30, delta: float = 1.0): - """Resolving loading from the same time from multiple concurrent processes.""" - res, exception = None, None - assert trials, "at least some trial has to be set" - assert os.path.isfile(path_data), f"missing file: {path_data}" - for _ in range(trials): - try: - res = torch.load(path_data) - # todo: specify the possible exception - except Exception as e: - exception = e - time.sleep(delta * random.random()) - else: - break - if exception is not None: - # raise the caught exception - raise exception - return res - - @staticmethod - def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor: - mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) - std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) - return tensor.sub(mean).div(std) +from pytorch_lightning.utilities.debugging_examples import MNIST class TrialMNIST(MNIST): @@ -214,3 +101,40 @@ def __getitem__(self, idx): def __len__(self): return len(self.y) + + +class RandomDictDataset(Dataset): + def __init__(self, size: int, length: int): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + a = self.data[index] + b = a + 2 + return {"a": a, "b": b} + + def __len__(self): + return self.len + + +class RandomIterableDataset(IterableDataset): + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self): + for _ in range(self.count): + yield torch.randn(self.size) + + +class RandomIterableDatasetWithLen(IterableDataset): + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self): + for _ in range(len(self)): + yield torch.randn(self.size) + + def __len__(self): + return self.count diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py index 3e5066d708da0..02983c2686a92 100644 --- a/tests/helpers/pipelines.py +++ b/tests/helpers/pipelines.py @@ -16,7 +16,7 @@ from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.utilities import DistributedType -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.utils import get_default_logger, load_model_from_checkpoint, reset_seed diff --git a/tests/helpers/test_models.py b/tests/helpers/test_models.py index b6d853f2ac594..bc4cf28d8b035 100644 --- a/tests/helpers/test_models.py +++ b/tests/helpers/test_models.py @@ -16,8 +16,8 @@ import pytest from pytorch_lightning import Trainer +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN -from tests.helpers.boring_model import BoringModel from tests.helpers.datamodules import ClassifDataModule, RegressDataModule from tests.helpers.simple_models import ClassificationModel, RegressionModel diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index eb1d53ccc62e3..1a29f2f8e9ad5 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -32,7 +32,7 @@ WandbLogger, ) from pytorch_lightning.loggers.base import DummyExperiment -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf from tests.loggers.test_comet import _patch_comet_atexit from tests.loggers.test_mlflow import mock_mlflow_run_creation diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 7c8673c34956f..c884e14818274 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -23,7 +23,7 @@ from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger from pytorch_lightning.utilities import rank_zero_only -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel def test_logger_collection(): diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index cb3641e669d50..6f10ae5d4472f 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -18,8 +18,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import CometLogger +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel def _patch_comet_atexit(monkeypatch): diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 8523edf69a980..7069428128b50 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -20,7 +20,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import _MLFLOW_AVAILABLE, MLFlowLogger from pytorch_lightning.loggers.mlflow import MLFLOW_RUN_NAME, resolve_tags -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None): diff --git a/tests/loggers/test_neptune.py b/tests/loggers/test_neptune.py index c58bb4ef59fbe..a60a56149c9c7 100644 --- a/tests/loggers/test_neptune.py +++ b/tests/loggers/test_neptune.py @@ -17,7 +17,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import NeptuneLogger -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel @patch("pytorch_lightning.loggers.neptune.neptune") diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index a1c66c0559d75..47d3a221d0c2b 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -25,8 +25,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.imports import _compare_version -from tests.helpers import BoringModel @pytest.mark.skipif( diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 0684727e84ac8..d8411b7d201fe 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -20,8 +20,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel @mock.patch("pytorch_lightning.loggers.wandb.wandb") diff --git a/tests/loops/test_iterator_batch_processor.py b/tests/loops/test_iterator_batch_processor.py index 2cd6a172f6941..4beeb7ad19557 100644 --- a/tests/loops/test_iterator_batch_processor.py +++ b/tests/loops/test_iterator_batch_processor.py @@ -17,9 +17,9 @@ from torch.utils.data import DataLoader from pytorch_lightning import Trainer +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import STEP_OUTPUT -from tests.helpers import BoringModel, RandomDataset _BATCH_SIZE = 32 _DATASET_LEN = 64 diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 65cbebc8203e5..e62c1f070c04c 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -24,7 +24,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loops import Loop, TrainingBatchLoop from pytorch_lightning.trainer.progress import BaseProgress -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index f851a7de0837e..4fde65d529d83 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -37,7 +37,7 @@ else: print("You requested to import Horovod which is missing or not supported for your OS.") -from tests.helpers import BoringModel # noqa: E402 +from pytorch_lightning.utilities.debugging_examples import BoringModel # noqa: E402 from tests.helpers.utils import reset_seed, set_random_master_port # noqa: E402 parser = argparse.ArgumentParser() diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 79c0cf7c12f15..60725d6981297 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -23,8 +23,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10 +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 015c79458e1aa..59c595b259884 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -19,7 +19,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 1d23ed2f76907..23c892c8301aa 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -25,9 +25,9 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import TorchElasticEnvironment from pytorch_lightning.utilities import device_parser +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _compare_version -from tests.helpers import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.imports import Batch, Dataset, Example, Field, LabelField from tests.helpers.runif import RunIf diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index fb0c0d4a8da76..3e6b61c83b436 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -17,7 +17,7 @@ import pytest from pytorch_lightning import Trainer -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.utils import reset_seed diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 990178b09d07f..fbfaeb016762e 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -21,7 +21,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, Trainer -from tests.helpers import BoringDataModule, BoringModel, RandomDataset +from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index c3f3cdcf7ffc2..3494aa6e3e418 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -30,7 +30,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.utilities import _HOROVOD_AVAILABLE -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.advanced_models import BasicGAN from tests.helpers.runif import RunIf diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 7e37fab22cd3f..3d70ef3956d48 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -31,8 +31,8 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel, RandomDataset if _HYDRA_EXPERIMENTAL_AVAILABLE: from hydra.experimental import compose, initialize diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index cec01e828d1ed..10f7598521dd3 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -21,7 +21,7 @@ import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils from pytorch_lightning import Trainer -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index d1d870fa30116..6662da51ad88b 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -28,7 +28,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.states import RunningStage -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 5748f6d9a1095..912124c2ad80b 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -20,7 +20,7 @@ from fsspec.implementations.local import LocalFileSystem from pytorch_lightning.utilities.cloud_io import get_filesystem -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.advanced_models import BasicGAN, ParityModuleRNN from tests.helpers.datamodules import MNISTDataModule from tests.helpers.runif import RunIf diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 5aa605cdf38bb..340d58b4020ea 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -27,9 +27,9 @@ from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf from tests.helpers.utils import pl_multi_process_test diff --git a/tests/models/test_truncated_bptt.py b/tests/models/test_truncated_bptt.py index ab10a527e3cda..ac7eda8db98c0 100644 --- a/tests/models/test_truncated_bptt.py +++ b/tests/models/test_truncated_bptt.py @@ -16,7 +16,7 @@ import torch from pytorch_lightning import Trainer -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel @pytest.mark.parametrize("n_hidden_states", (1, 2)) diff --git a/tests/overrides/test_base.py b/tests/overrides/test_base.py index 4b76fd028af66..90dba69d5ab17 100644 --- a/tests/overrides/test_base.py +++ b/tests/overrides/test_base.py @@ -20,7 +20,7 @@ _LightningPrecisionModuleWrapperBase, unwrap_lightning_module, ) -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel @pytest.mark.parametrize("wrapper_class", [_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase]) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index c481fdbef0f33..5dec4319ba366 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -27,7 +27,7 @@ unsqueeze_scalar_tensor, ) from pytorch_lightning.trainer.states import RunningStage -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/plugins/environments/torch_elastic_deadlock.py b/tests/plugins/environments/torch_elastic_deadlock.py index ac2348285d9af..40d04932c4587 100644 --- a/tests/plugins/environments/torch_elastic_deadlock.py +++ b/tests/plugins/environments/torch_elastic_deadlock.py @@ -4,8 +4,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import DeadlockDetectedException -from tests.helpers.boring_model import BoringModel if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1": diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index 15ec43973b0ed..2e237f70ab155 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -21,8 +21,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index 810127a03f361..f405d4039bd4b 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -20,9 +20,9 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.plugins import CheckpointIO, DeepSpeedPlugin, SingleDevicePlugin, TPUSpawnPlugin +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _PATH -from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 939c05d1b7afe..b84aded417511 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -19,7 +19,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin, SingleDevicePlugin -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index 3dd4897864947..4b8f9639fed5d 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -9,8 +9,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.plugins import DDPFullyShardedPlugin, FullyShardedNativeMixedPrecisionPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index 60ec1930852d0..9dae8d432f4d7 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -21,7 +21,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.plugins.environments import LightningEnvironment -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py index 49c4cd18ef316..7f11d225e1cbc 100644 --- a/tests/plugins/test_ddp_plugin_with_comm_hook.py +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -16,7 +16,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8: diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 1ab94446c8176..495e8cc1077cd 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -15,7 +15,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPSpawnPlugin -from tests.helpers.boring_model import BoringDataModule, BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel from tests.helpers.runif import RunIf diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index a5e4e1d189aaa..2aafb41d33ce4 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -15,10 +15,11 @@ from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE -from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.datasets import RandomIterableDataset from tests.helpers.runif import RunIf if _DEEPSPEED_AVAILABLE: diff --git a/tests/plugins/test_double_plugin.py b/tests/plugins/test_double_plugin.py index 71595024c80af..22aba697e65a5 100644 --- a/tests/plugins/test_double_plugin.py +++ b/tests/plugins/test_double_plugin.py @@ -21,7 +21,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DoublePrecisionPlugin from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 -from tests.helpers.boring_model import BoringModel, RandomDataset +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from tests.helpers.runif import RunIf diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 82e899a6f4aac..ab0767eb24d17 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -7,8 +7,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/plugins/test_single_device_plugin.py b/tests/plugins/test_single_device_plugin.py index 8d42bbb3e99ec..e2d53f62ca918 100644 --- a/tests/plugins/test_single_device_plugin.py +++ b/tests/plugins/test_single_device_plugin.py @@ -15,7 +15,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import SingleDevicePlugin -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index 036e26a7c4a2f..c4191eca8ae52 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -22,8 +22,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins.training_type import TPUSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests.helpers.runif import RunIf from tests.helpers.utils import pl_multi_process_test diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 2145ab83e9cdb..5477456e49c39 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -27,9 +27,9 @@ from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.profiler.pytorch import RegisterRecordFunction from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE -from tests.helpers import BoringModel from tests.helpers.runif import RunIf PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005 diff --git a/tests/profiler/test_xla_profiler.py b/tests/profiler/test_xla_profiler.py index 2afbf69a6d0b0..435851867d129 100644 --- a/tests/profiler/test_xla_profiler.py +++ b/tests/profiler/test_xla_profiler.py @@ -19,7 +19,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.profiler import XLAProfiler from pytorch_lightning.utilities import _TPU_AVAILABLE -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf if _TPU_AVAILABLE: diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 43158865f9e75..586c7ce503e9a 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -24,7 +24,7 @@ ProgressBar, ) from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel def test_checkpoint_callbacks_are_last(tmpdir): diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index 83a45f02224d5..f59345a4f5117 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -17,7 +17,7 @@ import torch from pytorch_lightning import Trainer -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel class HPCHookdedModel(BoringModel): diff --git a/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py index 15e817da975be..3c24a42b4e0c2 100644 --- a/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py +++ b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py @@ -15,7 +15,7 @@ from torch.utils.data import Dataset from pytorch_lightning import Trainer -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel class RandomDatasetA(Dataset): diff --git a/tests/trainer/flags/test_check_val_every_n_epoch.py b/tests/trainer/flags/test_check_val_every_n_epoch.py index 97c6ddf7803ab..fc133f7452f3c 100644 --- a/tests/trainer/flags/test_check_val_every_n_epoch.py +++ b/tests/trainer/flags/test_check_val_every_n_epoch.py @@ -14,7 +14,7 @@ import pytest from pytorch_lightning.trainer import Trainer -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel @pytest.mark.parametrize( diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index cff0c8a43727d..5266f92a87832 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -7,7 +7,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers.base import DummyLogger -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel @pytest.mark.parametrize("tuner_alg", ["batch size scaler", "learning rate finder"]) diff --git a/tests/trainer/flags/test_min_max_epochs.py b/tests/trainer/flags/test_min_max_epochs.py index 059a447e10edb..e57aa3e134ad6 100644 --- a/tests/trainer/flags/test_min_max_epochs.py +++ b/tests/trainer/flags/test_min_max_epochs.py @@ -1,7 +1,7 @@ import pytest from pytorch_lightning import Trainer -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel @pytest.mark.parametrize( diff --git a/tests/trainer/flags/test_overfit_batches.py b/tests/trainer/flags/test_overfit_batches.py index 798b9988469df..043174ddc7028 100644 --- a/tests/trainer/flags/test_overfit_batches.py +++ b/tests/trainer/flags/test_overfit_batches.py @@ -15,7 +15,7 @@ import torch from pytorch_lightning import Trainer -from tests.helpers.boring_model import BoringModel, RandomDataset +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset def test_overfit_multiple_val_loaders(tmpdir): diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py index 8bc6ef774f4a1..97823438d8d52 100644 --- a/tests/trainer/flags/test_val_check_interval.py +++ b/tests/trainer/flags/test_val_check_interval.py @@ -14,7 +14,7 @@ import pytest from pytorch_lightning.trainer import Trainer -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel @pytest.mark.parametrize("max_epochs", [1, 2, 3]) diff --git a/tests/trainer/logging_/test_distributed_logging.py b/tests/trainer/logging_/test_distributed_logging.py index 727e95b894060..c3b825b2e4003 100644 --- a/tests/trainer/logging_/test_distributed_logging.py +++ b/tests/trainer/logging_/test_distributed_logging.py @@ -18,7 +18,7 @@ import pytorch_lightning as pl from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers.base import LightningLoggerBase -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 8579bc044734a..a996ae0578072 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -25,7 +25,7 @@ from pytorch_lightning import callbacks, Trainer from pytorch_lightning.loggers import TensorBoardLogger -from tests.helpers import BoringModel, RandomDataset +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset def test__validation_step__log(tmpdir): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index ed7711b32ffda..ded4cf6662dae 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -24,8 +24,8 @@ from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf from tests.models.test_hooks import get_members diff --git a/tests/trainer/logging_/test_progress_bar_logging.py b/tests/trainer/logging_/test_progress_bar_logging.py index d19251c02d37c..b8fce774f4ba0 100644 --- a/tests/trainer/logging_/test_progress_bar_logging.py +++ b/tests/trainer/logging_/test_progress_bar_logging.py @@ -1,7 +1,7 @@ import pytest from pytorch_lightning import Trainer -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel def test_logging_to_progress_bar_with_reserved_key(tmpdir): diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index b218df9e3b15d..b34cd48e08168 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -26,8 +26,9 @@ from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel, RandomDictDataset +from tests.helpers.datasets import RandomDictDataset from tests.helpers.runif import RunIf diff --git a/tests/trainer/loops/test_all.py b/tests/trainer/loops/test_all.py index 5975937018e16..9d229125086fc 100644 --- a/tests/trainer/loops/test_all.py +++ b/tests/trainer/loops/test_all.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning import Callback, Trainer -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index d7acd7e65727e..92bf724e0f029 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import Trainer -from tests.helpers.boring_model import BoringModel, RandomDataset +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from tests.helpers.runif import RunIf diff --git a/tests/trainer/loops/test_flow_warnings.py b/tests/trainer/loops/test_flow_warnings.py index e14bd8825510a..199e7e9c5a363 100644 --- a/tests/trainer/loops/test_flow_warnings.py +++ b/tests/trainer/loops/test_flow_warnings.py @@ -14,7 +14,7 @@ import warnings from pytorch_lightning import Trainer -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel class TestModel(BoringModel): diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index 22258b8e52eea..0c8e148f00038 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -17,8 +17,8 @@ import torch from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel def test_outputs_format(tmpdir): diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 4ee9d858d44c9..fcf12c7de0336 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -19,7 +19,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.states import RunningStage -from tests.helpers.boring_model import BoringModel, RandomDataset +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from tests.helpers.deterministic_model import DeterministicModel from tests.helpers.utils import no_warning_call diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 670e8b4842a89..27904d6048a39 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -24,7 +24,7 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import Callback -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py index fccb4e60657d9..1ff6a6242219b 100644 --- a/tests/trainer/optimization/test_multiple_optimizers.py +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -18,7 +18,7 @@ import torch import pytorch_lightning as pl -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel class MultiOptModel(BoringModel): diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 3c8a3d5ae8e68..e16d6b76fecad 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -19,8 +19,8 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/trainer/properties/log_dir.py b/tests/trainer/properties/log_dir.py index d940dabd99c09..d339885ee6ecd 100644 --- a/tests/trainer/properties/log_dir.py +++ b/tests/trainer/properties/log_dir.py @@ -16,7 +16,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel class TestModel(BoringModel): diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 9a0527d46330c..2411a82086e7e 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -13,7 +13,7 @@ # limitations under the License. from pytorch_lightning import Trainer -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index daa01b5abe7b5..c0bdf910b440d 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -15,8 +15,8 @@ import torch from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel, RandomDataset def test_wrong_train_setting(tmpdir): diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index e9d5d3cc047cb..cfaa872b0c8e1 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -19,9 +19,9 @@ from torch.utils.data.sampler import BatchSampler, Sampler, SequentialSampler from pytorch_lightning import Trainer +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 -from tests.helpers import BoringModel, RandomDataset @pytest.mark.skipif( diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index e6686cf8117e0..1a97ee859a1ec 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -26,9 +26,10 @@ from pytorch_lightning import Callback, seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities.data import has_iterable_dataset, has_len +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate -from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset, RandomIterableDatasetWithLen +from tests.helpers.datasets import RandomIterableDataset, RandomIterableDatasetWithLen from tests.helpers.runif import RunIf diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index 861885b4c052b..3d34a1cf25cb4 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -15,7 +15,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel def test_initialize_state(): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0c4833c903a66..27c44a15d4ebb 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -41,10 +41,10 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import DeviceType, DistributedType from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.debugging_examples import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything from tests.base import EvalModelTemplate -from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index de1873ee391d8..1b12e4586ca07 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -18,9 +18,9 @@ import torch from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate -from tests.helpers import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.simple_models import ClassificationModel diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 3d4aa35a7da03..06b61107b2a5c 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -22,9 +22,9 @@ from pytorch_lightning import Trainer from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate -from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.datamodules import MNISTDataModule from tests.helpers.runif import RunIf diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 62c543619ee4d..0e9591336ac39 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -6,7 +6,7 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.utilities import AllGatherGrad -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index e665fc79e4323..76ea0d2f4872b 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -37,10 +37,10 @@ CaptureIterableDataset, FastForwardSampler, ) +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training -from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index b72eabc14a814..f6b25df35ec7f 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -33,8 +33,8 @@ from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import instantiate_class, LightningArgumentParser, LightningCLI, SaveConfigCallback +from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE -from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf torchvision_version = version.parse("0") @@ -365,8 +365,11 @@ def test_lightning_cli_save_config_cases(tmpdir): def test_lightning_cli_config_and_subclass_mode(tmpdir): config = dict( - model=dict(class_path="tests.helpers.BoringModel"), - data=dict(class_path="tests.helpers.BoringDataModule", init_args=dict(data_dir=str(tmpdir))), + model=dict(class_path="pytorch_lightning.utilities.debugging_examples.BoringModel"), + data=dict( + class_path="pytorch_lightning.utilities.debugging_examples.BoringDataModule", + init_args=dict(data_dir=str(tmpdir)), + ), trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, weights_summary=None), ) config_path = tmpdir / "config.yaml" @@ -414,7 +417,7 @@ def test_lightning_cli_help(): if param not in skip_params: assert f"--trainer.{param}" in out - cli_args = ["any.py", "--data.help=tests.helpers.BoringDataModule"] + cli_args = ["any.py", "--data.help=pytorch_lightning.utilities.debugging_examples.BoringDataModule"] out = StringIO() with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): any_model_any_data_cli() @@ -426,8 +429,8 @@ def test_lightning_cli_print_config(): cli_args = [ "any.py", "--seed_everything=1234", - "--model=tests.helpers.BoringModel", - "--data=tests.helpers.BoringDataModule", + "--model=pytorch_lightning.utilities.debugging_examples.BoringModel", + "--data=pytorch_lightning.utilities.debugging_examples.BoringDataModule", "--print_config", ] @@ -437,8 +440,8 @@ def test_lightning_cli_print_config(): outval = yaml.safe_load(out.getvalue()) assert outval["seed_everything"] == 1234 - assert outval["model"]["class_path"] == "tests.helpers.BoringModel" - assert outval["data"]["class_path"] == "tests.helpers.BoringDataModule" + assert outval["model"]["class_path"] == "pytorch_lightning.utilities.debugging_examples.BoringModel" + assert outval["data"]["class_path"] == "pytorch_lightning.utilities.debugging_examples.BoringDataModule" def test_lightning_cli_submodules(tmpdir): @@ -451,9 +454,9 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai config = """model: main_param: 2 submodule1: - class_path: tests.helpers.BoringModel + class_path: pytorch_lightning.utilities.debugging_examples.BoringModel submodule2: - class_path: tests.helpers.BoringModel + class_path: pytorch_lightning.utilities.debugging_examples.BoringModel """ config_path = tmpdir / "config.yaml" with open(config_path, "w") as f: diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index 6c5b5a8e33ffb..a5cc421d860c8 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -3,7 +3,8 @@ from torch.utils.data.dataloader import DataLoader from pytorch_lightning.utilities.data import extract_batch_size, get_len, has_iterable_dataset, has_len -from tests.helpers.boring_model import RandomDataset, RandomIterableDataset +from pytorch_lightning.utilities.debugging_examples import RandomDataset +from tests.helpers.datasets import RandomIterableDataset def test_extract_batch_size(): diff --git a/tests/utilities/test_deepspeed_collate_checkpoint.py b/tests/utilities/test_deepspeed_collate_checkpoint.py index 45c8f1a9a1d4f..0c911f2f1a5c9 100644 --- a/tests/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/utilities/test_deepspeed_collate_checkpoint.py @@ -17,8 +17,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DeepSpeedPlugin +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict -from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index f209a310eea39..745e5f2547cbf 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -17,7 +17,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index b351165e03fd8..7cd6169119629 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -23,9 +23,9 @@ from pytorch_lightning import Trainer from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher -from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/utilities/test_memory.py b/tests/utilities/test_memory.py index b486157480877..82cde9e234b2b 100644 --- a/tests/utilities/test_memory.py +++ b/tests/utilities/test_memory.py @@ -16,8 +16,8 @@ import torch import torch.nn as nn +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.memory import get_model_size_mb, recursive_detach -from tests.helpers import BoringModel def test_recursive_detach(): diff --git a/tests/utilities/test_model_helpers.py b/tests/utilities/test_model_helpers.py index 1319e6b44fd8f..3e57fb32ac538 100644 --- a/tests/utilities/test_model_helpers.py +++ b/tests/utilities/test_model_helpers.py @@ -17,8 +17,8 @@ import pytest from pytorch_lightning import LightningDataModule, Trainer +from pytorch_lightning.utilities.debugging_examples import BoringDataModule, BoringModel from pytorch_lightning.utilities.model_helpers import is_overridden -from tests.helpers import BoringDataModule, BoringModel def test_is_overridden(): diff --git a/tests/utilities/test_model_summary.py b/tests/utilities/test_model_summary.py index 0d993bee18ff2..ee4b90c7cd642 100644 --- a/tests/utilities/test_model_summary.py +++ b/tests/utilities/test_model_summary.py @@ -17,9 +17,9 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_9 +from pytorch_lightning.utilities.debugging_examples import BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_summary import ModelSummary, summarize, UNKNOWN_SIZE -from tests.helpers import BoringModel from tests.helpers.advanced_models import ParityModuleRNN from tests.helpers.runif import RunIf diff --git a/tests/utilities/test_remote_filesystem.py b/tests/utilities/test_remote_filesystem.py index 75173a69ae84c..1534e52f5620f 100644 --- a/tests/utilities/test_remote_filesystem.py +++ b/tests/utilities/test_remote_filesystem.py @@ -19,7 +19,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger -from tests.helpers import BoringModel +from pytorch_lightning.utilities.debugging_examples import BoringModel GCS_BUCKET_PATH = os.getenv("GCS_BUCKET_PATH", None) _GCS_BUCKET_PATH_AVAILABLE = GCS_BUCKET_PATH is not None