@@ -200,7 +200,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
200200 val data = (0 until validData1.length).map { idx =>
201201 (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
202202 }
203- val dataFrame : DataFrame = data.toSeq. toDF(" feature1" , " feature2" , " expected1" , " expected2" )
203+ val dataFrame : DataFrame = data.toDF(" feature1" , " feature2" , " expected1" , " expected2" )
204204
205205 val bucketizer1 : Bucketizer = new Bucketizer ()
206206 .setInputCols(Array (" feature1" , " feature2" ))
@@ -210,16 +210,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
210210 assert(bucketizer1.isBucketizeMultipleColumns())
211211
212212 bucketizer1.transform(dataFrame).select(" result1" , " expected1" , " result2" , " expected2" )
213- .collect().foreach {
214- case Row (r1 : Double , e1 : Double , r2 : Double , e2 : Double ) =>
215- assert(r1 === e1,
216- s " The feature value is not correct after bucketing. Expected $e1 but found $r1" )
217- assert(r2 === e2,
218- s " The feature value is not correct after bucketing. Expected $e2 but found $r2" )
219- }
213+ BucketizerSuite .checkBucketResults(bucketizer1.transform(dataFrame),
214+ Seq (" result1" , " result2" ),
215+ Seq (" expected1" , " expected2" ))
220216
221217 // Check for exceptions when using a set of invalid feature values.
222- val invalidData1 : Array [ Double ] = Array (- 0.9 ) ++ validData1
218+ val invalidData1 = Array (- 0.9 ) ++ validData1
223219 val invalidData2 = Array (0.51 ) ++ validData1
224220 val badDF1 = invalidData1.zipWithIndex.toSeq.toDF(" feature" , " idx" )
225221
@@ -256,7 +252,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
256252 val data = (0 until validData1.length).map { idx =>
257253 (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
258254 }
259- val dataFrame : DataFrame = data.toSeq. toDF(" feature1" , " feature2" , " expected1" , " expected2" )
255+ val dataFrame : DataFrame = data.toDF(" feature1" , " feature2" , " expected1" , " expected2" )
260256
261257 val bucketizer : Bucketizer = new Bucketizer ()
262258 .setInputCols(Array (" feature1" , " feature2" ))
@@ -265,14 +261,9 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
265261
266262 assert(bucketizer.isBucketizeMultipleColumns())
267263
268- bucketizer.transform(dataFrame).select(" result1" , " expected1" , " result2" , " expected2" )
269- .collect().foreach {
270- case Row (r1 : Double , e1 : Double , r2 : Double , e2 : Double ) =>
271- assert(r1 === e1,
272- s " The feature value is not correct after bucketing. Expected $e1 but found $r1" )
273- assert(r2 === e2,
274- s " The feature value is not correct after bucketing. Expected $e2 but found $r2" )
275- }
264+ BucketizerSuite .checkBucketResults(bucketizer.transform(dataFrame),
265+ Seq (" result1" , " result2" ),
266+ Seq (" expected1" , " expected2" ))
276267 }
277268
278269 test(" multiple columns: Bucket continuous features, with NaN data but non-NaN splits" ) {
@@ -288,7 +279,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
288279 val data = (0 until validData1.length).map { idx =>
289280 (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
290281 }
291- val dataFrame : DataFrame = data.toSeq. toDF(" feature1" , " feature2" , " expected1" , " expected2" )
282+ val dataFrame : DataFrame = data.toDF(" feature1" , " feature2" , " expected1" , " expected2" )
292283
293284 val bucketizer : Bucketizer = new Bucketizer ()
294285 .setInputCols(Array (" feature1" , " feature2" ))
@@ -298,14 +289,9 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
298289 assert(bucketizer.isBucketizeMultipleColumns())
299290
300291 bucketizer.setHandleInvalid(" keep" )
301- bucketizer.transform(dataFrame).select(" result1" , " expected1" , " result2" , " expected2" )
302- .collect().foreach {
303- case Row (r1 : Double , e1 : Double , r2 : Double , e2 : Double ) =>
304- assert(r1 === e1,
305- s " The feature value is not correct after bucketing. Expected $e1 but found $r1" )
306- assert(r2 === e2,
307- s " The feature value is not correct after bucketing. Expected $e2 but found $r2" )
308- }
292+ BucketizerSuite .checkBucketResults(bucketizer.transform(dataFrame),
293+ Seq (" result1" , " result2" ),
294+ Seq (" expected1" , " expected2" ))
309295
310296 bucketizer.setHandleInvalid(" skip" )
311297 val skipResults1 : Array [Double ] = bucketizer.transform(dataFrame)
@@ -335,7 +321,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
335321 }
336322 }
337323
338- test(" multiple columns:: read/write" ) {
324+ test(" multiple columns: read/write" ) {
339325 val t = new Bucketizer ()
340326 .setInputCols(Array (" myInputCol" ))
341327 .setOutputCols(Array (" myOutputCol" ))
@@ -359,13 +345,51 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
359345 .setStages(Array (bucket))
360346 .fit(df)
361347 pl.transform(df).select(" result1" , " expected1" , " result2" , " expected2" )
362- .collect().foreach {
363- case Row (r1 : Double , e1 : Double , r2 : Double , e2 : Double ) =>
364- assert(r1 === e1,
365- s " The feature value is not correct after bucketing. Expected $e1 but found $r1" )
366- assert(r2 === e2,
367- s " The feature value is not correct after bucketing. Expected $e2 but found $r2" )
368- }
348+
349+ BucketizerSuite .checkBucketResults(pl.transform(df),
350+ Seq (" result1" , " result2" ), Seq (" expected1" , " expected2" ))
351+ }
352+
353+ test(" Compare single/multiple column(s) Bucketizer in pipeline" ) {
354+ val df = Seq ((0.5 , 0.3 , 1.0 , 1.0 ), (0.5 , - 0.4 , 1.0 , 0.0 ))
355+ .toDF(" feature1" , " feature2" , " expected1" , " expected2" )
356+
357+ val multiColsBucket = new Bucketizer ()
358+ .setInputCols(Array (" feature1" , " feature2" ))
359+ .setOutputCols(Array (" result1" , " result2" ))
360+ .setSplitsArray(Array (Array (- 0.5 , 0.0 , 0.5 ), Array (- 0.5 , 0.0 , 0.5 )))
361+
362+ val plForMultiCols = new Pipeline ()
363+ .setStages(Array (multiColsBucket))
364+ .fit(df)
365+
366+ val bucketForCol1 = new Bucketizer ()
367+ .setInputCol(" feature1" )
368+ .setOutputCol(" result1" )
369+ .setSplits(Array (- 0.5 , 0.0 , 0.5 ))
370+ val bucketForCol2 = new Bucketizer ()
371+ .setInputCol(" feature2" )
372+ .setOutputCol(" result2" )
373+ .setSplits(Array (- 0.5 , 0.0 , 0.5 ))
374+
375+ val plForSingleCol = new Pipeline ()
376+ .setStages(Array (bucketForCol1, bucketForCol2))
377+ .fit(df)
378+
379+ val resultForSingleCol = plForSingleCol.transform(df)
380+ .select(" result1" , " expected1" , " result2" , " expected2" )
381+ .collect()
382+ val resultForMultiCols = plForMultiCols.transform(df)
383+ .select(" result1" , " expected1" , " result2" , " expected2" )
384+ .collect()
385+
386+ resultForSingleCol.zip(resultForMultiCols).foreach {
387+ case (rowForSingle, rowForMultiCols) =>
388+ assert(rowForSingle.getDouble(0 ) == rowForMultiCols.getDouble(0 ) &&
389+ rowForSingle.getDouble(1 ) == rowForMultiCols.getDouble(1 ) &&
390+ rowForSingle.getDouble(2 ) == rowForMultiCols.getDouble(2 ) &&
391+ rowForSingle.getDouble(3 ) == rowForMultiCols.getDouble(3 ))
392+ }
369393 }
370394
371395 test(" Both inputCol and inputCols are set" ) {
@@ -411,4 +435,26 @@ private object BucketizerSuite extends SparkFunSuite {
411435 i += 1
412436 }
413437 }
438+
439+ /** Checks if bucketized results match expected ones. */
440+ def checkBucketResults (
441+ bucketResult : DataFrame ,
442+ resultColumns : Seq [String ],
443+ expectedColumns : Seq [String ]): Unit = {
444+ assert(resultColumns.length == expectedColumns.length,
445+ s " Given ${resultColumns.length} result columns doesn't match " +
446+ s " ${expectedColumns.length} expected columns. " )
447+ assert(resultColumns.length > 0 , " At least one result and expected columns are needed." )
448+
449+ val allColumns = resultColumns ++ expectedColumns
450+ bucketResult.select(allColumns.head, allColumns.tail: _* ).collect().foreach {
451+ case row =>
452+ for (idx <- 0 until row.length / 2 ) {
453+ val result = row.getDouble(idx)
454+ val expected = row.getDouble(idx + row.length / 2 )
455+ assert(result === expected, " The feature value is not correct after bucketing. " +
456+ s " Expected $expected but found $result. " )
457+ }
458+ }
459+ }
414460}
0 commit comments