Skip to content

Commit 4d6b07a

Browse files
committed
add tests
1 parent 5294500 commit 4d6b07a

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed

python/pyspark/ml/tests.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
3434
from pyspark.sql import DataFrame
3535
from pyspark.ml.param import Param
36+
from pyspark.ml.param.shared import HasMaxIter, HasInputCol
3637
from pyspark.ml.pipeline import Transformer, Estimator, Pipeline
3738

3839

@@ -46,7 +47,7 @@ class MockTransformer(Transformer):
4647

4748
def __init__(self):
4849
super(MockTransformer, self).__init__()
49-
self.fake = Param(self, "fake", "fake", None)
50+
self.fake = Param(self, "fake", "fake")
5051
self.dataset_index = None
5152
self.fake_param_value = None
5253

@@ -62,7 +63,7 @@ class MockEstimator(Estimator):
6263

6364
def __init__(self):
6465
super(MockEstimator, self).__init__()
65-
self.fake = Param(self, "fake", "fake", None)
66+
self.fake = Param(self, "fake", "fake")
6667
self.dataset_index = None
6768
self.fake_param_value = None
6869
self.model = None
@@ -111,5 +112,47 @@ def test_pipeline(self):
111112
self.assertEqual(6, dataset.index)
112113

113114

115+
class TestParams(HasMaxIter, HasInputCol):
116+
"""
117+
A subclass of Params mixed with HasMaxIter and HasInputCol.
118+
"""
119+
120+
def __init__(self):
121+
super(TestParams, self).__init__()
122+
self._setDefault(maxIter=10)
123+
124+
125+
class ParamTests(PySparkTestCase):
126+
127+
def test_param(self):
128+
testParams = TestParams()
129+
maxIter = testParams.maxIter
130+
self.assertEqual(maxIter.name, "maxIter")
131+
self.assertEqual(maxIter.doc, "max number of iterations")
132+
self.assertTrue(maxIter.parent is testParams)
133+
134+
def test_params(self):
135+
testParams = TestParams()
136+
maxIter = testParams.maxIter
137+
inputCol = testParams.inputCol
138+
139+
params = testParams.params
140+
self.assertEqual(params, [inputCol, maxIter])
141+
142+
self.assertTrue(testParams.hasDefault(maxIter))
143+
self.assertFalse(testParams.isSet(maxIter))
144+
self.assertTrue(testParams.isDefined(maxIter))
145+
self.assertEqual(testParams.getMaxIter(), 10)
146+
testParams.setMaxIter(100)
147+
self.assertTrue(testParams.isSet(maxIter))
148+
self.assertEquals(testParams.getMaxIter(), 100)
149+
150+
self.assertFalse(testParams.hasDefault(inputCol))
151+
self.assertFalse(testParams.isSet(inputCol))
152+
self.assertFalse(testParams.isDefined(inputCol))
153+
with self.assertRaises(KeyError):
154+
testParams.getInputCol()
155+
156+
114157
if __name__ == "__main__":
115158
unittest.main()

0 commit comments

Comments
 (0)