diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index e9374f8ea4be1..f1ccf2a2726a2 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -21,8 +21,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset -from tests_pytorch.helpers.datasets import RandomIterableDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index 859cf2fa98c0c..ae70b3d205a20 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -22,10 +22,9 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import StochasticWeightAveraging -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from pytorch_lightning.strategies import DDPSpawnStrategy, Strategy from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.datasets import RandomIterableDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/helpers/datasets.py b/tests/tests_pytorch/helpers/datasets.py index 3443020d4528f..c9d185313e85e 100644 --- a/tests/tests_pytorch/helpers/datasets.py +++ b/tests/tests_pytorch/helpers/datasets.py @@ -19,7 +19,7 @@ from typing import Optional, Sequence, Tuple import torch -from torch.utils.data import Dataset, IterableDataset +from torch.utils.data import Dataset class MNIST(Dataset): @@ -212,40 +212,3 @@ def __getitem__(self, idx): def __len__(self): return len(self.y) - - -class RandomDictDataset(Dataset): - def __init__(self, size: int, length: int): - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index): - a = self.data[index] - b = a + 2 - return {"a": a, "b": b} - - def __len__(self): - return self.len - - -class RandomIterableDataset(IterableDataset): - def __init__(self, size: int, count: int): - self.count = count - self.size = size - - def __iter__(self): - for _ in range(self.count): - yield torch.randn(self.size) - - -class RandomIterableDatasetWithLen(IterableDataset): - def __init__(self, size: int, count: int): - self.count = count - self.size = size - - def __iter__(self): - for _ in range(len(self)): - yield torch.randn(self.size) - - def __len__(self): - return self.count diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index 272b03a846688..e3c6f95f3ff47 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -28,13 +28,12 @@ from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin from pytorch_lightning.strategies import DeepSpeedStrategy from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE, LightningDeepSpeedModule from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.datamodules import ClassifDataModule -from tests_pytorch.helpers.datasets import RandomIterableDataset from tests_pytorch.helpers.runif import RunIf if _DEEPSPEED_AVAILABLE: diff --git a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py index 9414fd1c5096f..e5fd9b5dd2706 100644 --- a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py +++ b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py @@ -16,10 +16,9 @@ import pytest from torch.utils.data import DataLoader -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.datasets import RandomIterableDataset @pytest.mark.parametrize("max_epochs", [1, 2, 3]) diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index d16be306b9365..85ed3d8e3471d 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -28,9 +28,8 @@ from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from pytorch_lightning.core.module import LightningModule -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomDictDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.datasets import RandomDictDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 92a1126294dfc..846a39a748a60 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -22,11 +22,10 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler -from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.demos.boring_classes import BoringModel, RandomIterableDataset from pytorch_lightning.strategies.ipu import IPUStrategy from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.datasets import RandomIterableDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index 5bea5a4cbbe1c..34504392dc0c1 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -25,12 +25,16 @@ from pytorch_lightning import Callback, seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import ( + BoringModel, + RandomDataset, + RandomIterableDataset, + RandomIterableDatasetWithLen, +) from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_iterable_dataset, has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader -from tests_pytorch.helpers.datasets import RandomIterableDataset, RandomIterableDatasetWithLen from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index e4be8929f9c7e..9506acee425d0 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -41,7 +41,12 @@ from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import ( + BoringModel, + RandomDataset, + RandomIterableDataset, + RandomIterableDatasetWithLen, +) from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler from pytorch_lightning.strategies import ( @@ -60,7 +65,6 @@ from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.seed import seed_everything from tests_pytorch.helpers.datamodules import ClassifDataModule -from tests_pytorch.helpers.datasets import RandomIterableDataset, RandomIterableDatasetWithLen from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index ffb898efaa815..3700feaba9992 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -6,7 +6,7 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from pytorch_lightning import Trainer -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.data import ( @@ -23,7 +23,6 @@ warning_cache, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.datasets import RandomIterableDataset from tests_pytorch.helpers.utils import no_warning_call