3333from pyspark .tests import ReusedPySparkTestCase as PySparkTestCase
3434from pyspark .sql import DataFrame
3535from pyspark .ml .param import Param
36+ from pyspark .ml .param .shared import HasMaxIter , HasInputCol
3637from pyspark .ml .pipeline import Transformer , Estimator , Pipeline
3738
3839
@@ -46,7 +47,7 @@ class MockTransformer(Transformer):
4647
4748 def __init__ (self ):
4849 super (MockTransformer , self ).__init__ ()
49- self .fake = Param (self , "fake" , "fake" , None )
50+ self .fake = Param (self , "fake" , "fake" )
5051 self .dataset_index = None
5152 self .fake_param_value = None
5253
@@ -62,7 +63,7 @@ class MockEstimator(Estimator):
6263
6364 def __init__ (self ):
6465 super (MockEstimator , self ).__init__ ()
65- self .fake = Param (self , "fake" , "fake" , None )
66+ self .fake = Param (self , "fake" , "fake" )
6667 self .dataset_index = None
6768 self .fake_param_value = None
6869 self .model = None
@@ -111,5 +112,47 @@ def test_pipeline(self):
111112 self .assertEqual (6 , dataset .index )
112113
113114
115+ class TestParams (HasMaxIter , HasInputCol ):
116+ """
117+ A subclass of Params mixed with HasMaxIter and HasInputCol.
118+ """
119+
120+ def __init__ (self ):
121+ super (TestParams , self ).__init__ ()
122+ self ._setDefault (maxIter = 10 )
123+
124+
125+ class ParamTests (PySparkTestCase ):
126+
127+ def test_param (self ):
128+ testParams = TestParams ()
129+ maxIter = testParams .maxIter
130+ self .assertEqual (maxIter .name , "maxIter" )
131+ self .assertEqual (maxIter .doc , "max number of iterations" )
132+ self .assertTrue (maxIter .parent is testParams )
133+
134+ def test_params (self ):
135+ testParams = TestParams ()
136+ maxIter = testParams .maxIter
137+ inputCol = testParams .inputCol
138+
139+ params = testParams .params
140+ self .assertEqual (params , [inputCol , maxIter ])
141+
142+ self .assertTrue (testParams .hasDefault (maxIter ))
143+ self .assertFalse (testParams .isSet (maxIter ))
144+ self .assertTrue (testParams .isDefined (maxIter ))
145+ self .assertEqual (testParams .getMaxIter (), 10 )
146+ testParams .setMaxIter (100 )
147+ self .assertTrue (testParams .isSet (maxIter ))
148+ self .assertEquals (testParams .getMaxIter (), 100 )
149+
150+ self .assertFalse (testParams .hasDefault (inputCol ))
151+ self .assertFalse (testParams .isSet (inputCol ))
152+ self .assertFalse (testParams .isDefined (inputCol ))
153+ with self .assertRaises (KeyError ):
154+ testParams .getInputCol ()
155+
156+
114157if __name__ == "__main__" :
115158 unittest .main ()
0 commit comments