Skip to content

Commit 9b61b1c

Browse files
authored
Remove duplicated test classes (#14122)
Remove duplicated classes
1 parent 4e87a44 commit 9b61b1c

File tree

10 files changed

+20
-56
lines changed

10 files changed

+20
-56
lines changed

tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
from pytorch_lightning import Trainer
2222
from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar
2323
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
24-
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
25-
from tests_pytorch.helpers.datasets import RandomIterableDataset
24+
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
2625
from tests_pytorch.helpers.runif import RunIf
2726

2827

tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@
2626

2727
from pytorch_lightning import Trainer
2828
from pytorch_lightning.callbacks import StochasticWeightAveraging
29-
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
29+
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
3030
from pytorch_lightning.strategies import DDPSpawnStrategy, Strategy
3131
from pytorch_lightning.utilities.exceptions import MisconfigurationException
32-
from tests_pytorch.helpers.datasets import RandomIterableDataset
3332
from tests_pytorch.helpers.runif import RunIf
3433

3534

tests/tests_pytorch/helpers/datasets.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Optional, Sequence, Tuple
2020

2121
import torch
22-
from torch.utils.data import Dataset, IterableDataset
22+
from torch.utils.data import Dataset
2323

2424

2525
class MNIST(Dataset):
@@ -212,40 +212,3 @@ def __getitem__(self, idx):
212212

213213
def __len__(self):
214214
return len(self.y)
215-
216-
217-
class RandomDictDataset(Dataset):
218-
def __init__(self, size: int, length: int):
219-
self.len = length
220-
self.data = torch.randn(length, size)
221-
222-
def __getitem__(self, index):
223-
a = self.data[index]
224-
b = a + 2
225-
return {"a": a, "b": b}
226-
227-
def __len__(self):
228-
return self.len
229-
230-
231-
class RandomIterableDataset(IterableDataset):
232-
def __init__(self, size: int, count: int):
233-
self.count = count
234-
self.size = size
235-
236-
def __iter__(self):
237-
for _ in range(self.count):
238-
yield torch.randn(self.size)
239-
240-
241-
class RandomIterableDatasetWithLen(IterableDataset):
242-
def __init__(self, size: int, count: int):
243-
self.count = count
244-
self.size = size
245-
246-
def __iter__(self):
247-
for _ in range(len(self)):
248-
yield torch.randn(self.size)
249-
250-
def __len__(self):
251-
return self.count

tests/tests_pytorch/strategies/test_deepspeed_strategy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,12 @@
2828

2929
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
3030
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
31-
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
31+
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
3232
from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin
3333
from pytorch_lightning.strategies import DeepSpeedStrategy
3434
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE, LightningDeepSpeedModule
3535
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3636
from tests_pytorch.helpers.datamodules import ClassifDataModule
37-
from tests_pytorch.helpers.datasets import RandomIterableDataset
3837
from tests_pytorch.helpers.runif import RunIf
3938

4039
if _DEEPSPEED_AVAILABLE:

tests/tests_pytorch/trainer/flags/test_val_check_interval.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
import pytest
1717
from torch.utils.data import DataLoader
1818

19-
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
19+
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
2020
from pytorch_lightning.trainer.trainer import Trainer
2121
from pytorch_lightning.utilities.exceptions import MisconfigurationException
22-
from tests_pytorch.helpers.datasets import RandomIterableDataset
2322

2423

2524
@pytest.mark.parametrize("max_epochs", [1, 2, 3])

tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@
2828
from pytorch_lightning import callbacks, Trainer
2929
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
3030
from pytorch_lightning.core.module import LightningModule
31-
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
31+
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomDictDataset
3232
from pytorch_lightning.utilities.exceptions import MisconfigurationException
33-
from tests_pytorch.helpers.datasets import RandomDictDataset
3433
from tests_pytorch.helpers.runif import RunIf
3534

3635

tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@
2222

2323
from pytorch_lightning import Trainer
2424
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
25-
from pytorch_lightning.demos.boring_classes import BoringModel
25+
from pytorch_lightning.demos.boring_classes import BoringModel, RandomIterableDataset
2626
from pytorch_lightning.strategies.ipu import IPUStrategy
2727
from pytorch_lightning.utilities import device_parser
2828
from pytorch_lightning.utilities.exceptions import MisconfigurationException
29-
from tests_pytorch.helpers.datasets import RandomIterableDataset
3029
from tests_pytorch.helpers.runif import RunIf
3130

3231

tests/tests_pytorch/trainer/test_dataloaders.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,16 @@
2525

2626
from pytorch_lightning import Callback, seed_everything, Trainer
2727
from pytorch_lightning.callbacks import ModelCheckpoint
28-
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
28+
from pytorch_lightning.demos.boring_classes import (
29+
BoringModel,
30+
RandomDataset,
31+
RandomIterableDataset,
32+
RandomIterableDatasetWithLen,
33+
)
2934
from pytorch_lightning.trainer.states import RunningStage
3035
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_iterable_dataset, has_len_all_ranks
3136
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3237
from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader
33-
from tests_pytorch.helpers.datasets import RandomIterableDataset, RandomIterableDatasetWithLen
3438
from tests_pytorch.helpers.runif import RunIf
3539

3640

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@
4141
from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint
4242
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
4343
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
44-
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
44+
from pytorch_lightning.demos.boring_classes import (
45+
BoringModel,
46+
RandomDataset,
47+
RandomIterableDataset,
48+
RandomIterableDatasetWithLen,
49+
)
4550
from pytorch_lightning.loggers import TensorBoardLogger
4651
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
4752
from pytorch_lightning.strategies import (
@@ -60,7 +65,6 @@
6065
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_12
6166
from pytorch_lightning.utilities.seed import seed_everything
6267
from tests_pytorch.helpers.datamodules import ClassifDataModule
63-
from tests_pytorch.helpers.datasets import RandomIterableDataset, RandomIterableDatasetWithLen
6468
from tests_pytorch.helpers.runif import RunIf
6569
from tests_pytorch.helpers.simple_models import ClassificationModel
6670

tests/tests_pytorch/utilities/test_data.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
77

88
from pytorch_lightning import Trainer
9-
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
9+
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
1010
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
1111
from pytorch_lightning.trainer.states import RunningStage
1212
from pytorch_lightning.utilities.data import (
@@ -23,7 +23,6 @@
2323
warning_cache,
2424
)
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
26-
from tests_pytorch.helpers.datasets import RandomIterableDataset
2726
from tests_pytorch.helpers.utils import no_warning_call
2827

2928

0 commit comments

Comments
 (0)