@@ -317,26 +317,33 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable)
317317
318318
319319@inherit_doc
320- class Bucketizer (JavaTransformer , HasInputCol , HasOutputCol , HasHandleInvalid ,
321- JavaMLReadable , JavaMLWritable ):
322- """
323- Maps a column of continuous features to a column of feature buckets.
324-
325- >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
326- >>> df = spark.createDataFrame(values, ["values"])
320+ class Bucketizer (JavaTransformer , HasInputCol , HasOutputCol , HasInputCols , HasOutputCols ,
321+ HasHandleInvalid , JavaMLReadable , JavaMLWritable ):
322+ """
323+ Maps a column of continuous features to a column of feature buckets. Since 2.3.0,
324+ :py:class:`Bucketizer` can map multiple columns at once by setting the :py:attr:`inputCols`
325+ parameter. Note that when both the :py:attr:`inputCol` and :py:attr:`inputCols` parameters
326+ are set, an Exception will be thrown. The :py:attr:`splits` parameter is only used for single
327+ column usage, and :py:attr:`splitsArray` is for multiple columns.
328+
329+ >>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")),
330+ ... (float("nan"), 1.0), (float("nan"), 0.0)]
331+ >>> df = spark.createDataFrame(values, ["values1", "values2"])
327332 >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")],
328- ... inputCol="values", outputCol="buckets")
329- >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect()
330- >>> len(bucketed)
331- 6
332- >>> bucketed[0].buckets
333- 0.0
334- >>> bucketed[1].buckets
335- 0.0
336- >>> bucketed[2].buckets
337- 1.0
338- >>> bucketed[3].buckets
339- 2.0
333+ ... inputCol="values1", outputCol="buckets")
334+ >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1"))
335+ >>> bucketed.show(truncate=False)
336+ +-------+-------+
337+ |values1|buckets|
338+ +-------+-------+
339+ |0.1 |0.0 |
340+ |0.4 |0.0 |
341+ |1.2 |1.0 |
342+ |1.5 |2.0 |
343+ |NaN |3.0 |
344+ |NaN |3.0 |
345+ +-------+-------+
346+ ...
340347 >>> bucketizer.setParams(outputCol="b").transform(df).head().b
341348 0.0
342349 >>> bucketizerPath = temp_path + "/bucketizer"
@@ -347,6 +354,22 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
347354 >>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect()
348355 >>> len(bucketed)
349356 4
357+ >>> bucketizer2 = Bucketizer(splitsArray=
358+ ... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf")]],
359+ ... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"])
360+ >>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df)
361+ >>> bucketed2.show(truncate=False)
362+ +-------+-------+--------+--------+
363+ |values1|values2|buckets1|buckets2|
364+ +-------+-------+--------+--------+
365+ |0.1 |0.0 |0.0 |0.0 |
366+ |0.4 |1.0 |0.0 |1.0 |
367+ |1.2 |1.3 |1.0 |1.0 |
368+ |1.5 |NaN |2.0 |2.0 |
369+ |NaN |1.0 |3.0 |1.0 |
370+ |NaN |0.0 |3.0 |0.0 |
371+ +-------+-------+--------+--------+
372+ ...
350373
351374 .. versionadded:: 1.4.0
352375 """
@@ -363,14 +386,30 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
363386
364387 handleInvalid = Param (Params ._dummy (), "handleInvalid" , "how to handle invalid entries. " +
365388 "Options are 'skip' (filter out rows with invalid values), " +
366- "'error' (throw an error), or 'keep' (keep invalid values in a special " +
367- "additional bucket)." ,
389+ "'error' (throw an error), or 'keep' (keep invalid values in a " +
390+ "special additional bucket). Note that in the multiple column " +
391+ "case, the invalid handling is applied to all columns. That said " +
392+ "for 'error' it will throw an error if any invalids are found in " +
393+ "any column, for 'skip' it will skip rows with any invalids in " +
394+ "any columns, etc." ,
368395 typeConverter = TypeConverters .toString )
369396
397+ splitsArray = Param (Params ._dummy (), "splitsArray" , "The array of split points for mapping " +
398+ "continuous features into buckets for multiple columns. For each input " +
399+ "column, with n+1 splits, there are n buckets. A bucket defined by " +
400+ "splits x,y holds values in the range [x,y) except the last bucket, " +
401+ "which also includes y. The splits should be of length >= 3 and " +
402+ "strictly increasing. Values at -inf, inf must be explicitly provided " +
403+ "to cover all Double values; otherwise, values outside the splits " +
404+ "specified will be treated as errors." ,
405+ typeConverter = TypeConverters .toListListFloat )
406+
370407 @keyword_only
371- def __init__ (self , splits = None , inputCol = None , outputCol = None , handleInvalid = "error" ):
408+ def __init__ (self , splits = None , inputCol = None , outputCol = None , handleInvalid = "error" ,
409+ splitsArray = None , inputCols = None , outputCols = None ):
372410 """
373- __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
411+ __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
412+ splitsArray=None, inputCols=None, outputCols=None)
374413 """
375414 super (Bucketizer , self ).__init__ ()
376415 self ._java_obj = self ._new_java_obj ("org.apache.spark.ml.feature.Bucketizer" , self .uid )
@@ -380,9 +419,11 @@ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="er
380419
381420 @keyword_only
382421 @since ("1.4.0" )
383- def setParams (self , splits = None , inputCol = None , outputCol = None , handleInvalid = "error" ):
422+ def setParams (self , splits = None , inputCol = None , outputCol = None , handleInvalid = "error" ,
423+ splitsArray = None , inputCols = None , outputCols = None ):
384424 """
385- setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
425+ setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
426+ splitsArray=None, inputCols=None, outputCols=None)
386427 Sets params for this Bucketizer.
387428 """
388429 kwargs = self ._input_kwargs
@@ -402,6 +443,20 @@ def getSplits(self):
402443 """
403444 return self .getOrDefault (self .splits )
404445
446+ @since ("2.3.0" )
447+ def setSplitsArray (self , value ):
448+ """
449+ Sets the value of :py:attr:`splitsArray`.
450+ """
451+ return self ._set (splitsArray = value )
452+
453+ @since ("2.3.0" )
454+ def getSplitsArray (self ):
455+ """
456+ Gets the array of split points or its default value.
457+ """
458+ return self .getOrDefault (self .splitsArray )
459+
405460
406461@inherit_doc
407462class CountVectorizer (JavaEstimator , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
0 commit comments