Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 37 additions & 83 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])

/**
* Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
* A bucket defined by splits x,y holds values in the range [x,y). Note that the splits should be
* strictly increasing.
* A bucket defined by splits x,y holds values in the range [x,y). Splits should be strictly
* increasing. Values at -inf, inf must be explicitly provided to cover all Double values;
* otherwise, values outside the splits specified will be treated as errors.
* @group param
*/
val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
"Split points for mapping continuous features into buckets. With n splits, there are n+1 " +
"buckets. A bucket defined by splits x,y holds values in the range [x,y). The splits " +
"should be strictly increasing.",
"should be strictly increasing. Values at -inf, inf must be explicitly provided to cover" +
" all Double values; otherwise, values outside the splits specified will be treated as" +
" errors.",
Bucketizer.checkSplits)

/** @group getParam */
Expand All @@ -55,40 +58,6 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
/** @group setParam */
def setSplits(value: Array[Double]): this.type = set(splits, value)

/**
* An indicator of the inclusiveness of negative infinite. If true, then use implicit bin
* (-inf, getSplits.head). If false, then throw exception if values < getSplits.head are
* encountered.
* @group Param */
val lowerInclusive: BooleanParam = new BooleanParam(this, "lowerInclusive",
"An indicator of the inclusiveness of negative infinite. If true, then use implicit bin " +
"(-inf, getSplits.head). If false, then throw exception if values < getSplits.head are " +
"encountered.")
setDefault(lowerInclusive -> true)

/** @group getParam */
def getLowerInclusive: Boolean = $(lowerInclusive)

/** @group setParam */
def setLowerInclusive(value: Boolean): this.type = set(lowerInclusive, value)

/**
* An indicator of the inclusiveness of positive infinite. If true, then use implicit bin
* [getSplits.last, inf). If false, then throw exception if values > getSplits.last are
* encountered.
* @group Param */
val upperInclusive: BooleanParam = new BooleanParam(this, "upperInclusive",
"An indicator of the inclusiveness of positive infinite. If true, then use implicit bin " +
"[getSplits.last, inf). If false, then throw exception if values > getSplits.last are " +
"encountered.")
setDefault(upperInclusive -> true)

/** @group getParam */
def getUpperInclusive: Boolean = $(upperInclusive)

/** @group setParam */
def setUpperInclusive(value: Boolean): this.type = set(upperInclusive, value)

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

Expand All @@ -97,81 +66,66 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])

override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema)
val wrappedSplits = Array(Double.MinValue) ++ $(splits) ++ Array(Double.MaxValue)
val bucketizer = udf { feature: Double =>
Bucketizer
.binarySearchForBuckets(wrappedSplits, feature, $(lowerInclusive), $(upperInclusive)) }
Bucketizer.binarySearchForBuckets($(splits), feature)
}
val newCol = bucketizer(dataset($(inputCol)))
val newField = prepOutputField(dataset.schema)
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
}

private def prepOutputField(schema: StructType): StructField = {
val innerRanges = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray
val values = ($(lowerInclusive), $(upperInclusive)) match {
case (true, true) =>
Array(s"-inf, ${$(splits).head}") ++ innerRanges ++ Array(s"${$(splits).last}, inf")
case (true, false) => Array(s"-inf, ${$(splits).head}") ++ innerRanges
case (false, true) => innerRanges ++ Array(s"${$(splits).last}, inf")
case _ => innerRanges
}
val attr =
new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true), values = Some(values))
val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray
val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true),
values = Some(buckets))
attr.toStructField()
}

override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
require(schema.fields.forall(_.name != $(outputCol)),
s"Output column ${$(outputCol)} already exists.")
StructType(schema.fields :+ prepOutputField(schema))
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}
}

