@@ -34,12 +34,30 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
3434@ AlphaComponent
3535final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
3636
37+ /**
38+ * The given buckets should match 1) its size is larger than zero; 2) it is ordered in a non-DESC
39+ * way.
40+ */
41+ private def checkBuckets (buckets : Array [Double ]): Boolean = {
42+ if (buckets.size == 0 ) false
43+ else if (buckets.size == 1 ) true
44+ else {
45+ buckets.foldLeft((true , Double .MinValue )) { case ((validator, prevValue), currValue) =>
46+ if (validator & prevValue <= currValue) {
47+ (true , currValue)
48+ } else {
49+ (false , currValue)
50+ }
51+ }._1
52+ }
53+ }
54+
3755 /**
3856 * Parameter for mapping continuous features into buckets.
3957 * @group param
4058 */
4159 val buckets : Param [Array [Double ]] = new Param [Array [Double ]](this , " buckets" ,
42- " Map continuous features into buckets." )
60+ " Split points for mapping continuous features into buckets." , checkBuckets )
4361
4462 /** @group getParam */
4563 def getBuckets : Array [Double ] = $(buckets)
@@ -55,7 +73,7 @@ final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
5573
5674 override def transform (dataset : DataFrame ): DataFrame = {
5775 transformSchema(dataset.schema)
58- val bucketizer = udf { feature : Double => binarySearchForBins ($(buckets), feature) }
76+ val bucketizer = udf { feature : Double => binarySearchForBuckets ($(buckets), feature) }
5977 val outputColName = $(outputCol)
6078 val metadata = NominalAttribute .defaultAttr
6179 .withName(outputColName).withValues($(buckets).map(_.toString)).toMetadata()
@@ -65,7 +83,7 @@ final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
6583 /**
6684 * Binary searching in several buckets to place each data point.
6785 */
68- private def binarySearchForBins (splits : Array [Double ], feature : Double ): Double = {
86+ private def binarySearchForBuckets (splits : Array [Double ], feature : Double ): Double = {
6987 val wrappedSplits = Array (Double .MinValue ) ++ splits ++ Array (Double .MaxValue )
7088 var left = 0
7189 var right = wrappedSplits.length - 2
0 commit comments