@@ -21,20 +21,28 @@ import scala.util.Random
2121
2222import org .scalatest .FunSuite
2323
24+ import org .apache .spark .SparkException
2425import org .apache .spark .mllib .linalg .Vectors
2526import org .apache .spark .mllib .util .MLlibTestSparkContext
2627import org .apache .spark .mllib .util .TestingUtils ._
2728import org .apache .spark .sql .{DataFrame , Row , SQLContext }
2829
2930class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
3031
31- test(" Bucket continuous features with setter" ) {
32- val sqlContext = new SQLContext (sc)
33- val data = Array (0.1 , - 0.5 , 0.2 , - 0.3 , 0.8 , 0.7 , - 0.1 , - 0.4 , - 0.9 )
32+ @ transient private var sqlContext : SQLContext = _
33+
34+ override def beforeAll (): Unit = {
35+ super .beforeAll()
36+ sqlContext = new SQLContext (sc)
37+ }
38+
39+ test(" Bucket continuous features, without -inf,inf" ) {
40+ // Check a set of valid feature values.
3441 val splits = Array (- 0.5 , 0.0 , 0.5 )
35- val bucketizedData = Array (2.0 , 1.0 , 2.0 , 1.0 , 3.0 , 3.0 , 1.0 , 1.0 , 0.0 )
36- val dataFrame : DataFrame = sqlContext.createDataFrame(
37- data.zip(bucketizedData)).toDF(" feature" , " expected" )
42+ val validData = Array (- 0.5 , - 0.3 , 0.0 , 0.2 )
43+ val expectedBuckets = Array (0.0 , 0.0 , 1.0 , 1.0 )
44+ val dataFrame : DataFrame =
45+ sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF(" feature" , " expected" )
3846
3947 val bucketizer : Bucketizer = new Bucketizer ()
4048 .setInputCol(" feature" )
@@ -43,58 +51,98 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
4351
4452 bucketizer.transform(dataFrame).select(" result" , " expected" ).collect().foreach {
4553 case Row (x : Double , y : Double ) =>
46- assert(x === y, " The feature value is not correct after bucketing." )
54+ assert(x === y,
55+ s " The feature value is not correct after bucketing. Expected $y but found $x" )
4756 }
48- }
4957
50- test(" Binary search correctness in contrast with linear search" ) {
51- val data = Array .fill(100 )(Random .nextDouble())
52- val splits = Array .fill(10 )(Random .nextDouble()).sorted
53- val wrappedSplits = Array (Double .MinValue ) ++ splits ++ Array (Double .MaxValue )
54- val bsResult = Vectors .dense(
55- data.map(x => Bucketizer .binarySearchForBuckets(wrappedSplits, x, true , true )))
56- val lsResult = Vectors .dense(data.map(x => BucketizerSuite .linearSearchForBuckets(splits, x)))
57- assert(bsResult ~== lsResult absTol 1e-5 )
58+ // Check for exceptions when using a set of invalid feature values.
59+ val invalidData1 : Array [Double ] = Array (- 0.9 ) ++ validData
60+ val invalidData2 = Array (0.5 ) ++ validData
61+ val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF(" feature" , " idx" )
62+ intercept[RuntimeException ]{
63+ bucketizer.transform(badDF1).collect()
64+ println(" Invalid feature value -0.9 was not caught as an invalid feature!" )
65+ }
66+ val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF(" feature" , " idx" )
67+ intercept[RuntimeException ]{
68+ bucketizer.transform(badDF2).collect()
69+ println(" Invalid feature value 0.5 was not caught as an invalid feature!" )
70+ }
5871 }
5972
60- test(" Binary search of features at splits" ) {
61- val splits = Array .fill(10 )(Random .nextDouble()).sorted
62- val data = splits
63- val expected = Vectors .dense(1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 , 9.0 , 10.0 )
64- val wrappedSplits = Array (Double .MinValue ) ++ splits ++ Array (Double .MaxValue )
65- val result = Vectors .dense(
66- data.map(x => Bucketizer .binarySearchForBuckets(wrappedSplits, x, true , true )))
67- assert(result ~== expected absTol 1e-5 )
73+ test(" Bucket continuous features, with -inf,inf" ) {
74+ val splits = Array (Double .NegativeInfinity , - 0.5 , 0.0 , 0.5 , Double .PositiveInfinity )
75+ val validData = Array (- 0.9 , - 0.5 , - 0.3 , 0.0 , 0.2 , 0.5 , 0.9 )
76+ val expectedBuckets = Array (0.0 , 1.0 , 1.0 , 2.0 , 2.0 , 3.0 , 3.0 )
77+ val dataFrame : DataFrame =
78+ sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF(" feature" , " expected" )
79+
80+ val bucketizer : Bucketizer = new Bucketizer ()
81+ .setInputCol(" feature" )
82+ .setOutputCol(" result" )
83+ .setSplits(splits)
84+
85+ bucketizer.transform(dataFrame).select(" result" , " expected" ).collect().foreach {
86+ case Row (x : Double , y : Double ) =>
87+ assert(x === y,
88+ s " The feature value is not correct after bucketing. Expected $y but found $x" )
89+ }
6890 }
6991
70- test(" Binary search of features between splits" ) {
71- val data = Array .fill(10 )(Random .nextDouble())
72- val splits = Array (- 0.1 , 1.1 )
73- val expected = Vectors .dense(Array .fill(10 )(1.0 ))
74- val wrappedSplits = Array (Double .MinValue ) ++ splits ++ Array (Double .MaxValue )
75- val result = Vectors .dense(
76- data.map(x => Bucketizer .binarySearchForBuckets(wrappedSplits, x, true , true )))
77- assert(result ~== expected absTol 1e-5 )
92+ test(" Binary search correctness on hand-picked examples" ) {
93+ import BucketizerSuite .checkBinarySearch
94+ // length 3, with -inf
95+ checkBinarySearch(Array (Double .NegativeInfinity , 0.0 , 1.0 ))
96+ // length 4
97+ checkBinarySearch(Array (- 1.0 , - 0.5 , 0.0 , 1.0 ))
98+ // length 5
99+ checkBinarySearch(Array (- 1.0 , - 0.5 , 0.0 , 1.0 , 1.5 ))
100+ // length 3, with inf
101+ checkBinarySearch(Array (0.0 , 1.0 , Double .PositiveInfinity ))
102+ // length 3, with -inf and inf
103+ checkBinarySearch(Array (Double .NegativeInfinity , 1.0 , Double .PositiveInfinity ))
104+ // length 4, with -inf and inf
105+ checkBinarySearch(Array (Double .NegativeInfinity , 0.0 , 1.0 , Double .PositiveInfinity ))
78106 }
79107
80- test(" Binary search of features outside splits" ) {
81- val data = Array .fill(5 )(Random .nextDouble() + 1.1 ) ++ Array .fill(5 )(Random .nextDouble() - 1.1 )
82- val splits = Array (0.0 , 1.1 )
83- val expected = Vectors .dense(Array .fill(5 )(2.0 ) ++ Array .fill(5 )(0.0 ))
84- val wrappedSplits = Array (Double .MinValue ) ++ splits ++ Array (Double .MaxValue )
85- val result = Vectors .dense(
86- data.map(x => Bucketizer .binarySearchForBuckets(wrappedSplits, x, true , true )))
87- assert(result ~== expected absTol 1e-5 )
108+ test(" Binary search correctness in contrast with linear search, on random data" ) {
109+ val data = Array .fill(100 )(Random .nextDouble())
110+ val splits : Array [Double ] = Double .NegativeInfinity +:
111+ Array .fill(10 )(Random .nextDouble()).sorted :+ Double .PositiveInfinity
112+ val bsResult = Vectors .dense(data.map(x => Bucketizer .binarySearchForBuckets(splits, x)))
113+ val lsResult = Vectors .dense(data.map(x => BucketizerSuite .linearSearchForBuckets(splits, x)))
114+ assert(bsResult ~== lsResult absTol 1e-5 )
88115 }
89116}
90117
91- private object BucketizerSuite {
92- private def linearSearchForBuckets (splits : Array [Double ], feature : Double ): Double = {
118+ private object BucketizerSuite extends FunSuite {
119+ /** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */
120+ def linearSearchForBuckets (splits : Array [Double ], feature : Double ): Double = {
121+ require(feature >= splits.head)
93122 var i = 0
94- while (i < splits.size ) {
95- if (feature < splits(i)) return i
123+ while (i < splits.length - 1 ) {
124+ if (feature < splits(i + 1 )) return i
96125 i += 1
97126 }
98- i
127+ throw new RuntimeException (
128+ s " linearSearchForBuckets failed to find bucket for feature value $feature" )
129+ }
130+
131+ /** Check all values in splits, plus values between all splits. */
132+ def checkBinarySearch (splits : Array [Double ]): Unit = {
133+ def testFeature (feature : Double , expectedBucket : Double ): Unit = {
134+ assert(Bucketizer .binarySearchForBuckets(splits, feature) === expectedBucket,
135+ s " Expected feature value $feature to be in bucket $expectedBucket with splits: " +
136+ s " ${splits.mkString(" , " )}" )
137+ }
138+ var i = 0
139+ while (i < splits.length - 1 ) {
140+ testFeature(splits(i), i) // Split i should fall in bucket i.
141+ testFeature((splits(i) + splits(i + 1 )) / 2 , i) // Value between splits i,i+1 should be in i.
142+ i += 1
143+ }
144+ if (splits.last === Double .PositiveInfinity ) {
145+ testFeature(Double .PositiveInfinity , splits.length - 2 )
146+ }
99147 }
100148}
0 commit comments