Skip to content

Commit c22eaa9

Browse files
zhengruifengNick Pentreath
authored andcommitted
[SPARK-22797][PYSPARK] Bucketizer support multi-column
## What changes were proposed in this pull request? Bucketizer support multi-column in the python side ## How was this patch tested? existing tests and added tests Author: Zheng RuiFeng <[email protected]> Closes #19892 from zhengruifeng/20542_py.
1 parent cd3956d commit c22eaa9

File tree

3 files changed

+99
-25
lines changed

3 files changed

+99
-25
lines changed

python/pyspark/ml/feature.py

Lines changed: 80 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -317,26 +317,33 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable)
317317

318318

319319
@inherit_doc
320-
class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
321-
JavaMLReadable, JavaMLWritable):
322-
"""
323-
Maps a column of continuous features to a column of feature buckets.
324-
325-
>>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
326-
>>> df = spark.createDataFrame(values, ["values"])
320+
class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
321+
HasHandleInvalid, JavaMLReadable, JavaMLWritable):
322+
"""
323+
Maps a column of continuous features to a column of feature buckets. Since 2.3.0,
324+
:py:class:`Bucketizer` can map multiple columns at once by setting the :py:attr:`inputCols`
325+
parameter. Note that when both the :py:attr:`inputCol` and :py:attr:`inputCols` parameters
326+
are set, an Exception will be thrown. The :py:attr:`splits` parameter is only used for single
327+
column usage, and :py:attr:`splitsArray` is for multiple columns.
328+
329+
>>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")),
330+
... (float("nan"), 1.0), (float("nan"), 0.0)]
331+
>>> df = spark.createDataFrame(values, ["values1", "values2"])
327332
>>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")],
328-
... inputCol="values", outputCol="buckets")
329-
>>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect()
330-
>>> len(bucketed)
331-
6
332-
>>> bucketed[0].buckets
333-
0.0
334-
>>> bucketed[1].buckets
335-
0.0
336-
>>> bucketed[2].buckets
337-
1.0
338-
>>> bucketed[3].buckets
339-
2.0
333+
... inputCol="values1", outputCol="buckets")
334+
>>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1"))
335+
>>> bucketed.show(truncate=False)
336+
+-------+-------+
337+
|values1|buckets|
338+
+-------+-------+
339+
|0.1 |0.0 |
340+
|0.4 |0.0 |
341+
|1.2 |1.0 |
342+
|1.5 |2.0 |
343+
|NaN |3.0 |
344+
|NaN |3.0 |
345+
+-------+-------+
346+
...
340347
>>> bucketizer.setParams(outputCol="b").transform(df).head().b
341348
0.0
342349
>>> bucketizerPath = temp_path + "/bucketizer"
@@ -347,6 +354,22 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
347354
>>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect()
348355
>>> len(bucketed)
349356
4
357+
>>> bucketizer2 = Bucketizer(splitsArray=
358+
... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf")]],
359+
... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"])
360+
>>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df)
361+
>>> bucketed2.show(truncate=False)
362+
+-------+-------+--------+--------+
363+
|values1|values2|buckets1|buckets2|
364+
+-------+-------+--------+--------+
365+
|0.1 |0.0 |0.0 |0.0 |
366+
|0.4 |1.0 |0.0 |1.0 |
367+
|1.2 |1.3 |1.0 |1.0 |
368+
|1.5 |NaN |2.0 |2.0 |
369+
|NaN |1.0 |3.0 |1.0 |
370+
|NaN |0.0 |3.0 |0.0 |
371+
+-------+-------+--------+--------+
372+
...
350373
351374
.. versionadded:: 1.4.0
352375
"""
@@ -363,14 +386,30 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
363386

364387
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
365388
"Options are 'skip' (filter out rows with invalid values), " +
366-
"'error' (throw an error), or 'keep' (keep invalid values in a special " +
367-
"additional bucket).",
389+
"'error' (throw an error), or 'keep' (keep invalid values in a " +
390+
"special additional bucket). Note that in the multiple column " +
391+
"case, the invalid handling is applied to all columns. That said " +
392+
"for 'error' it will throw an error if any invalids are found in " +
393+
"any column, for 'skip' it will skip rows with any invalids in " +
394+
"any columns, etc.",
368395
typeConverter=TypeConverters.toString)
369396

397+
splitsArray = Param(Params._dummy(), "splitsArray", "The array of split points for mapping " +
398+
"continuous features into buckets for multiple columns. For each input " +
399+
"column, with n+1 splits, there are n buckets. A bucket defined by " +
400+
"splits x,y holds values in the range [x,y) except the last bucket, " +
401+
"which also includes y. The splits should be of length >= 3 and " +
402+
"strictly increasing. Values at -inf, inf must be explicitly provided " +
403+
"to cover all Double values; otherwise, values outside the splits " +
404+
"specified will be treated as errors.",
405+
typeConverter=TypeConverters.toListListFloat)
406+
370407
@keyword_only
371-
def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"):
408+
def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error",
409+
splitsArray=None, inputCols=None, outputCols=None):
372410
"""
373-
__init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
411+
__init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
412+
splitsArray=None, inputCols=None, outputCols=None)
374413
"""
375414
super(Bucketizer, self).__init__()
376415
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid)
@@ -380,9 +419,11 @@ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="er
380419

381420
@keyword_only
382421
@since("1.4.0")
383-
def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"):
422+
def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error",
423+
splitsArray=None, inputCols=None, outputCols=None):
384424
"""
385-
setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
425+
setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
426+
splitsArray=None, inputCols=None, outputCols=None)
386427
Sets params for this Bucketizer.
387428
"""
388429
kwargs = self._input_kwargs
@@ -402,6 +443,20 @@ def getSplits(self):
402443
"""
403444
return self.getOrDefault(self.splits)
404445

446+
@since("2.3.0")
447+
def setSplitsArray(self, value):
448+
"""
449+
Sets the value of :py:attr:`splitsArray`.
450+
"""
451+
return self._set(splitsArray=value)
452+
453+
@since("2.3.0")
454+
def getSplitsArray(self):
455+
"""
456+
Gets the array of split points or its default value.
457+
"""
458+
return self.getOrDefault(self.splitsArray)
459+
405460

406461
@inherit_doc
407462
class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):

python/pyspark/ml/param/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,16 @@ def toListFloat(value):
134134
return [float(v) for v in value]
135135
raise TypeError("Could not convert %s to list of floats" % value)
136136

137+
@staticmethod
138+
def toListListFloat(value):
139+
"""
140+
Convert a value to list of list of floats, if possible.
141+
"""
142+
if TypeConverters._can_convert_to_list(value):
143+
value = TypeConverters.toList(value)
144+
return [TypeConverters.toListFloat(v) for v in value]
145+
raise TypeError("Could not convert %s to list of list of floats" % value)
146+
137147
@staticmethod
138148
def toListInt(value):
139149
"""

python/pyspark/ml/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,15 @@ def test_bool(self):
238238
self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1))
239239
self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false"))
240240

241+
def test_list_list_float(self):
242+
b = Bucketizer(splitsArray=[[-0.1, 0.5, 3], [-5, 1.5]])
243+
self.assertEqual(b.getSplitsArray(), [[-0.1, 0.5, 3.0], [-5.0, 1.5]])
244+
self.assertTrue(all([type(v) == list for v in b.getSplitsArray()]))
245+
self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[0]]))
246+
self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[1]]))
247+
self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=["a", 1.0]))
248+
self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=[[-5, 1.5], ["a", 1.0]]))
249+
241250

242251
class PipelineTests(PySparkTestCase):
243252

0 commit comments

Comments
 (0)