Skip to content

Commit 6ff9c79

Browse files
committed
Add test for MultipleBucketizer.
1 parent 38dce8b commit 6ff9c79

File tree

1 file changed

+184
-0
lines changed

1 file changed

+184
-0
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import scala.util.Random
21+
22+
import org.apache.spark.{SparkException, SparkFunSuite}
23+
import org.apache.spark.ml.linalg.Vectors
24+
import org.apache.spark.ml.param.ParamsSuite
25+
import org.apache.spark.ml.util.DefaultReadWriteTest
26+
import org.apache.spark.ml.util.TestingUtils._
27+
import org.apache.spark.mllib.util.MLlibTestSparkContext
28+
import org.apache.spark.sql.{DataFrame, Row}
29+
30+
class MultipleBucketizerSuite extends SparkFunSuite with MLlibTestSparkContext
31+
with DefaultReadWriteTest {
32+
33+
import testImplicits._
34+
35+
test("params") {
36+
ParamsSuite.checkParams(new MultipleBucketizer)
37+
}
38+
39+
test("Bucket continuous features, without -inf,inf") {
40+
// Check a set of valid feature values.
41+
val splits = Array(Array(-0.5, 0.0, 0.5), Array(-0.1, 0.3, 0.5))
42+
val validData1 = Array(-0.5, -0.3, 0.0, 0.2)
43+
val validData2 = Array(0.5, 0.3, 0.0, -0.1)
44+
val expectedBuckets1 = Array(0.0, 0.0, 1.0, 1.0)
45+
val expectedBuckets2 = Array(1.0, 1.0, 0.0, 0.0)
46+
47+
val data = (0 until validData1.length).map { idx =>
48+
(validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
49+
}
50+
val dataFrame: DataFrame = data.toSeq.toDF("feature1", "feature2", "expected1", "expected2")
51+
52+
val bucketizer1: MultipleBucketizer = new MultipleBucketizer()
53+
.setInputCols(Array("feature1", "feature2"))
54+
.setOutputCols(Array("result1", "result2"))
55+
.setSplitsArray(splits)
56+
57+
bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2")
58+
.collect().foreach {
59+
case Row(r1: Double, e1: Double, r2: Double, e2: Double) =>
60+
assert(r1 === e1,
61+
s"The feature value is not correct after bucketing. Expected $e1 but found $r1")
62+
assert(r2 === e2,
63+
s"The feature value is not correct after bucketing. Expected $e2 but found $r2")
64+
}
65+
66+
// Check for exceptions when using a set of invalid feature values.
67+
val invalidData1: Array[Double] = Array(-0.9) ++ validData1
68+
val invalidData2 = Array(0.51) ++ validData1
69+
val badDF1 = invalidData1.zipWithIndex.toSeq.toDF("feature", "idx")
70+
71+
val bucketizer2: MultipleBucketizer = new MultipleBucketizer()
72+
.setInputCols(Array("feature"))
73+
.setOutputCols(Array("result"))
74+
.setSplitsArray(Array(splits(0)))
75+
76+
withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
77+
intercept[SparkException] {
78+
bucketizer2.transform(badDF1).collect()
79+
}
80+
}
81+
val badDF2 = invalidData2.zipWithIndex.toSeq.toDF("feature", "idx")
82+
withClue("Invalid feature value 0.51 was not caught as an invalid feature!") {
83+
intercept[SparkException] {
84+
bucketizer2.transform(badDF2).collect()
85+
}
86+
}
87+
}
88+
89+
test("Bucket continuous features, with -inf,inf") {
90+
val splits = Array(
91+
Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity),
92+
Array(Double.NegativeInfinity, -0.3, 0.2, 0.5, Double.PositiveInfinity))
93+
94+
val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
95+
val validData2 = Array(-0.1, -0.5, -0.2, 0.0, 0.1, 0.3, 0.5)
96+
val expectedBuckets1 = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
97+
val expectedBuckets2 = Array(1.0, 0.0, 1.0, 1.0, 1.0, 2.0, 3.0)
98+
99+
val data = (0 until validData1.length).map { idx =>
100+
(validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
101+
}
102+
val dataFrame: DataFrame = data.toSeq.toDF("feature1", "feature2", "expected1", "expected2")
103+
104+
val bucketizer: MultipleBucketizer = new MultipleBucketizer()
105+
.setInputCols(Array("feature1", "feature2"))
106+
.setOutputCols(Array("result1", "result2"))
107+
.setSplitsArray(splits)
108+
109+
bucketizer.transform(dataFrame).select("result1", "expected1", "result2", "expected2")
110+
.collect().foreach {
111+
case Row(r1: Double, e1: Double, r2: Double, e2: Double) =>
112+
assert(r1 === e1,
113+
s"The feature value is not correct after bucketing. Expected $e1 but found $r1")
114+
assert(r2 === e2,
115+
s"The feature value is not correct after bucketing. Expected $e2 but found $r2")
116+
}
117+
}
118+
119+
test("Bucket continuous features, with NaN data but non-NaN splits") {
120+
val splits = Array(
121+
Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity),
122+
Array(Double.NegativeInfinity, -0.1, 0.2, 0.6, Double.PositiveInfinity))
123+
124+
val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN)
125+
val validData2 = Array(0.2, -0.1, 0.3, 0.0, 0.1, 0.3, 0.5, 0.8, Double.NaN, Double.NaN)
126+
val expectedBuckets1 = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 4.0)
127+
val expectedBuckets2 = Array(2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 3.0, 4.0, 4.0)
128+
129+
val data = (0 until validData1.length).map { idx =>
130+
(validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
131+
}
132+
val dataFrame: DataFrame = data.toSeq.toDF("feature1", "feature2", "expected1", "expected2")
133+
134+
val bucketizer: MultipleBucketizer = new MultipleBucketizer()
135+
.setInputCols(Array("feature1", "feature2"))
136+
.setOutputCols(Array("result1", "result2"))
137+
.setSplitsArray(splits)
138+
139+
bucketizer.setHandleInvalid("keep")
140+
bucketizer.transform(dataFrame).select("result1", "expected1", "result2", "expected2")
141+
.collect().foreach {
142+
case Row(r1: Double, e1: Double, r2: Double, e2: Double) =>
143+
assert(r1 === e1,
144+
s"The feature value is not correct after bucketing. Expected $e1 but found $r1")
145+
assert(r2 === e2,
146+
s"The feature value is not correct after bucketing. Expected $e2 but found $r2")
147+
}
148+
149+
bucketizer.setHandleInvalid("skip")
150+
val skipResults1: Array[Double] = bucketizer.transform(dataFrame)
151+
.select("result1").as[Double].collect()
152+
assert(skipResults1.length === 7)
153+
assert(skipResults1.forall(_ !== 4.0))
154+
155+
val skipResults2: Array[Double] = bucketizer.transform(dataFrame)
156+
.select("result2").as[Double].collect()
157+
assert(skipResults2.length === 7)
158+
assert(skipResults2.forall(_ !== 4.0))
159+
160+
bucketizer.setHandleInvalid("error")
161+
withClue("Bucketizer should throw error when setHandleInvalid=error and given NaN values") {
162+
intercept[SparkException] {
163+
bucketizer.transform(dataFrame).collect()
164+
}
165+
}
166+
}
167+
168+
test("Bucket continuous features, with NaN splits") {
169+
val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN)
170+
withClue("Invalid NaN split was not caught during Bucketizer initialization") {
171+
intercept[IllegalArgumentException] {
172+
new MultipleBucketizer().setSplitsArray(Array(splits))
173+
}
174+
}
175+
}
176+
177+
test("read/write") {
178+
val t = new MultipleBucketizer()
179+
.setInputCols(Array("myInputCol"))
180+
.setOutputCols(Array("myOutputCol"))
181+
.setSplitsArray(Array(Array(0.1, 0.8, 0.9)))
182+
testDefaultReadWrite(t)
183+
}
184+
}

0 commit comments

Comments
 (0)