@@ -39,12 +39,12 @@ private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])
3939
4040 /**
4141 * Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
42- * A bucket defined by splits x,y holds values in the range ( x,y] .
42+ * A bucket defined by splits x,y holds values in the range [ x,y) .
4343 * @group param
4444 */
4545 val splits : Param [Array [Double ]] = new Param [Array [Double ]](this , " splits" ,
4646 " Split points for mapping continuous features into buckets. With n splits, there are n+1" +
47- " buckets. A bucket defined by splits x,y holds values in the range ( x,y] ." ,
47+ " buckets. A bucket defined by splits x,y holds values in the range [ x,y) ." ,
4848 Bucketizer .checkSplits)
4949
5050 /** @group getParam */
@@ -85,7 +85,8 @@ private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])
8585 transformSchema(dataset.schema)
8686 val wrappedSplits = Array (Double .MinValue ) ++ $(splits) ++ Array (Double .MaxValue )
8787 val bucketizer = udf { feature : Double =>
88- Bucketizer .binarySearchForBuckets(wrappedSplits, feature) }
88+ Bucketizer
89+ .binarySearchForBuckets(wrappedSplits, feature, $(lowerInclusive), $(upperInclusive)) }
8990 val newCol = bucketizer(dataset($(inputCol)))
9091 val newField = prepOutputField(dataset.schema)
9192 dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
@@ -95,7 +96,6 @@ private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])
9596 val attr = new NominalAttribute (
9697 name = Some ($(outputCol)),
9798 isOrdinal = Some (true ),
98- numValues = Some ($(splits).size),
9999 values = Some ($(splits).map(_.toString)))
100100
101101 attr.toStructField()
@@ -131,20 +131,27 @@ object Bucketizer {
131131 /**
132132 * Binary searching in several buckets to place each data point.
133133 */
134- private [feature] def binarySearchForBuckets (splits : Array [Double ], feature : Double ): Double = {
134+ private [feature] def binarySearchForBuckets (
135+ splits : Array [Double ],
136+ feature : Double ,
137+ lowerInclusive : Boolean ,
138+ upperInclusive : Boolean ): Double = {
139+ if ((feature < splits.head && ! lowerInclusive) || (feature > splits.last && ! upperInclusive))
140+ throw new Exception (s " Feature $feature out of bound, check your features or loose the " +
141+ s " lower/upper bound constraint. " )
135142 var left = 0
136143 var right = splits.length - 2
137144 while (left <= right) {
138145 val mid = left + (right - left) / 2
139146 val split = splits(mid)
140- if ((feature > split) && (feature <= splits(mid + 1 ))) {
147+ if ((feature >= split) && (feature < splits(mid + 1 ))) {
141148 return mid
142- } else if (feature <= split) {
149+ } else if (feature < split) {
143150 right = mid - 1
144151 } else {
145152 left = mid + 1
146153 }
147154 }
148- throw new Exception (" Failed to find a bucket." )
155+ throw new Exception (s " Failed to find a bucket for feature $feature . " )
149156 }
150157}
0 commit comments