1313# limitations under the License.
1414from typing import Optional
1515
16+ import pytest
1617import torch
1718from torch .utils .data import DataLoader
1819
2425if _SKLEARN_AVAILABLE :
2526 from sklearn .datasets import make_classification , make_regression
2627 from sklearn .model_selection import train_test_split
27- else :
28- make_classification = None
29- make_regression = None
30- train_test_split = None
3128
3229
3330class MNISTDataModule (LightningDataModule ):
@@ -60,6 +57,8 @@ def test_dataloader(self):
6057
6158class SklearnDataModule (LightningDataModule ):
6259 def __init__ (self , sklearn_dataset , x_type , y_type , batch_size : int = 10 ):
60+ if not _SKLEARN_AVAILABLE :
61+ pytest .skip ("`sklearn` is not available." )
6362 super ().__init__ ()
6463 self .batch_size = batch_size
6564 self ._x , self ._y = sklearn_dataset
@@ -102,6 +101,8 @@ def sample(self):
102101
103102class ClassifDataModule (SklearnDataModule ):
104103 def __init__ (self , num_features = 32 , length = 800 , num_classes = 3 , batch_size = 10 ):
104+ if not _SKLEARN_AVAILABLE :
105+ pytest .skip ("`sklearn` is not available." )
105106 data = make_classification (
106107 n_samples = length , n_features = num_features , n_classes = num_classes , n_clusters_per_class = 1 , random_state = 42
107108 )
@@ -110,6 +111,8 @@ def __init__(self, num_features=32, length=800, num_classes=3, batch_size=10):
110111
111112class RegressDataModule (SklearnDataModule ):
112113 def __init__ (self , num_features = 16 , length = 800 , batch_size = 10 ):
114+ if not _SKLEARN_AVAILABLE :
115+ pytest .skip ("`sklearn` is not available." )
113116 x , y = make_regression (n_samples = length , n_features = num_features , random_state = 42 )
114117 y = [[v ] for v in y ]
115118 super ().__init__ ((x , y ), x_type = torch .float32 , y_type = torch .float32 , batch_size = batch_size )
0 commit comments