@@ -31,7 +31,6 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
3131import org .apache .spark .mllib .tree .model .{DecisionTreeModel , Node }
3232import org .apache .spark .mllib .util .LocalSparkContext
3333
34-
3534class DecisionTreeSuite extends FunSuite with LocalSparkContext {
3635
3736 def validateClassifier (
@@ -353,8 +352,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
353352 assert(splits(0 ).length === 99 )
354353 assert(bins.length === 2 )
355354 assert(bins(0 ).length === 100 )
356- assert(splits(0 ).length === 99 )
357- assert(bins(0 ).length === 100 )
358355
359356 val treeInput = TreePoint .convertToTreeRDD(rdd, bins, metadata)
360357 val bestSplits = DecisionTree .findBestSplits(treeInput, new Array (8 ), metadata, 0 ,
@@ -381,8 +378,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
381378 assert(splits(0 ).length === 99 )
382379 assert(bins.length === 2 )
383380 assert(bins(0 ).length === 100 )
384- assert(splits(0 ).length === 99 )
385- assert(bins(0 ).length === 100 )
386381
387382 val treeInput = TreePoint .convertToTreeRDD(rdd, bins, metadata)
388383 val bestSplits = DecisionTree .findBestSplits(treeInput, new Array (2 ), metadata, 0 ,
@@ -410,8 +405,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
410405 assert(splits(0 ).length === 99 )
411406 assert(bins.length === 2 )
412407 assert(bins(0 ).length === 100 )
413- assert(splits(0 ).length === 99 )
414- assert(bins(0 ).length === 100 )
415408
416409 val treeInput = TreePoint .convertToTreeRDD(rdd, bins, metadata)
417410 val bestSplits = DecisionTree .findBestSplits(treeInput, new Array (2 ), metadata, 0 ,
@@ -439,8 +432,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
439432 assert(splits(0 ).length === 99 )
440433 assert(bins.length === 2 )
441434 assert(bins(0 ).length === 100 )
442- assert(splits(0 ).length === 99 )
443- assert(bins(0 ).length === 100 )
444435
445436 val treeInput = TreePoint .convertToTreeRDD(rdd, bins, metadata)
446437 val bestSplits = DecisionTree .findBestSplits(treeInput, new Array (2 ), metadata, 0 ,
@@ -464,8 +455,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
464455 assert(splits(0 ).length === 99 )
465456 assert(bins.length === 2 )
466457 assert(bins(0 ).length === 100 )
467- assert(splits(0 ).length === 99 )
468- assert(bins(0 ).length === 100 )
469458
470459 // Train a 1-node model
471460 val strategyOneNode = new Strategy (Classification , Entropy , 1 , 2 , 100 )
@@ -600,7 +589,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
600589 val arr = DecisionTreeSuite .generateContinuousDataPointsForMulticlass()
601590 val rdd = sc.parallelize(arr)
602591 val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
603- numClassesForClassification = 3 )
592+ numClassesForClassification = 3 , maxBins = 100 )
604593 assert(strategy.isMulticlassClassification)
605594 val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
606595
@@ -626,7 +615,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
626615 val arr = DecisionTreeSuite .generateContinuousDataPointsForMulticlass()
627616 val rdd = sc.parallelize(arr)
628617 val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
629- numClassesForClassification = 3 , categoricalFeaturesInfo = Map (0 -> 3 ))
618+ numClassesForClassification = 3 , maxBins = 100 , categoricalFeaturesInfo = Map (0 -> 3 ))
630619 assert(strategy.isMulticlassClassification)
631620 val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
632621 assert(metadata.isUnordered(featureIndex = 0 ))
@@ -652,7 +641,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
652641 val arr = DecisionTreeSuite .generateCategoricalDataPointsForMulticlassForOrderedFeatures()
653642 val rdd = sc.parallelize(arr)
654643 val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
655- numClassesForClassification = 3 , categoricalFeaturesInfo = Map (0 -> 10 , 1 -> 10 ))
644+ numClassesForClassification = 3 , maxBins = 100 ,
645+ categoricalFeaturesInfo = Map (0 -> 10 , 1 -> 10 ))
656646 assert(strategy.isMulticlassClassification)
657647 val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
658648 assert(! metadata.isUnordered(featureIndex = 0 ))
0 commit comments