private[feature] object Bucketizer {
/**
* The given splits should match 1) its size is larger than zero; 2) it is ordered in a strictly
* increasing way.
*/
private def checkSplits(splits: Array[Double]): Boolean = {
if (splits.size == 0) false
else if (splits.size == 1) true
else {
splits.foldLeft((true, Double.MinValue)) { case ((validator, prevValue), currValue) =>
if (validator && prevValue < currValue) {
(true, currValue)
} else {
(false, currValue)
}
}._1
/** We require splits to be of length >= 3 and to be in strictly increasing order. */
def checkSplits(splits: Array[Double]): Boolean = {
if (splits.length < 3) {
false
} else {
var i = 0
while (i < splits.length - 1) {
if (splits(i) >= splits(i + 1)) return false
i += 1
}
true
}
}

/**
* Binary searching in several buckets to place each data point.
* @throws RuntimeException if a feature is < splits.head or >= splits.last
*/
private[feature] def binarySearchForBuckets(
def binarySearchForBuckets(
splits: Array[Double],
feature: Double,
lowerInclusive: Boolean,
upperInclusive: Boolean): Double = {
if ((feature < splits.head && !lowerInclusive) || (feature > splits.last && !upperInclusive)) {
throw new RuntimeException(s"Feature $feature out of bound, check your features or loosen " +
s"the lower/upper bound constraint.")
feature: Double): Double = {
// Check bounds. We make an exception for +inf so that it can exist in some bin.
if ((feature < splits.head) || (feature >= splits.last && feature != Double.PositiveInfinity)) {
throw new RuntimeException(s"Feature value $feature out of Bucketizer bounds" +
s" [${splits.head}, ${splits.last}). Check your features, or loosen " +
s"the lower/upper bound constraints.")
}
var left = 0
var right = splits.length - 2
while (left <= right) {
val mid = left + (right - left) / 2
val split = splits(mid)
if ((feature >= split) && (feature < splits(mid + 1))) {
return mid
} else if (feature < split) {
right = mid - 1
while (left < right) {
val mid = (left + right) / 2
val split = splits(mid + 1)
if (feature < split) {
right = mid
} else {
left = mid + 1
}
}
throw new RuntimeException(s"Unexpected error: failed to find a bucket for feature $feature.")
left
}
}
11 changes: 11 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,15 @@ object SchemaUtils {
val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false)
StructType(outputFields)
}

/**
* Appends a new column to the input schema. This fails if the given output column already exists.
* @param schema input schema
* @param col New column schema
* @return new schema with the input column appended
*/
def appendColumn(schema: StructType, col: StructField): StructType = {
require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.")
StructType(schema.fields :+ col)
}
}
138 changes: 93 additions & 45 deletions mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,28 @@ import scala.util.Random

import org.scalatest.FunSuite

import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

class BucketizerSuite extends FunSuite with MLlibTestSparkContext {

test("Bucket continuous features with setter") {
val sqlContext = new SQLContext(sc)
val data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4, -0.9)
@transient private var sqlContext: SQLContext = _

override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
}

test("Bucket continuous features, without -inf,inf") {
// Check a set of valid feature values.
val splits = Array(-0.5, 0.0, 0.5)
val bucketizedData = Array(2.0, 1.0, 2.0, 1.0, 3.0, 3.0, 1.0, 1.0, 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(
data.zip(bucketizedData)).toDF("feature", "expected")
val validData = Array(-0.5, -0.3, 0.0, 0.2)
val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0)
val dataFrame: DataFrame =
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")

val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")
Expand All @@ -43,58 +51,98 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {

bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
case Row(x: Double, y: Double) =>
assert(x === y, "The feature value is not correct after bucketing.")
assert(x === y,
s"The feature value is not correct after bucketing. Expected $y but found $x")
}
}

test("Binary search correctness in contrast with linear search") {
val data = Array.fill(100)(Random.nextDouble())
val splits = Array.fill(10)(Random.nextDouble()).sorted
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
val bsResult = Vectors.dense(
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
assert(bsResult ~== lsResult absTol 1e-5)
// Check for exceptions when using a set of invalid feature values.
val invalidData1: Array[Double] = Array(-0.9) ++ validData
val invalidData2 = Array(0.5) ++ validData
val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
intercept[RuntimeException]{
bucketizer.transform(badDF1).collect()
println("Invalid feature value -0.9 was not caught as an invalid feature!")
}
val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
intercept[RuntimeException]{
bucketizer.transform(badDF2).collect()
println("Invalid feature value 0.5 was not caught as an invalid feature!")
}
}

test("Binary search of features at splits") {
val splits = Array.fill(10)(Random.nextDouble()).sorted
val data = splits
val expected = Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
val result = Vectors.dense(
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
assert(result ~== expected absTol 1e-5)
test("Bucket continuous features, with -inf,inf") {
val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity)
val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
val dataFrame: DataFrame =
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")

val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")
.setOutputCol("result")
.setSplits(splits)

bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
case Row(x: Double, y: Double) =>
assert(x === y,
s"The feature value is not correct after bucketing. Expected $y but found $x")
}
}

test("Binary search of features between splits") {
val data = Array.fill(10)(Random.nextDouble())
val splits = Array(-0.1, 1.1)
val expected = Vectors.dense(Array.fill(10)(1.0))
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
val result = Vectors.dense(
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
assert(result ~== expected absTol 1e-5)
test("Binary search correctness on hand-picked examples") {
import BucketizerSuite.checkBinarySearch
// length 3, with -inf
checkBinarySearch(Array(Double.NegativeInfinity, 0.0, 1.0))
// length 4
checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0))
// length 5
checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0, 1.5))
// length 3, with inf
checkBinarySearch(Array(0.0, 1.0, Double.PositiveInfinity))
// length 3, with -inf and inf
checkBinarySearch(Array(Double.NegativeInfinity, 1.0, Double.PositiveInfinity))
// length 4, with -inf and inf
checkBinarySearch(Array(Double.NegativeInfinity, 0.0, 1.0, Double.PositiveInfinity))
}

test("Binary search of features outside splits") {
val data = Array.fill(5)(Random.nextDouble() + 1.1) ++ Array.fill(5)(Random.nextDouble() - 1.1)
val splits = Array(0.0, 1.1)
val expected = Vectors.dense(Array.fill(5)(2.0) ++ Array.fill(5)(0.0))
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
val result = Vectors.dense(
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
assert(result ~== expected absTol 1e-5)
test("Binary search correctness in contrast with linear search, on random data") {
val data = Array.fill(100)(Random.nextDouble())
val splits: Array[Double] = Double.NegativeInfinity +:
Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity
val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x)))
val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
assert(bsResult ~== lsResult absTol 1e-5)
}
}

private object BucketizerSuite {
private def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
private object BucketizerSuite extends FunSuite {
/** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */
def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
require(feature >= splits.head)
var i = 0
while (i < splits.size) {
if (feature < splits(i)) return i
while (i < splits.length - 1) {
if (feature < splits(i + 1)) return i
i += 1
}
i
throw new RuntimeException(
s"linearSearchForBuckets failed to find bucket for feature value $feature")
}

/** Check all values in splits, plus values between all splits. */
def checkBinarySearch(splits: Array[Double]): Unit = {
def testFeature(feature: Double, expectedBucket: Double): Unit = {
assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket,
s"Expected feature value $feature to be in bucket $expectedBucket with splits:" +
s" ${splits.mkString(", ")}")
}
var i = 0
while (i < splits.length - 1) {
testFeature(splits(i), i) // Split i should fall in bucket i.
testFeature((splits(i) + splits(i + 1)) / 2, i) // Value between splits i,i+1 should be in i.
i += 1
}
if (splits.last === Double.PositiveInfinity) {
testFeature(Double.PositiveInfinity, splits.length - 2)
}
}
}