Skip to content

Commit e869e75

Browse files
committed
update tests
1 parent 76de8e6 commit e869e75

File tree

2 files changed

+37
-33
lines changed

2 files changed

+37
-33
lines changed

python/pyspark/ml/feature.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -329,20 +329,22 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu
329329
330330
>>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")),
331331
... (float("nan"), 1.0), (float("nan"), 0.0)]
332-
>>> df = spark.createDataFrame(values, ["values", "numbers"])
332+
>>> df = spark.createDataFrame(values, ["values1", "values2"])
333333
>>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")],
334-
... inputCol="values", outputCol="buckets")
335-
>>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect()
336-
>>> len(bucketed)
337-
6
338-
>>> bucketed[0].buckets
339-
0.0
340-
>>> bucketed[1].buckets
341-
0.0
342-
>>> bucketed[2].buckets
343-
1.0
344-
>>> bucketed[3].buckets
345-
2.0
334+
... inputCol="values1", outputCol="buckets")
335+
>>> bucketed = bucketizer.setHandleInvalid("keep").transform(df)
336+
>>> bucketed.show(truncate=False)
337+
+-------+-------+-------+
338+
|values1|values2|buckets|
339+
+-------+-------+-------+
340+
|0.1 |0.0 |0.0 |
341+
|0.4 |1.0 |0.0 |
342+
|1.2 |1.3 |1.0 |
343+
|1.5 |NaN |2.0 |
344+
|NaN |1.0 |3.0 |
345+
|NaN |0.0 |3.0 |
346+
+-------+-------+-------+
347+
...
346348
>>> bucketizer.setParams(outputCol="b").transform(df).head().b
347349
0.0
348350
>>> bucketizerPath = temp_path + "/bucketizer"
@@ -355,26 +357,20 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu
355357
4
356358
>>> bucketizer2 = Bucketizer(splitsArray=
357359
... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf")]],
358-
... inputCols=["values", "numbers"], outputCols=["buckets1", "buckets2"])
359-
>>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df).collect()
360-
>>> len(bucketed2)
361-
6
362-
>>> bucketed2[0].buckets1
363-
0.0
364-
>>> bucketed2[1].buckets1
365-
0.0
366-
>>> bucketed2[2].buckets1
367-
1.0
368-
>>> bucketed2[3].buckets1
369-
2.0
370-
>>> bucketed2[0].buckets2
371-
0.0
372-
>>> bucketed2[1].buckets2
373-
1.0
374-
>>> bucketed2[2].buckets2
375-
1.0
376-
>>> bucketed2[3].buckets2
377-
2.0
360+
... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"])
361+
>>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df)
362+
>>> bucketed2.show(truncate=False)
363+
+-------+-------+--------+--------+
364+
|values1|values2|buckets1|buckets2|
365+
+-------+-------+--------+--------+
366+
|0.1 |0.0 |0.0 |0.0 |
367+
|0.4 |1.0 |0.0 |1.0 |
368+
|1.2 |1.3 |1.0 |1.0 |
369+
|1.5 |NaN |2.0 |2.0 |
370+
|NaN |1.0 |3.0 |1.0 |
371+
|NaN |0.0 |3.0 |0.0 |
372+
+-------+-------+--------+--------+
373+
...
378374
379375
.. versionadded:: 1.4.0
380376
"""

python/pyspark/ml/tests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,14 @@ def test_bool(self):
238238
self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1))
239239
self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false"))
240240

241+
def test_list_list_float(self):
242+
b = Bucketizer(splitsArray=[[-0.1, 0.5, 3], [-5, 1.5]])
243+
self.assertEqual(b.getSplitsArray(), [[-0.1, 0.5, 3.0], [-5.0, 1.5]])
244+
self.assertTrue(all([type(v) == list for v in b.getSplitsArray()]))
245+
self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[0]]))
246+
self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[1]]))
247+
self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=["a", 1.0]))
248+
241249

242250
class PipelineTests(PySparkTestCase):
243251

0 commit comments

Comments
 (0)