Skip to content

Commit 4f68c63

Browse files
Refactor tests, skip if sklearn not available (#12093)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f530489 commit 4f68c63

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/helpers/datamodules.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from typing import Optional
1515

16+
import pytest
1617
import torch
1718
from torch.utils.data import DataLoader
1819

@@ -24,10 +25,6 @@
2425
if _SKLEARN_AVAILABLE:
2526
from sklearn.datasets import make_classification, make_regression
2627
from sklearn.model_selection import train_test_split
27-
else:
28-
make_classification = None
29-
make_regression = None
30-
train_test_split = None
3128

3229

3330
class MNISTDataModule(LightningDataModule):
@@ -60,6 +57,8 @@ def test_dataloader(self):
6057

6158
class SklearnDataModule(LightningDataModule):
6259
def __init__(self, sklearn_dataset, x_type, y_type, batch_size: int = 10):
60+
if not _SKLEARN_AVAILABLE:
61+
pytest.skip("`sklearn` is not available.")
6362
super().__init__()
6463
self.batch_size = batch_size
6564
self._x, self._y = sklearn_dataset
@@ -102,6 +101,8 @@ def sample(self):
102101

103102
class ClassifDataModule(SklearnDataModule):
104103
def __init__(self, num_features=32, length=800, num_classes=3, batch_size=10):
104+
if not _SKLEARN_AVAILABLE:
105+
pytest.skip("`sklearn` is not available.")
105106
data = make_classification(
106107
n_samples=length, n_features=num_features, n_classes=num_classes, n_clusters_per_class=1, random_state=42
107108
)
@@ -110,6 +111,8 @@ def __init__(self, num_features=32, length=800, num_classes=3, batch_size=10):
110111

111112
class RegressDataModule(SklearnDataModule):
112113
def __init__(self, num_features=16, length=800, batch_size=10):
114+
if not _SKLEARN_AVAILABLE:
115+
pytest.skip("`sklearn` is not available.")
113116
x, y = make_regression(n_samples=length, n_features=num_features, random_state=42)
114117
y = [[v] for v in y]
115118
super().__init__((x, y), x_type=torch.float32, y_type=torch.float32, batch_size=batch_size)

0 commit comments

Comments
 (0)