From cd53eae11313fd30f71f5ec94b20fe8d4427b8cd Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Thu, 28 Nov 2013 02:20:27 -0800 Subject: [PATCH 01/48] skeletal framework Signed-off-by: Manish Amde --- .../classification/ClassificationTree.scala | 21 ++++++++ .../mllib/regression/RegressionTree.scala | 21 ++++++++ .../spark/mllib/tree/DecisionTree.scala | 54 +++++++++++++++++++ .../org/apache/spark/mllib/tree/README.md | 15 ++++++ .../apache/spark/mllib/tree/Strategy.scala | 28 ++++++++++ .../spark/mllib/tree/impurity/Entropy.scala | 34 ++++++++++++ .../spark/mllib/tree/impurity/Gini.scala | 28 ++++++++++ .../spark/mllib/tree/impurity/Impurity.scala | 23 ++++++++ .../spark/mllib/tree/impurity/Variance.scala | 23 ++++++++ .../apache/spark/mllib/tree/model/Bin.scala | 21 ++++++++ .../mllib/tree/model/DecisionTreeModel.scala | 21 ++++++++ .../apache/spark/mllib/tree/model/Split.scala | 29 ++++++++++ 12 files changed, 318 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionTree.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/README.md create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala new file mode 100644 index 0000000000000..a6f27e7fd1111 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.classification + +class ClassificationTree { + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionTree.scala new file mode 100644 index 0000000000000..fd9beb79ab88f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionTree.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.regression + +class RegressionTree { + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala new file mode 100644 index 0000000000000..e3a52990bfb65 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.{Split, Bin, DecisionTreeModel} + + +class DecisionTree(val strategy : Strategy) { + + def train(input : RDD[LabeledPoint]) : DecisionTreeModel = { + + //Cache input RDD for speedup during multiple passes + input.cache() + + //TODO: Find all splits and bins using quantiles including support for categorical features, single-pass + val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) + + //TODO: Level-wise training of tree and obtain Decision Tree model + + + return new DecisionTreeModel() + } + +} + +object DecisionTree { + def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { + val numSplits = strategy.numSplits + //TODO: Justify this calculation + val requiredSamples : Long = numSplits*numSplits + val count : Long = input.count() + val numSamples : Long = if (requiredSamples < count) requiredSamples else count + val numFeatures = input.take(1)(0).features.length + (Array.ofDim[Split](numFeatures,numSplits),Array.ofDim[Bin](numFeatures,numSplits)) + } + +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md new file mode 100644 index 0000000000000..61cba281cccfc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md @@ -0,0 +1,15 @@ +This package contains the default implementation of the decision tree algorithm. + +The decision tree algorithm supports: ++ information loss calculation with entropy and gini for classification and variance for regression ++ node model pruning ++ printing to dot files ++ unit tests + +#Performance testing + +#Future Extensions + ++ Random forests ++ Boosting ++ Extremely randomized trees \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala new file mode 100644 index 0000000000000..12f0bb37d793c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree + +import org.apache.spark.mllib.tree.impurity.Impurity + +class Strategy ( + val kind : String, + val impurity : Impurity, + val maxDepth : Int, + val numSplits : Int, + val quantileCalculationStrategy : String = "sampleAndSort") { + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala new file mode 100644 index 0000000000000..00feb25e25322 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.impurity + +object Entropy extends Impurity { + + def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + + def calculate(c0: Double, c1: Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + -(f0 * log2(f0)) - (f1 * log2(f1)) + } + } + + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala new file mode 100644 index 0000000000000..a95f0431c6e8f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.impurity + +object Gini extends Impurity { + + def calculate(c0 : Double, c1 : Double): Double = { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + 1 - f0*f0 - f1*f1 + } + + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala new file mode 100644 index 0000000000000..fadebb0c203eb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.impurity + +trait Impurity { + + def calculate(c0 : Double, c1 : Double): Double + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala new file mode 100644 index 0000000000000..98f332122785e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.impurity + +import javax.naming.OperationNotSupportedException + +object Variance extends Impurity { + def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala new file mode 100644 index 0000000000000..638d65805a813 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.model + +case class Bin(kind : String, lowSplit : Split, highSplit : Split) { + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala new file mode 100644 index 0000000000000..d0465d8c6fb6f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.model + +class DecisionTreeModel { + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala new file mode 100644 index 0000000000000..4f8beb73bbb09 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.model + +case class Split( + val feature: Int, + val threshold : Double, + val kind : String) { + +} + +class dummyLowSplit(kind : String) extends Split(Int.MinValue, Double.MinValue, kind) + +class dummyHighSplit(kind : String) extends Split(Int.MaxValue, Double.MaxValue, kind) + From 92cedce2eb5055e0164c90842d6613c618bfed94 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 1 Dec 2013 22:52:29 -0800 Subject: [PATCH 02/48] basic building blocks for intermediate RDD calculation. untested. Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 170 +++++++++++++++++- .../apache/spark/mllib/tree/Strategy.scala | 2 +- .../apache/spark/mllib/tree/model/Bin.scala | 2 +- .../spark/mllib/tree/model/Filter.scala | 22 +++ .../apache/spark/mllib/tree/model/Split.scala | 11 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 70 ++++++++ 6 files changed, 262 insertions(+), 15 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index e3a52990bfb65..575eb4e8d825f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,9 +17,12 @@ package org.apache.spark.mllib.tree +import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.tree.model._ +import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.{Split, Bin, DecisionTreeModel} +import org.apache.spark.mllib.tree.model.Split class DecisionTree(val strategy : Strategy) { @@ -30,25 +33,180 @@ class DecisionTree(val strategy : Strategy) { input.cache() //TODO: Find all splits and bins using quantiles including support for categorical features, single-pass + //TODO: Think about broadcasting this val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) //TODO: Level-wise training of tree and obtain Decision Tree model + val maxDepth = strategy.maxDepth + + val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1 + val filters = new Array[List[Filter]](maxNumNodes) + + for (level <- 0 until maxDepth){ + //Find best split for all nodes at a level + val numNodes= scala.math.pow(2,level).toInt + val bestSplits = DecisionTree.findBestSplits(input, strategy, level, filters,splits,bins) + //TODO: update filters and decision tree model + } return new DecisionTreeModel() } } -object DecisionTree { +object DecisionTree extends Logging { + + def findBestSplits( + input : RDD[LabeledPoint], + strategy: Strategy, + level: Int, + filters : Array[List[Filter]], + splits : Array[Array[Split]], + bins : Array[Array[Bin]]) : Array[Split] = { + + def findParentFilters(nodeIndex: Int): List[Filter] = { + if (level == 0) { + List[Filter]() + } else { + val nodeFilterIndex = scala.math.pow(2, level).toInt + nodeIndex + val parentFilterIndex = nodeFilterIndex / 2 + filters(parentFilterIndex) + } + } + + def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { + + for (filter <- parentFilters) { + val features = labeledPoint.features + val featureIndex = filter.split.feature + val threshold = filter.split.threshold + val comparison = filter.comparison + comparison match { + case(-1) => if (features(featureIndex) > threshold) return false + case(0) => if (features(featureIndex) != threshold) return false + case(1) => if (features(featureIndex) <= threshold) return false + } + } + true + } + + def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { + + //TODO: Do binary search + for (binIndex <- 0 until strategy.numSplits) { + val bin = bins(featureIndex)(binIndex) + //TODO: Remove this requirement post basic functional testing + require(bin.lowSplit.feature == featureIndex) + require(bin.highSplit.feature == featureIndex) + val lowThreshold = bin.lowSplit.threshold + val highThreshold = bin.highSplit.threshold + val features = labeledPoint.features + if ((lowThreshold < features(featureIndex)) & (highThreshold < features(featureIndex))) { + return binIndex + } + } + throw new UnknownError("no bin was found.") + + } + def findBinsForLevel: Array[Double] = { + + val numNodes = scala.math.pow(2, level).toInt + //Find the number of features by looking at the first sample + val numFeatures = input.take(1)(0).features.length + + //TODO: Bit pack more by removing redundant label storage + // calculating bin index and label per feature per node + val arr = new Array[Double](2 * numFeatures * numNodes) + for (nodeIndex <- 0 until numNodes) { + val parentFilters = findParentFilters(nodeIndex) + //Find out whether the sample qualifies for the particular node + val sampleValid = isSampleValid(parentFilters, labeledPoint) + val shift = 2 * numFeatures * nodeIndex + if (sampleValid) { + //Add to invalid bin index -1 + for (featureIndex <- shift until (shift + numFeatures) by 2) { + arr(featureIndex + 1) = -1 + arr(featureIndex + 2) = labeledPoint.label + } + } else { + for (featureIndex <- 0 until numFeatures) { + arr(shift + (featureIndex * 2) + 1) = findBin(featureIndex, labeledPoint) + arr(shift + (featureIndex * 2) + 2) = labeledPoint.label + } + } + + } + arr + } + + val binMappedRDD = input.map(labeledPoint => findBinsForLevel) + //calculate bin aggregates + //find best split + + + Array[Split]() + } + def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { + val numSplits = strategy.numSplits + logDebug("numSplits = " + numSplits) + + //Calculate the number of sample for approximate quantile calculation //TODO: Justify this calculation - val requiredSamples : Long = numSplits*numSplits - val count : Long = input.count() - val numSamples : Long = if (requiredSamples < count) requiredSamples else count + val requiredSamples = numSplits*numSplits + val count = input.count() + val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 + logDebug("fraction of data used for calculating quantiles = " + fraction) + + //sampled input for RDD calculation + val sampledInput = input.sample(false, fraction, 42).collect() + val numSamples = sampledInput.length + + require(numSamples > numSplits, "length of input samples should be greater than numSplits") + + //Find the number of features by looking at the first sample val numFeatures = input.take(1)(0).features.length - (Array.ofDim[Split](numFeatures,numSplits),Array.ofDim[Bin](numFeatures,numSplits)) + + strategy.quantileCalculationStrategy match { + case "sort" => { + val splits = Array.ofDim[Split](numFeatures,numSplits-1) + val bins = Array.ofDim[Bin](numFeatures,numSplits) + + //Find all splits + for (featureIndex <- 0 until numFeatures){ + val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted + val stride : Double = numSamples.toDouble/numSplits + for (index <- 0 until numSplits-1) { + val sampleIndex = (index+1)*stride.toInt + val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous") + splits(featureIndex)(index) = split + } + } + + //Find all bins + for (featureIndex <- 0 until numFeatures){ + bins(featureIndex)(0) + = new Bin(new DummyLowSplit("continuous"),splits(featureIndex)(0),"continuous") + for (index <- 1 until numSplits - 1){ + val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),"continuous") + bins(featureIndex)(index) = bin + } + bins(featureIndex)(numSplits-1) + = new Bin(splits(featureIndex)(numSplits-3),new DummyHighSplit("continuous"),"continuous") + } + + (splits,bins) + } + case "minMax" => { + (Array.ofDim[Split](numFeatures,numSplits),Array.ofDim[Bin](numFeatures,numSplits+2)) + } + case "approximateHistogram" => { + throw new UnsupportedOperationException("approximate histogram not supported yet.") + } + + } } } \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala index 12f0bb37d793c..a7077f0914033 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala @@ -23,6 +23,6 @@ class Strategy ( val impurity : Impurity, val maxDepth : Int, val numSplits : Int, - val quantileCalculationStrategy : String = "sampleAndSort") { + val quantileCalculationStrategy : String = "sort") { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 638d65805a813..25d16a9a2fc2f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -16,6 +16,6 @@ */ package org.apache.spark.mllib.tree.model -case class Bin(kind : String, lowSplit : Split, highSplit : Split) { +case class Bin(lowSplit : Split, highSplit : Split, kind : String) { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala new file mode 100644 index 0000000000000..62e5006c80c1b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.model + +case class Filter(split : Split, comparison : Int) { + // Comparison -1,0,1 signifies <.=,> + override def toString = " split = " + split + "comparison = " + comparison +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 4f8beb73bbb09..1b39154d42e47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -16,14 +16,11 @@ */ package org.apache.spark.mllib.tree.model -case class Split( - val feature: Int, - val threshold : Double, - val kind : String) { - +case class Split(feature: Int, threshold : Double, kind : String){ + override def toString = "Feature = " + feature + ", threshold = " + threshold + ", kind = " + kind } -class dummyLowSplit(kind : String) extends Split(Int.MinValue, Double.MinValue, kind) +class DummyLowSplit(kind : String) extends Split(Int.MinValue, Double.MinValue, kind) -class dummyHighSplit(kind : String) extends Split(Int.MaxValue, Double.MaxValue, kind) +class DummyHighSplit(kind : String) extends Split(Int.MaxValue, Double.MaxValue, kind) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala new file mode 100644 index 0000000000000..22c6b6eca1876 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree + +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ + +import org.jblas._ +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.impurity.Gini + +class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { + + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + test("split and bin calculation"){ + val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy("regression",Gini,3,100,"sort") + val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + assert(splits.length==2) + assert(bins.length==2) + assert(splits(0).length==99) + assert(bins(0).length==100) + println(splits(1)(98)) + } +} + +object DecisionTreeSuite { + + def generateReverseOrderedLabeledPoints() : Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + val lp = new LabeledPoint(1.0,Array(i.toDouble,1000.0-i)) + arr(i) = lp + } + arr + } + +} From 8bca1e20b703fd90bc6fcdbed5d36b42a0bdf66e Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 8 Dec 2013 19:48:39 -0800 Subject: [PATCH 03/48] additional code for creating intermediate RDD Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 124 ++++++++++++++---- .../apache/spark/mllib/tree/Strategy.scala | 2 +- .../spark/mllib/tree/impurity/Impurity.scala | 2 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 18 +++ 4 files changed, 120 insertions(+), 26 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 575eb4e8d825f..883ddcf74999e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -37,7 +37,6 @@ class DecisionTree(val strategy : Strategy) { val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) //TODO: Level-wise training of tree and obtain Decision Tree model - val maxDepth = strategy.maxDepth val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1 @@ -55,8 +54,20 @@ class DecisionTree(val strategy : Strategy) { } -object DecisionTree extends Logging { +object DecisionTree extends Serializable { + + /* + Returns an Array[Split] of optimal splits for all nodes at a given level + + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree + @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree + @param level Level of the tree + @param filters Filter for all nodes at a given level + @param splits possible splits for all features + @param bins possible bins for all features + @return Array[Split] instance for best splits for all nodes at a given level. + */ def findBestSplits( input : RDD[LabeledPoint], strategy: Strategy, @@ -65,6 +76,16 @@ object DecisionTree extends Logging { splits : Array[Array[Split]], bins : Array[Array[Bin]]) : Array[Split] = { + //TODO: Move these calculations outside + val numNodes = scala.math.pow(2, level).toInt + println("numNodes = " + numNodes) + //Find the number of features by looking at the first sample + val numFeatures = input.take(1)(0).features.length + println("numFeatures = " + numFeatures) + val numSplits = strategy.numSplits + println("numSplits = " + numSplits) + + /*Find the filters used before reaching the current code*/ def findParentFilters(nodeIndex: Int): List[Filter] = { if (level == 0) { List[Filter]() @@ -75,6 +96,10 @@ object DecisionTree extends Logging { } } + /*Find whether the sample is valid input for the current node. + + In other words, does it pass through all the filters for the current node. + */ def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { for (filter <- parentFilters) { @@ -91,48 +116,49 @@ object DecisionTree extends Logging { true } + /*Finds the right bin for the given feature*/ def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { - //TODO: Do binary search for (binIndex <- 0 until strategy.numSplits) { val bin = bins(featureIndex)(binIndex) - //TODO: Remove this requirement post basic functional testing - require(bin.lowSplit.feature == featureIndex) - require(bin.highSplit.feature == featureIndex) + //TODO: Remove this requirement post basic functional val lowThreshold = bin.lowSplit.threshold val highThreshold = bin.highSplit.threshold val features = labeledPoint.features - if ((lowThreshold < features(featureIndex)) & (highThreshold < features(featureIndex))) { + if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) { return binIndex } } throw new UnknownError("no bin was found.") } - def findBinsForLevel: Array[Double] = { - val numNodes = scala.math.pow(2, level).toInt - //Find the number of features by looking at the first sample - val numFeatures = input.take(1)(0).features.length + /*Finds bins for all nodes (and all features) at a given level + k features, l nodes + Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk + Denotes invalid sample for tree by noting bin for feature 1 as -1 + */ + def findBinsForLevel(labeledPoint : LabeledPoint) : Array[Double] = { + - //TODO: Bit pack more by removing redundant label storage // calculating bin index and label per feature per node - val arr = new Array[Double](2 * numFeatures * numNodes) + val arr = new Array[Double](1+(numFeatures * numNodes)) + arr(0) = labeledPoint.label for (nodeIndex <- 0 until numNodes) { val parentFilters = findParentFilters(nodeIndex) //Find out whether the sample qualifies for the particular node val sampleValid = isSampleValid(parentFilters, labeledPoint) - val shift = 2 * numFeatures * nodeIndex - if (sampleValid) { + val shift = 1 + numFeatures * nodeIndex + if (!sampleValid) { //Add to invalid bin index -1 - for (featureIndex <- shift until (shift + numFeatures) by 2) { - arr(featureIndex + 1) = -1 - arr(featureIndex + 2) = labeledPoint.label + for (featureIndex <- 0 until numFeatures) { + arr(shift+featureIndex) = -1 + //TODO: Break since marking one bin is sufficient } } else { for (featureIndex <- 0 until numFeatures) { - arr(shift + (featureIndex * 2) + 1) = findBin(featureIndex, labeledPoint) - arr(shift + (featureIndex * 2) + 2) = labeledPoint.label + //println("shift+featureIndex =" + (shift+featureIndex)) + arr(shift + featureIndex) = findBin(featureIndex, labeledPoint) } } @@ -140,30 +166,80 @@ object DecisionTree extends Logging { arr } - val binMappedRDD = input.map(labeledPoint => findBinsForLevel) + /* + Performs a sequential aggreation over a partition + + @param agg Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification + and 3*numSplits*numFeatures*numNodes for regression + @param arr Array[Double] of size 1+(numFeatures*numNodes) + @return Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification + and 3*numSplits*numFeatures*numNodes for regression + */ + def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { + for (node <- 0 until numNodes) { + val validSignalIndex = 1+numFeatures*node + val isSampleValidForNode = if(arr(validSignalIndex) != -1) true else false + if(isSampleValidForNode) { + for (feature <- 0 until numFeatures){ + val arrShift = 1 + numFeatures*node + val aggShift = numSplits*numFeatures*node + val arrIndex = arrShift + feature + val aggIndex = aggShift + feature*numSplits + arr(arrIndex).toInt + agg(aggIndex) = agg(aggIndex) + 1 + } + } + } + agg + } + + def binCombOp(par1 : Array[Double], par2: Array[Double]) : Array[Double] = { + par1 + } + + println("input = " + input.count) + val binMappedRDD = input.map(x => findBinsForLevel(x)) + println("binMappedRDD.count = " + binMappedRDD.count) //calculate bin aggregates + + val binAggregates = binMappedRDD.aggregate(Array.fill[Double](numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp) + //find best split + println("binAggregates.length = " + binAggregates.length) - Array[Split]() + val bestSplits = new Array[Split](numNodes) + for (node <- 0 until numNodes){ + val binsForNode = binAggregates.slice(node,numSplits*node) + } + + bestSplits } + /* + Returns split and bins for decision tree calculation. + + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree + @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree + @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an + Array[Array[Bin]] of size (numFeatures,numSplits1) + */ def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { val numSplits = strategy.numSplits - logDebug("numSplits = " + numSplits) + println("numSplits = " + numSplits) //Calculate the number of sample for approximate quantile calculation //TODO: Justify this calculation val requiredSamples = numSplits*numSplits val count = input.count() val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 - logDebug("fraction of data used for calculating quantiles = " + fraction) + println("fraction of data used for calculating quantiles = " + fraction) //sampled input for RDD calculation val sampledInput = input.sample(false, fraction, 42).collect() val numSamples = sampledInput.length + //TODO: Remove this requirement require(numSamples > numSplits, "length of input samples should be greater than numSplits") //Find the number of features by looking at the first sample diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala index a7077f0914033..7f88053043e0a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.mllib.tree.impurity.Impurity -class Strategy ( +case class Strategy ( val kind : String, val impurity : Impurity, val maxDepth : Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index fadebb0c203eb..4b6e679820f59 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.mllib.tree.impurity -trait Impurity { +trait Impurity extends Serializable { def calculate(c0 : Double, c1 : Double): Double diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 22c6b6eca1876..0e8c9ba850e4f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -28,6 +28,7 @@ import org.jblas._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity.Gini +import org.apache.spark.mllib.tree.model.Filter class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { @@ -54,6 +55,23 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) println(splits(1)(98)) } + + test("stump"){ + val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy("regression",Gini,3,100,"sort") + val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + assert(splits.length==2) + assert(splits(0).length==99) + assert(bins.length==2) + assert(bins(0).length==100) + assert(splits(0).length==99) + assert(bins(0).length==100) + println(splits(1)(98)) + DecisionTree.findBestSplits(rdd,strategy,0,Array[List[Filter]](),splits,bins) + } + } object DecisionTreeSuite { From 0012a77eb02e0a6627b7e3e68ac4d0f29d0885e0 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 9 Dec 2013 21:08:44 -0800 Subject: [PATCH 04/48] basic stump working Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 170 ++++++++++++++++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 2 +- 2 files changed, 153 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 883ddcf74999e..ddb78d3903049 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -45,7 +45,8 @@ class DecisionTree(val strategy : Strategy) { for (level <- 0 until maxDepth){ //Find best split for all nodes at a level val numNodes= scala.math.pow(2,level).toInt - val bestSplits = DecisionTree.findBestSplits(input, strategy, level, filters,splits,bins) + //TODO: Change the input parent impurities values + val bestSplits = DecisionTree.findBestSplits(input, Array(0.0), strategy, level, filters,splits,bins) //TODO: update filters and decision tree model } @@ -60,6 +61,7 @@ object DecisionTree extends Serializable { Returns an Array[Split] of optimal splits for all nodes at a given level @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree + @param parentImpurities Impurities for all parent nodes for the current level @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree @param level Level of the tree @param filters Filter for all nodes at a given level @@ -70,13 +72,14 @@ object DecisionTree extends Serializable { */ def findBestSplits( input : RDD[LabeledPoint], + parentImpurities : Array[Double], strategy: Strategy, level: Int, filters : Array[List[Filter]], splits : Array[Array[Split]], bins : Array[Array[Bin]]) : Array[Split] = { - //TODO: Move these calculations outside + //Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt println("numNodes = " + numNodes) //Find the number of features by looking at the first sample @@ -118,6 +121,7 @@ object DecisionTree extends Serializable { /*Finds the right bin for the given feature*/ def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { + println("finding bin for labeled point " + labeledPoint.features(featureIndex)) //TODO: Do binary search for (binIndex <- 0 until strategy.numSplits) { val bin = bins(featureIndex)(binIndex) @@ -134,7 +138,7 @@ object DecisionTree extends Serializable { } /*Finds bins for all nodes (and all features) at a given level - k features, l nodes + k features, l nodes (level = log2(l)) Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk Denotes invalid sample for tree by noting bin for feature 1 as -1 */ @@ -167,33 +171,53 @@ object DecisionTree extends Serializable { } /* - Performs a sequential aggreation over a partition + Performs a sequential aggregation over a partition. - @param agg Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification - and 3*numSplits*numFeatures*numNodes for regression + for p bins, k features, l nodes (level = log2(l)) storage is of the form: + b111_left_count,b111_right_count, .... , bpk1_left_count, bpk1_right_count, .... , bpkl_left_count, bpkl_right_count + + @param agg Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification @param arr Array[Double] of size 1+(numFeatures*numNodes) - @return Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification - and 3*numSplits*numFeatures*numNodes for regression + @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification */ def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { + //TODO: Requires logic for regressions for (node <- 0 until numNodes) { val validSignalIndex = 1+numFeatures*node val isSampleValidForNode = if(arr(validSignalIndex) != -1) true else false - if(isSampleValidForNode) { + if(isSampleValidForNode){ + val label = arr(0) for (feature <- 0 until numFeatures){ val arrShift = 1 + numFeatures*node - val aggShift = numSplits*numFeatures*node + val aggShift = 2*numSplits*numFeatures*node val arrIndex = arrShift + feature - val aggIndex = aggShift + feature*numSplits + arr(arrIndex).toInt - agg(aggIndex) = agg(aggIndex) + 1 + val aggIndex = aggShift + 2*feature*numSplits + arr(arrIndex).toInt*2 + label match { + case(0.0) => agg(aggIndex) = agg(aggIndex) + 1 + case(1.0) => agg(aggIndex+1) = agg(aggIndex+1) + 1 + } } } } agg } - def binCombOp(par1 : Array[Double], par2: Array[Double]) : Array[Double] = { - par1 + //TODO: This length if different for regression + val binAggregateLength = 2*numSplits * numFeatures * numNodes + println("binAggregageLength = " + binAggregateLength) + + /*Combines the aggregates from partitions + @param agg1 Array containing aggregates from one or more partitions + @param agg2 Array contianing aggregates from one or more partitions + + @return Combined aggregate from agg1 and agg2 + */ + def binCombOp(agg1 : Array[Double], agg2: Array[Double]) : Array[Double] = { + val combinedAggregate = new Array[Double](binAggregateLength) + for (index <- 0 until binAggregateLength){ + combinedAggregate(index) = agg1(index) + agg2(index) + } + combinedAggregate } println("input = " + input.count) @@ -201,15 +225,125 @@ object DecisionTree extends Serializable { println("binMappedRDD.count = " + binMappedRDD.count) //calculate bin aggregates - val binAggregates = binMappedRDD.aggregate(Array.fill[Double](numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp) - - //find best split + val binAggregates = binMappedRDD.aggregate(Array.fill[Double](2*numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp) println("binAggregates.length = " + binAggregates.length) + binAggregates.foreach(x => println(x)) + + + def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, index: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double): Double = { + + val left0Count = leftNodeAgg(featureIndex)(2 * index) + val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) + val leftCount = left0Count + left1Count + println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) + val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) + val right0Count = rightNodeAgg(featureIndex)(2 * index) + val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) + val rightCount = right0Count + right1Count + println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount) + val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) + + topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity + + } + + /* + Extracts left and right split aggregates + + @param binData Array[Double] of size 2*numFeatures*numSplits + @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], Array[Double]) where + each array is of size(numFeature,2*(numSplits-1)) + */ + def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { + val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) + println("binData.length = " + binData.length) + println("binData.sum = " + binData.sum) + for (featureIndex <- 0 until numFeatures) { + println("featureIndex = " + featureIndex) + val shift = 2*featureIndex*numSplits + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + println("binData(shift + 0) = " + binData(shift + 0)) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + println("binData(shift + 1) = " + binData(shift + 1)) + rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) + println(binData(shift + (2 * (numSplits - 1)))) + rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) + println(binData(shift + (2 * (numSplits - 1)) + 1)) + for (splitIndex <- 1 until numSplits - 1) { + println("splitIndex = " + splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex) + = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) + leftNodeAgg(featureIndex)(2 * splitIndex + 1) + = binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) + rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex)) + = binData(shift + (2 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex)) + rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex) + 1) + = binData(shift + (2 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex) + 1) + } + } + (leftNodeAgg, rightNodeAgg) + } + + def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double): Array[Array[Double]] = { + + val gains = Array.ofDim[Double](numFeatures, numSplits - 1) + + for (featureIndex <- 0 until numFeatures) { + for (index <- 0 until numSplits -1) { + println("splitIndex = " + index) + gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity) + } + } + gains + } + + /* + Find the best split for a node given bin aggregate data + + @param binData Array[Double] of size 2*numSplits*numFeatures + */ + def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : Split = { + println("node impurity = " + nodeImpurity) + val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) + val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) + + println("gains.size = " + gains.size) + println("gains(0).size = " + gains(0).size) + + val (bestFeatureIndex,bestSplitIndex) = { + var bestFeatureIndex = 0 + var bestSplitIndex = 0 + var maxGain = Double.MinValue + for (featureIndex <- 0 until numFeatures) { + for (splitIndex <- 0 until numSplits - 1){ + val gain = gains(featureIndex)(splitIndex) + println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) + if(gain > maxGain) { + maxGain = gain + bestFeatureIndex = featureIndex + bestSplitIndex = splitIndex + println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex + ", maxGain = " + maxGain) + } + } + } + (bestFeatureIndex,bestSplitIndex) + } + + splits(bestFeatureIndex)(bestSplitIndex) + } + //Calculate best splits for all nodes at a given level val bestSplits = new Array[Split](numNodes) for (node <- 0 until numNodes){ - val binsForNode = binAggregates.slice(node,numSplits*node) + val shift = 2*node*numSplits*numFeatures + val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures) + val parentNodeImpurity = parentImpurities(node/2) + bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) } bestSplits diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 0e8c9ba850e4f..e886c40901b45 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -69,7 +69,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(splits(0).length==99) assert(bins(0).length==100) println(splits(1)(98)) - DecisionTree.findBestSplits(rdd,strategy,0,Array[List[Filter]](),splits,bins) + DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) } } From 03f534c2f9a8dd739945f92b98a58e93fa5b716a Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 9 Dec 2013 22:10:46 -0800 Subject: [PATCH 05/48] some more tests Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 43 ++++++---- .../spark/mllib/tree/DecisionTreeSuite.scala | 82 +++++++++++++++++-- 2 files changed, 102 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ddb78d3903049..43ede29ef6fd8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -121,7 +121,7 @@ object DecisionTree extends Serializable { /*Finds the right bin for the given feature*/ def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { - println("finding bin for labeled point " + labeledPoint.features(featureIndex)) + //println("finding bin for labeled point " + labeledPoint.features(featureIndex)) //TODO: Do binary search for (binIndex <- 0 until strategy.numSplits) { val bin = bins(featureIndex)(binIndex) @@ -227,7 +227,7 @@ object DecisionTree extends Serializable { val binAggregates = binMappedRDD.aggregate(Array.fill[Double](2*numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp) println("binAggregates.length = " + binAggregates.length) - binAggregates.foreach(x => println(x)) + //binAggregates.foreach(x => println(x)) def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, index: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double): Double = { @@ -235,13 +235,19 @@ object DecisionTree extends Serializable { val left0Count = leftNodeAgg(featureIndex)(2 * index) val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) val leftCount = left0Count + left1Count - println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) + + if (leftCount == 0) return 0 + + //println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) val right0Count = rightNodeAgg(featureIndex)(2 * index) val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) val rightCount = right0Count + right1Count - println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount) + + if (rightCount == 0) return 0 + + //println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount) val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) val leftWeight = leftCount.toDouble / (leftCount + rightCount) @@ -261,21 +267,21 @@ object DecisionTree extends Serializable { def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) - println("binData.length = " + binData.length) - println("binData.sum = " + binData.sum) + //println("binData.length = " + binData.length) + //println("binData.sum = " + binData.sum) for (featureIndex <- 0 until numFeatures) { - println("featureIndex = " + featureIndex) + //println("featureIndex = " + featureIndex) val shift = 2*featureIndex*numSplits leftNodeAgg(featureIndex)(0) = binData(shift + 0) - println("binData(shift + 0) = " + binData(shift + 0)) + //println("binData(shift + 0) = " + binData(shift + 0)) leftNodeAgg(featureIndex)(1) = binData(shift + 1) - println("binData(shift + 1) = " + binData(shift + 1)) + //println("binData(shift + 1) = " + binData(shift + 1)) rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) - println(binData(shift + (2 * (numSplits - 1)))) + //println(binData(shift + (2 * (numSplits - 1)))) rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) - println(binData(shift + (2 * (numSplits - 1)) + 1)) + //println(binData(shift + (2 * (numSplits - 1)) + 1)) for (splitIndex <- 1 until numSplits - 1) { - println("splitIndex = " + splitIndex) + //println("splitIndex = " + splitIndex) leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) leftNodeAgg(featureIndex)(2 * splitIndex + 1) @@ -295,7 +301,7 @@ object DecisionTree extends Serializable { for (featureIndex <- 0 until numFeatures) { for (index <- 0 until numSplits -1) { - println("splitIndex = " + index) + //println("splitIndex = " + index) gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity) } } @@ -312,8 +318,8 @@ object DecisionTree extends Serializable { val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - println("gains.size = " + gains.size) - println("gains(0).size = " + gains(0).size) + //println("gains.size = " + gains.size) + //println("gains(0).size = " + gains(0).size) val (bestFeatureIndex,bestSplitIndex) = { var bestFeatureIndex = 0 @@ -322,7 +328,7 @@ object DecisionTree extends Serializable { for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numSplits - 1){ val gain = gains(featureIndex)(splitIndex) - println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) + //println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) if(gain > maxGain) { maxGain = gain bestFeatureIndex = featureIndex @@ -335,6 +341,8 @@ object DecisionTree extends Serializable { } splits(bestFeatureIndex)(bestSplitIndex) + + //TODo: Return array of node stats with split and impurity information } //Calculate best splits for all nodes at a given level @@ -388,6 +396,9 @@ object DecisionTree extends Serializable { for (featureIndex <- 0 until numFeatures){ val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted val stride : Double = numSamples.toDouble/numSplits + + println("stride = " + stride) + for (index <- 0 until numSplits-1) { val sampleIndex = (index+1)*stride.toInt val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index e886c40901b45..2c9794371eb29 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkContext._ import org.jblas._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.impurity.Gini +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini} import org.apache.spark.mllib.tree.model.Filter class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { @@ -44,7 +44,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { } test("split and bin calculation"){ - val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints() + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy("regression",Gini,3,100,"sort") @@ -56,8 +56,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { println(splits(1)(98)) } - test("stump"){ - val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints() + test("stump with fixed label 0 for Gini"){ + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy("regression",Gini,3,100,"sort") @@ -69,17 +69,85 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(splits(0).length==99) assert(bins(0).length==100) println(splits(1)(98)) - DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + assert(bestSplits.length == 1) + println(bestSplits(0)) } + test("stump with fixed label 1 for Gini"){ + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy("regression",Gini,3,100,"sort") + val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + assert(splits.length==2) + assert(splits(0).length==99) + assert(bins.length==2) + assert(bins(0).length==100) + assert(splits(0).length==99) + assert(bins(0).length==100) + println(splits(1)(98)) + val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + assert(bestSplits.length == 1) + println(bestSplits(0)) + } + + + test("stump with fixed label 0 for Entropy"){ + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy("regression",Entropy,3,100,"sort") + val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + assert(splits.length==2) + assert(splits(0).length==99) + assert(bins.length==2) + assert(bins(0).length==100) + assert(splits(0).length==99) + assert(bins(0).length==100) + println(splits(1)(98)) + val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + assert(bestSplits.length == 1) + println(bestSplits(0)) + } + + test("stump with fixed label 1 for Entropy"){ + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy("regression",Entropy,3,100,"sort") + val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + assert(splits.length==2) + assert(splits(0).length==99) + assert(bins.length==2) + assert(bins(0).length==100) + assert(splits(0).length==99) + assert(bins(0).length==100) + println(splits(1)(98)) + val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + assert(bestSplits.length == 1) + println(bestSplits(0)) + } + + } object DecisionTreeSuite { - def generateReverseOrderedLabeledPoints() : Array[LabeledPoint] = { + def generateOrderedLabeledPointsWithLabel0() : Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i)) + arr(i) = lp + } + arr + } + + + def generateOrderedLabeledPointsWithLabel1() : Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ - val lp = new LabeledPoint(1.0,Array(i.toDouble,1000.0-i)) + val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i)) arr(i) = lp } arr From dad0afc85aea64c06b4dd64504b3112c881ae4e6 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 15 Dec 2013 00:25:58 -0800 Subject: [PATCH 06/48] decison stump functionality working Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 123 +++++++++++++----- .../spark/mllib/tree/DecisionTreeSuite.scala | 28 ++-- 2 files changed, 108 insertions(+), 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 43ede29ef6fd8..4f7324345e1d8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -20,9 +20,10 @@ package org.apache.spark.mllib.tree import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.tree.model._ -import org.apache.spark.Logging +import org.apache.spark.{SparkContext, Logging} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.Split +import org.apache.spark.mllib.tree.impurity.Gini class DecisionTree(val strategy : Strategy) { @@ -46,8 +47,13 @@ class DecisionTree(val strategy : Strategy) { //Find best split for all nodes at a level val numNodes= scala.math.pow(2,level).toInt //TODO: Change the input parent impurities values - val bestSplits = DecisionTree.findBestSplits(input, Array(0.0), strategy, level, filters,splits,bins) + val splits_stats_for_level = DecisionTree.findBestSplits(input, Array(2.0), strategy, level, filters,splits,bins) + for (tmp <- splits_stats_for_level){ + println("final best split = " + tmp._1) + } //TODO: update filters and decision tree model + require(scala.math.pow(2,level)==splits_stats_for_level.length) + } return new DecisionTreeModel() @@ -77,7 +83,7 @@ object DecisionTree extends Serializable { level: Int, filters : Array[List[Filter]], splits : Array[Array[Split]], - bins : Array[Array[Bin]]) : Array[Split] = { + bins : Array[Array[Bin]]) : Array[(Split, Double, Long, Long)] = { //Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt @@ -94,8 +100,9 @@ object DecisionTree extends Serializable { List[Filter]() } else { val nodeFilterIndex = scala.math.pow(2, level).toInt + nodeIndex - val parentFilterIndex = nodeFilterIndex / 2 - filters(parentFilterIndex) + //val parentFilterIndex = nodeFilterIndex / 2 + //TODO: Check left or right filter + filters(nodeFilterIndex) } } @@ -230,22 +237,26 @@ object DecisionTree extends Serializable { //binAggregates.foreach(x => println(x)) - def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, index: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double): Double = { + def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], + featureIndex: Int, + index: Int, + rightNodeAgg: Array[Array[Double]], + topImpurity: Double) : (Double, Long, Long) = { val left0Count = leftNodeAgg(featureIndex)(2 * index) val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) val leftCount = left0Count + left1Count - if (leftCount == 0) return 0 - - //println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) - val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) - val right0Count = rightNodeAgg(featureIndex)(2 * index) val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) val rightCount = right0Count + right1Count - if (rightCount == 0) return 0 + if (leftCount == 0) return (0, leftCount.toLong, rightCount.toLong) + + //println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) + val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) + + if (rightCount == 0) return (0, leftCount.toLong, rightCount.toLong) //println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount) val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) @@ -253,7 +264,7 @@ object DecisionTree extends Serializable { val leftWeight = leftCount.toDouble / (leftCount + rightCount) val rightWeight = rightCount.toDouble / (leftCount + rightCount) - topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity + (topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity, leftCount.toLong, rightCount.toLong) } @@ -295,9 +306,10 @@ object DecisionTree extends Serializable { (leftNodeAgg, rightNodeAgg) } - def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double): Array[Array[Double]] = { + def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double) + : Array[Array[(Double,Long,Long)]] = { - val gains = Array.ofDim[Double](numFeatures, numSplits - 1) + val gains = Array.ofDim[(Double,Long,Long)](numFeatures, numSplits - 1) for (featureIndex <- 0 until numFeatures) { for (index <- 0 until numSplits -1) { @@ -313,7 +325,7 @@ object DecisionTree extends Serializable { @param binData Array[Double] of size 2*numSplits*numFeatures */ - def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : Split = { + def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, Double, Long, Long) = { println("node impurity = " + nodeImpurity) val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) @@ -321,32 +333,36 @@ object DecisionTree extends Serializable { //println("gains.size = " + gains.size) //println("gains(0).size = " + gains(0).size) - val (bestFeatureIndex,bestSplitIndex) = { + val (bestFeatureIndex,bestSplitIndex, gain, leftCount, rightCount) = { var bestFeatureIndex = 0 var bestSplitIndex = 0 var maxGain = Double.MinValue + var leftSamples = Long.MinValue + var rightSamples = Long.MinValue for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numSplits - 1){ val gain = gains(featureIndex)(splitIndex) //println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) - if(gain > maxGain) { - maxGain = gain + if(gain._1 > maxGain) { + maxGain = gain._1 + leftSamples = gain._2 + rightSamples = gain._3 bestFeatureIndex = featureIndex bestSplitIndex = splitIndex - println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex + ", maxGain = " + maxGain) + println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex + + ", maxGain = " + maxGain + ", leftSamples = " + leftSamples + ",rightSamples = " + rightSamples) } } } - (bestFeatureIndex,bestSplitIndex) + (bestFeatureIndex,bestSplitIndex,maxGain,leftSamples,rightSamples) } - splits(bestFeatureIndex)(bestSplitIndex) - - //TODo: Return array of node stats with split and impurity information + (splits(bestFeatureIndex)(bestSplitIndex),gain,leftCount,rightCount) + //TODO: Return array of node stats with split and impurity information } //Calculate best splits for all nodes at a given level - val bestSplits = new Array[Split](numNodes) + val bestSplits = new Array[(Split, Double, Long, Long)](numNodes) for (node <- 0 until numNodes){ val shift = 2*node*numSplits*numFeatures val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures) @@ -381,9 +397,6 @@ object DecisionTree extends Serializable { val sampledInput = input.sample(false, fraction, 42).collect() val numSamples = sampledInput.length - //TODO: Remove this requirement - require(numSamples > numSplits, "length of input samples should be greater than numSplits") - //Find the number of features by looking at the first sample val numFeatures = input.take(1)(0).features.length @@ -395,14 +408,22 @@ object DecisionTree extends Serializable { //Find all splits for (featureIndex <- 0 until numFeatures){ val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride : Double = numSamples.toDouble/numSplits - - println("stride = " + stride) - for (index <- 0 until numSplits-1) { - val sampleIndex = (index+1)*stride.toInt - val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous") - splits(featureIndex)(index) = split + if (numSamples < numSplits) { + //TODO: Test this + println("numSamples = " + numSamples + ", less than numSplits = " + numSplits) + for (index <- 0 until numSplits-1) { + val split = new Split(featureIndex,featureSamples(index),"continuous") + splits(featureIndex)(index) = split + } + } else { + val stride : Double = numSamples.toDouble/numSplits + println("stride = " + stride) + for (index <- 0 until numSplits-1) { + val sampleIndex = (index+1)*stride.toInt + val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous") + splits(featureIndex)(index) = split + } } } @@ -430,4 +451,36 @@ object DecisionTree extends Serializable { } } + def main(args: Array[String]) { + + val sc = new SparkContext(args(0), "DecisionTree") + val data = loadLabeledData(sc, args(1)) + + val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = 2, numSplits = 569) + val model = new DecisionTree(strategy).train(data) + + sc.stop() + } + + /** + * Load labeled data from a file. The data format used here is + * , ... + * where , are feature values in Double and is the corresponding label as Double. + * + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is + * the label, and the second element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { + sc.textFile(dir).map { line => + val parts = line.trim().split(",") + val label = parts(0).toDouble + val features = parts.slice(1,parts.length).map(_.toDouble) + LabeledPoint(label, features) + } + } + + + } \ No newline at end of file diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2c9794371eb29..5ee61b0a5173c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -68,10 +68,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) assert(splits(0).length==99) assert(bins(0).length==100) - println(splits(1)(98)) val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) assert(bestSplits.length == 1) - println(bestSplits(0)) + assert(0==bestSplits(0)._1.feature) + assert(10==bestSplits(0)._1.threshold) + assert(0==bestSplits(0)._2) + assert(10==bestSplits(0)._3) + assert(990==bestSplits(0)._4) } test("stump with fixed label 1 for Gini"){ @@ -86,10 +89,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) assert(splits(0).length==99) assert(bins(0).length==100) - println(splits(1)(98)) val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) assert(bestSplits.length == 1) - println(bestSplits(0)) + assert(0==bestSplits(0)._1.feature) + assert(10==bestSplits(0)._1.threshold) + assert(0==bestSplits(0)._2) + assert(10==bestSplits(0)._3) + assert(990==bestSplits(0)._4) } @@ -105,10 +111,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) assert(splits(0).length==99) assert(bins(0).length==100) - println(splits(1)(98)) val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) assert(bestSplits.length == 1) - println(bestSplits(0)) + assert(0==bestSplits(0)._1.feature) + assert(10==bestSplits(0)._1.threshold) + assert(0==bestSplits(0)._2) + assert(10==bestSplits(0)._3) + assert(990==bestSplits(0)._4) } test("stump with fixed label 1 for Entropy"){ @@ -123,10 +132,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) assert(splits(0).length==99) assert(bins(0).length==100) - println(splits(1)(98)) val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) assert(bestSplits.length == 1) - println(bestSplits(0)) + assert(0==bestSplits(0)._1.feature) + assert(10==bestSplits(0)._1.threshold) + assert(0==bestSplits(0)._2) + assert(10==bestSplits(0)._3) + assert(990==bestSplits(0)._4) } From 4798aae63e898fed71e6240462a163ad81ccd64b Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 15 Dec 2013 00:45:23 -0800 Subject: [PATCH 07/48] added gain stats class Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 45 ++++++++++--------- .../spark/mllib/tree/DecisionTreeSuite.scala | 32 ++++++++----- 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4f7324345e1d8..4fd030e3a3c05 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -83,7 +83,7 @@ object DecisionTree extends Serializable { level: Int, filters : Array[List[Filter]], splits : Array[Array[Split]], - bins : Array[Array[Bin]]) : Array[(Split, Double, Long, Long)] = { + bins : Array[Array[Bin]]) : Array[(Split, InformationGainStats)] = { //Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt @@ -241,7 +241,7 @@ object DecisionTree extends Serializable { featureIndex: Int, index: Int, rightNodeAgg: Array[Array[Double]], - topImpurity: Double) : (Double, Long, Long) = { + topImpurity: Double) : InformationGainStats = { val left0Count = leftNodeAgg(featureIndex)(2 * index) val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) @@ -251,12 +251,12 @@ object DecisionTree extends Serializable { val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) val rightCount = right0Count + right1Count - if (leftCount == 0) return (0, leftCount.toLong, rightCount.toLong) + if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) + if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) //println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) - if (rightCount == 0) return (0, leftCount.toLong, rightCount.toLong) //println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount) val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) @@ -264,7 +264,9 @@ object DecisionTree extends Serializable { val leftWeight = leftCount.toDouble / (leftCount + rightCount) val rightWeight = rightCount.toDouble / (leftCount + rightCount) - (topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity, leftCount.toLong, rightCount.toLong) + val gain = topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity + + new InformationGainStats(gain,topImpurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) } @@ -307,9 +309,9 @@ object DecisionTree extends Serializable { } def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double) - : Array[Array[(Double,Long,Long)]] = { + : Array[Array[InformationGainStats]] = { - val gains = Array.ofDim[(Double,Long,Long)](numFeatures, numSplits - 1) + val gains = Array.ofDim[InformationGainStats](numFeatures, numSplits - 1) for (featureIndex <- 0 until numFeatures) { for (index <- 0 until numSplits -1) { @@ -325,7 +327,7 @@ object DecisionTree extends Serializable { @param binData Array[Double] of size 2*numSplits*numFeatures */ - def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, Double, Long, Long) = { + def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, InformationGainStats) = { println("node impurity = " + nodeImpurity) val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) @@ -333,36 +335,35 @@ object DecisionTree extends Serializable { //println("gains.size = " + gains.size) //println("gains(0).size = " + gains(0).size) - val (bestFeatureIndex,bestSplitIndex, gain, leftCount, rightCount) = { + val (bestFeatureIndex,bestSplitIndex, gainStats) = { var bestFeatureIndex = 0 var bestSplitIndex = 0 - var maxGain = Double.MinValue - var leftSamples = Long.MinValue - var rightSamples = Long.MinValue + //Initialization with infeasible values + var bestGainStats = new InformationGainStats(-1.0,-1.0,-1.0,0,-1.0,0) +// var maxGain = Double.MinValue +// var leftSamples = Long.MinValue +// var rightSamples = Long.MinValue for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numSplits - 1){ - val gain = gains(featureIndex)(splitIndex) + val gainStats = gains(featureIndex)(splitIndex) //println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) - if(gain._1 > maxGain) { - maxGain = gain._1 - leftSamples = gain._2 - rightSamples = gain._3 + if(gainStats.gain > bestGainStats.gain) { + bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex - + ", maxGain = " + maxGain + ", leftSamples = " + leftSamples + ",rightSamples = " + rightSamples) + + ", gain stats = " + bestGainStats) } } } - (bestFeatureIndex,bestSplitIndex,maxGain,leftSamples,rightSamples) + (bestFeatureIndex,bestSplitIndex,bestGainStats) } - (splits(bestFeatureIndex)(bestSplitIndex),gain,leftCount,rightCount) - //TODO: Return array of node stats with split and impurity information + (splits(bestFeatureIndex)(bestSplitIndex),gainStats) } //Calculate best splits for all nodes at a given level - val bestSplits = new Array[(Split, Double, Long, Long)](numNodes) + val bestSplits = new Array[(Split, InformationGainStats)](numNodes) for (node <- 0 until numNodes){ val shift = 2*node*numSplits*numFeatures val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5ee61b0a5173c..2b5988bb3c6a3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -72,9 +72,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2) - assert(10==bestSplits(0)._3) - assert(990==bestSplits(0)._4) + assert(0==bestSplits(0)._2.gain) + assert(10==bestSplits(0)._2.leftSamples) + assert(0==bestSplits(0)._2.leftImpurity) + assert(990==bestSplits(0)._2.rightSamples) + assert(0==bestSplits(0)._2.rightImpurity) } test("stump with fixed label 1 for Gini"){ @@ -93,9 +95,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2) - assert(10==bestSplits(0)._3) - assert(990==bestSplits(0)._4) + assert(0==bestSplits(0)._2.gain) + assert(10==bestSplits(0)._2.leftSamples) + assert(0==bestSplits(0)._2.leftImpurity) + assert(990==bestSplits(0)._2.rightSamples) + assert(0==bestSplits(0)._2.rightImpurity) } @@ -115,9 +119,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2) - assert(10==bestSplits(0)._3) - assert(990==bestSplits(0)._4) + assert(0==bestSplits(0)._2.gain) + assert(10==bestSplits(0)._2.leftSamples) + assert(0==bestSplits(0)._2.leftImpurity) + assert(990==bestSplits(0)._2.rightSamples) + assert(0==bestSplits(0)._2.rightImpurity) } test("stump with fixed label 1 for Entropy"){ @@ -136,9 +142,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2) - assert(10==bestSplits(0)._3) - assert(990==bestSplits(0)._4) + assert(0==bestSplits(0)._2.gain) + assert(10==bestSplits(0)._2.leftSamples) + assert(0==bestSplits(0)._2.leftImpurity) + assert(990==bestSplits(0)._2.rightSamples) + assert(0==bestSplits(0)._2.rightImpurity) } From 80e8c66dd25ad03c706f4993b10ba4caafa54c18 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 15 Dec 2013 17:41:59 -0800 Subject: [PATCH 08/48] working version of multi-level split calculation Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 75 +++++++++++++------ .../spark/mllib/tree/impurity/Gini.scala | 16 ++-- 2 files changed, 63 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4fd030e3a3c05..a2a3dba213e7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -26,7 +26,7 @@ import org.apache.spark.mllib.tree.model.Split import org.apache.spark.mllib.tree.impurity.Gini -class DecisionTree(val strategy : Strategy) { +class DecisionTree(val strategy : Strategy) extends Logging { def train(input : RDD[LabeledPoint]) : DecisionTreeModel = { @@ -42,20 +42,43 @@ class DecisionTree(val strategy : Strategy) { val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1 val filters = new Array[List[Filter]](maxNumNodes) + filters(0) = List() + val parentImpurities = new Array[Double](maxNumNodes) + //Dummy value for top node (calculate from scratch during first split calculation) + parentImpurities(0) = Double.MinValue for (level <- 0 until maxDepth){ + + println("#####################################") + println("level = " + level) + println("#####################################") + //Find best split for all nodes at a level val numNodes= scala.math.pow(2,level).toInt - //TODO: Change the input parent impurities values - val splits_stats_for_level = DecisionTree.findBestSplits(input, Array(2.0), strategy, level, filters,splits,bins) - for (tmp <- splits_stats_for_level){ - println("final best split = " + tmp._1) + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins) + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){ + for (i <- 0 to 1){ + val nodeIndex = (scala.math.pow(2,level+1)).toInt - 1 + 2*index + i + if(level < maxDepth - 1){ + val impurity = if (i == 0) nodeSplitStats._2.leftImpurity else nodeSplitStats._2.rightImpurity + println("nodeIndex = " + nodeIndex + ", impurity = " + impurity) + parentImpurities(nodeIndex) = impurity + println("updating nodeIndex = " + nodeIndex) + filters(nodeIndex) = new Filter(nodeSplitStats._1, if(i == 0) - 1 else 1) :: filters((nodeIndex-1)/2) + for (filter <- filters(nodeIndex)){ + println(filter) + } + } + } + println("final best split = " + nodeSplitStats._1) } - //TODO: update filters and decision tree model - require(scala.math.pow(2,level)==splits_stats_for_level.length) + require(scala.math.pow(2,level)==splitsStatsForLevel.length) + } + //TODO: Extract decision tree model + return new DecisionTreeModel() } @@ -99,7 +122,7 @@ object DecisionTree extends Serializable { if (level == 0) { List[Filter]() } else { - val nodeFilterIndex = scala.math.pow(2, level).toInt + nodeIndex + val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex //val parentFilterIndex = nodeFilterIndex / 2 //TODO: Check left or right filter filters(nodeFilterIndex) @@ -155,11 +178,11 @@ object DecisionTree extends Serializable { // calculating bin index and label per feature per node val arr = new Array[Double](1+(numFeatures * numNodes)) arr(0) = labeledPoint.label - for (nodeIndex <- 0 until numNodes) { - val parentFilters = findParentFilters(nodeIndex) + for (index <- 0 until numNodes) { + val parentFilters = findParentFilters(index) //Find out whether the sample qualifies for the particular node val sampleValid = isSampleValid(parentFilters, labeledPoint) - val shift = 1 + numFeatures * nodeIndex + val shift = 1 + numFeatures * index if (!sampleValid) { //Add to invalid bin index -1 for (featureIndex <- 0 until numFeatures) { @@ -251,22 +274,26 @@ object DecisionTree extends Serializable { val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) val rightCount = right0Count + right1Count + val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) + if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) - //println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) - - - //println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount) val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) val leftWeight = leftCount.toDouble / (leftCount + rightCount) val rightWeight = rightCount.toDouble / (leftCount + rightCount) - val gain = topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } - new InformationGainStats(gain,topImpurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) + new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) } @@ -339,7 +366,7 @@ object DecisionTree extends Serializable { var bestFeatureIndex = 0 var bestSplitIndex = 0 //Initialization with infeasible values - var bestGainStats = new InformationGainStats(-1.0,-1.0,-1.0,0,-1.0,0) + var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,0,-1.0,0) // var maxGain = Double.MinValue // var leftSamples = Long.MinValue // var rightSamples = Long.MinValue @@ -351,8 +378,8 @@ object DecisionTree extends Serializable { bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex - println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex - + ", gain stats = " + bestGainStats) + //println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex) + //println( "gain stats = " + bestGainStats) } } } @@ -365,9 +392,12 @@ object DecisionTree extends Serializable { //Calculate best splits for all nodes at a given level val bestSplits = new Array[(Split, InformationGainStats)](numNodes) for (node <- 0 until numNodes){ + val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node val shift = 2*node*numSplits*numFeatures val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures) - val parentNodeImpurity = parentImpurities(node/2) + println("nodeImpurityIndex = " + nodeImpurityIndex) + val parentNodeImpurity = parentImpurities(nodeImpurityIndex) + println("node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) } @@ -456,8 +486,9 @@ object DecisionTree extends Serializable { val sc = new SparkContext(args(0), "DecisionTree") val data = loadLabeledData(sc, args(1)) + val maxDepth = args(2).toInt - val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = 2, numSplits = 569) + val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = maxDepth, numSplits = 569) val model = new DecisionTree(strategy).train(data) sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index a95f0431c6e8f..3396a015e7858 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -18,11 +18,15 @@ package org.apache.spark.mllib.tree.impurity object Gini extends Impurity { - def calculate(c0 : Double, c1 : Double): Double = { - val total = c0 + c1 - val f0 = c0 / total - val f1 = c1 / total - 1 - f0*f0 - f1*f1 - } + def calculate(c0 : Double, c1 : Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + 1 - f0*f0 - f1*f1 + } + } } From b0eb866cfd2d98a9281127e02e0c159668ca01f4 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 15 Dec 2013 20:42:52 -0800 Subject: [PATCH 09/48] added logic to handle leaf nodes Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 5 +++ .../tree/model/InformationGainStats.scala | 32 +++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index a2a3dba213e7f..7749bdd687d1f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -135,6 +135,11 @@ object DecisionTree extends Serializable { */ def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { + //Leaf + if (parentFilters.length == 0 ){ + return false + } + for (filter <- parentFilters) { val features = labeledPoint.features val featureIndex = filter.split.feature diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala new file mode 100644 index 0000000000000..4ca02beec03c0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.model + +class InformationGainStats(val gain : Double, + val impurity: Double, + val leftImpurity : Double, + val leftSamples : Long, + val rightImpurity : Double, + val rightSamples : Long) { + + override def toString = + "gain = " + gain + ", impurity = " + impurity + ", left impurity = " + + leftImpurity + ", leftSamples = " + leftSamples + ", right impurity = " + + rightImpurity + ", rightSamples = " + rightSamples + + +} From 98ec8d57a0a0897b093ced7e3284228ee21ce5f4 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 21 Dec 2013 22:39:29 -0800 Subject: [PATCH 10/48] tree building and prediction logic Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 248 +++++++++--------- .../spark/mllib/tree/DecisionTreeRunner.scala | 74 ++++++ .../apache/spark/mllib/tree/Strategy.scala | 8 +- .../mllib/tree/model/DecisionTreeModel.scala | 6 +- .../tree/model/InformationGainStats.scala | 2 +- .../apache/spark/mllib/tree/model/Node.scala | 60 +++++ 6 files changed, 272 insertions(+), 126 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 7749bdd687d1f..d8ffa12030f8d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -24,9 +24,10 @@ import org.apache.spark.{SparkContext, Logging} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.Split import org.apache.spark.mllib.tree.impurity.Gini +import scala.util.control.Breaks._ -class DecisionTree(val strategy : Strategy) extends Logging { +class DecisionTree(val strategy : Strategy) extends Serializable with Logging { def train(input : RDD[LabeledPoint]) : DecisionTreeModel = { @@ -36,6 +37,8 @@ class DecisionTree(val strategy : Strategy) extends Logging { //TODO: Find all splits and bins using quantiles including support for categorical features, single-pass //TODO: Think about broadcasting this val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) + logDebug("numSplits = " + bins(0).length) + strategy.numBins = bins(0).length //TODO: Level-wise training of tree and obtain Decision Tree model val maxDepth = strategy.maxDepth @@ -44,47 +47,86 @@ class DecisionTree(val strategy : Strategy) extends Logging { val filters = new Array[List[Filter]](maxNumNodes) filters(0) = List() val parentImpurities = new Array[Double](maxNumNodes) - //Dummy value for top node (calculate from scratch during first split calculation) - parentImpurities(0) = Double.MinValue - - for (level <- 0 until maxDepth){ - - println("#####################################") - println("level = " + level) - println("#####################################") - - //Find best split for all nodes at a level - val numNodes= scala.math.pow(2,level).toInt - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins) - for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){ - for (i <- 0 to 1){ - val nodeIndex = (scala.math.pow(2,level+1)).toInt - 1 + 2*index + i - if(level < maxDepth - 1){ - val impurity = if (i == 0) nodeSplitStats._2.leftImpurity else nodeSplitStats._2.rightImpurity - println("nodeIndex = " + nodeIndex + ", impurity = " + impurity) - parentImpurities(nodeIndex) = impurity - println("updating nodeIndex = " + nodeIndex) - filters(nodeIndex) = new Filter(nodeSplitStats._1, if(i == 0) - 1 else 1) :: filters((nodeIndex-1)/2) - for (filter <- filters(nodeIndex)){ - println(filter) - } - } + //Dummy value for top node (updated during first split calculation) + //parentImpurities(0) = Double.MinValue + val nodes = new Array[Node](maxNumNodes) + + + breakable { + for (level <- 0 until maxDepth){ + + logDebug("#####################################") + logDebug("level = " + level) + logDebug("#####################################") + + //Find best split for all nodes at a level + val numNodes= scala.math.pow(2,level).toInt + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins) + + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){ + + extractNodeInfo(nodeSplitStats, level, index, nodes) + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) + logDebug("final best split = " + nodeSplitStats._1) + } - println("final best split = " + nodeSplitStats._1) + require(scala.math.pow(2,level)==splitsStatsForLevel.length) + + val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 ) + logDebug("all leaf = " + allLeaf) + if (allLeaf) break + } - require(scala.math.pow(2,level)==splitsStatsForLevel.length) + } + val topNode = nodes(0) + topNode.build(nodes) + val decisionTreeModel = { + return new DecisionTreeModel(topNode) } - //TODO: Extract decision tree model + return decisionTreeModel + } + - return new DecisionTreeModel() + private def extractNodeInfo(nodeSplitStats: (Split, InformationGainStats), level: Int, index: Int, nodes: Array[Node]) { + val split = nodeSplitStats._1 + val stats = nodeSplitStats._2 + val nodeIndex = scala.math.pow(2, level).toInt - 1 + index + val predict = { + val leftSamples = nodeSplitStats._2.leftSamples.toDouble + val rightSamples = nodeSplitStats._2.rightSamples.toDouble + val totalSamples = leftSamples + rightSamples + leftSamples / totalSamples + } + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) + val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats)) + logDebug("Node = " + node) + nodes(nodeIndex) = node } + private def extractInfoForLowerLevels(level: Int, index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), parentImpurities: Array[Double], filters: Array[List[Filter]]) { + for (i <- 0 to 1) { + + val nodeIndex = (scala.math.pow(2, level + 1)).toInt - 1 + 2 * index + i + + if (level < maxDepth - 1) { + + val impurity = if (i == 0) nodeSplitStats._2.leftImpurity else nodeSplitStats._2.rightImpurity + logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) + parentImpurities(nodeIndex) = impurity + filters(nodeIndex) = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) :: filters((nodeIndex - 1) / 2) + for (filter <- filters(nodeIndex)) { + logDebug("Filter = " + filter) + } + + } + } + } } -object DecisionTree extends Serializable { +object DecisionTree extends Serializable with Logging { /* Returns an Array[Split] of optimal splits for all nodes at a given level @@ -110,12 +152,12 @@ object DecisionTree extends Serializable { //Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt - println("numNodes = " + numNodes) + logDebug("numNodes = " + numNodes) //Find the number of features by looking at the first sample val numFeatures = input.take(1)(0).features.length - println("numFeatures = " + numFeatures) - val numSplits = strategy.numSplits - println("numSplits = " + numSplits) + logDebug("numFeatures = " + numFeatures) + val numSplits = strategy.numBins + logDebug("numSplits = " + numSplits) /*Find the filters used before reaching the current code*/ def findParentFilters(nodeIndex: Int): List[Filter] = { @@ -136,7 +178,7 @@ object DecisionTree extends Serializable { def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { //Leaf - if (parentFilters.length == 0 ){ + if ((level > 0) & (parentFilters.length == 0) ){ return false } @@ -156,9 +198,9 @@ object DecisionTree extends Serializable { /*Finds the right bin for the given feature*/ def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { - //println("finding bin for labeled point " + labeledPoint.features(featureIndex)) + //logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex)) //TODO: Do binary search - for (binIndex <- 0 until strategy.numSplits) { + for (binIndex <- 0 until strategy.numBins) { val bin = bins(featureIndex)(binIndex) //TODO: Remove this requirement post basic functional val lowThreshold = bin.lowSplit.threshold @@ -196,7 +238,7 @@ object DecisionTree extends Serializable { } } else { for (featureIndex <- 0 until numFeatures) { - //println("shift+featureIndex =" + (shift+featureIndex)) + //logDebug("shift+featureIndex =" + (shift+featureIndex)) arr(shift + featureIndex) = findBin(featureIndex, labeledPoint) } } @@ -239,7 +281,7 @@ object DecisionTree extends Serializable { //TODO: This length if different for regression val binAggregateLength = 2*numSplits * numFeatures * numNodes - println("binAggregageLength = " + binAggregateLength) + logDebug("binAggregageLength = " + binAggregateLength) /*Combines the aggregates from partitions @param agg1 Array containing aggregates from one or more partitions @@ -255,14 +297,14 @@ object DecisionTree extends Serializable { combinedAggregate } - println("input = " + input.count) + logDebug("input = " + input.count) val binMappedRDD = input.map(x => findBinsForLevel(x)) - println("binMappedRDD.count = " + binMappedRDD.count) + logDebug("binMappedRDD.count = " + binMappedRDD.count) //calculate bin aggregates val binAggregates = binMappedRDD.aggregate(Array.fill[Double](2*numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp) - println("binAggregates.length = " + binAggregates.length) - //binAggregates.foreach(x => println(x)) + logDebug("binAggregates.length = " + binAggregates.length) + //binAggregates.foreach(x => logDebug(x)) def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], @@ -312,21 +354,21 @@ object DecisionTree extends Serializable { def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) - //println("binData.length = " + binData.length) - //println("binData.sum = " + binData.sum) + //logDebug("binData.length = " + binData.length) + //logDebug("binData.sum = " + binData.sum) for (featureIndex <- 0 until numFeatures) { - //println("featureIndex = " + featureIndex) + //logDebug("featureIndex = " + featureIndex) val shift = 2*featureIndex*numSplits leftNodeAgg(featureIndex)(0) = binData(shift + 0) - //println("binData(shift + 0) = " + binData(shift + 0)) + //logDebug("binData(shift + 0) = " + binData(shift + 0)) leftNodeAgg(featureIndex)(1) = binData(shift + 1) - //println("binData(shift + 1) = " + binData(shift + 1)) + //logDebug("binData(shift + 1) = " + binData(shift + 1)) rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) - //println(binData(shift + (2 * (numSplits - 1)))) + //logDebug(binData(shift + (2 * (numSplits - 1)))) rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) - //println(binData(shift + (2 * (numSplits - 1)) + 1)) + //logDebug(binData(shift + (2 * (numSplits - 1)) + 1)) for (splitIndex <- 1 until numSplits - 1) { - //println("splitIndex = " + splitIndex) + //logDebug("splitIndex = " + splitIndex) leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) leftNodeAgg(featureIndex)(2 * splitIndex + 1) @@ -347,7 +389,7 @@ object DecisionTree extends Serializable { for (featureIndex <- 0 until numFeatures) { for (index <- 0 until numSplits -1) { - //println("splitIndex = " + index) + //logDebug("splitIndex = " + index) gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity) } } @@ -360,12 +402,12 @@ object DecisionTree extends Serializable { @param binData Array[Double] of size 2*numSplits*numFeatures */ def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, InformationGainStats) = { - println("node impurity = " + nodeImpurity) + logDebug("node impurity = " + nodeImpurity) val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - //println("gains.size = " + gains.size) - //println("gains(0).size = " + gains(0).size) + //logDebug("gains.size = " + gains.size) + //logDebug("gains(0).size = " + gains(0).size) val (bestFeatureIndex,bestSplitIndex, gainStats) = { var bestFeatureIndex = 0 @@ -378,13 +420,13 @@ object DecisionTree extends Serializable { for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numSplits - 1){ val gainStats = gains(featureIndex)(splitIndex) - //println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) + //logDebug("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) if(gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex - //println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex) - //println( "gain stats = " + bestGainStats) + //logDebug("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex) + //logDebug( "gain stats = " + bestGainStats) } } } @@ -400,9 +442,9 @@ object DecisionTree extends Serializable { val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node val shift = 2*node*numSplits*numFeatures val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures) - println("nodeImpurityIndex = " + nodeImpurityIndex) + logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) - println("node impurity = " + parentNodeImpurity) + logDebug("node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) } @@ -419,47 +461,42 @@ object DecisionTree extends Serializable { */ def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { - val numSplits = strategy.numSplits - println("numSplits = " + numSplits) + val count = input.count() + + //Find the number of features by looking at the first sample + val numFeatures = input.take(1)(0).features.length + + val maxBins = strategy.maxBins + val numBins = if (maxBins <= count) maxBins else count.toInt + logDebug("maxBins = " + numBins) //Calculate the number of sample for approximate quantile calculation //TODO: Justify this calculation - val requiredSamples = numSplits*numSplits - val count = input.count() + val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 - println("fraction of data used for calculating quantiles = " + fraction) - + logDebug("fraction of data used for calculating quantiles = " + fraction) //sampled input for RDD calculation val sampledInput = input.sample(false, fraction, 42).collect() val numSamples = sampledInput.length - //Find the number of features by looking at the first sample - val numFeatures = input.take(1)(0).features.length + val stride : Double = numSamples.toDouble/numBins + logDebug("stride = " + stride) strategy.quantileCalculationStrategy match { case "sort" => { - val splits = Array.ofDim[Split](numFeatures,numSplits-1) - val bins = Array.ofDim[Bin](numFeatures,numSplits) + val splits = Array.ofDim[Split](numFeatures,numBins-1) + val bins = Array.ofDim[Bin](numFeatures,numBins) //Find all splits for (featureIndex <- 0 until numFeatures){ val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - if (numSamples < numSplits) { - //TODO: Test this - println("numSamples = " + numSamples + ", less than numSplits = " + numSplits) - for (index <- 0 until numSplits-1) { - val split = new Split(featureIndex,featureSamples(index),"continuous") - splits(featureIndex)(index) = split - } - } else { - val stride : Double = numSamples.toDouble/numSplits - println("stride = " + stride) - for (index <- 0 until numSplits-1) { - val sampleIndex = (index+1)*stride.toInt - val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous") - splits(featureIndex)(index) = split - } + val stride : Double = numSamples.toDouble/numBins + logDebug("stride = " + stride) + for (index <- 0 until numBins-1) { + val sampleIndex = (index+1)*stride.toInt + val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous") + splits(featureIndex)(index) = split } } @@ -467,18 +504,18 @@ object DecisionTree extends Serializable { for (featureIndex <- 0 until numFeatures){ bins(featureIndex)(0) = new Bin(new DummyLowSplit("continuous"),splits(featureIndex)(0),"continuous") - for (index <- 1 until numSplits - 1){ + for (index <- 1 until numBins - 1){ val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),"continuous") bins(featureIndex)(index) = bin } - bins(featureIndex)(numSplits-1) - = new Bin(splits(featureIndex)(numSplits-3),new DummyHighSplit("continuous"),"continuous") + bins(featureIndex)(numBins-1) + = new Bin(splits(featureIndex)(numBins-3),new DummyHighSplit("continuous"),"continuous") } (splits,bins) } case "minMax" => { - (Array.ofDim[Split](numFeatures,numSplits),Array.ofDim[Bin](numFeatures,numSplits+2)) + (Array.ofDim[Split](numFeatures,numBins),Array.ofDim[Bin](numFeatures,numBins+2)) } case "approximateHistogram" => { throw new UnsupportedOperationException("approximate histogram not supported yet.") @@ -487,37 +524,6 @@ object DecisionTree extends Serializable { } } - def main(args: Array[String]) { - - val sc = new SparkContext(args(0), "DecisionTree") - val data = loadLabeledData(sc, args(1)) - val maxDepth = args(2).toInt - - val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = maxDepth, numSplits = 569) - val model = new DecisionTree(strategy).train(data) - - sc.stop() - } - - /** - * Load labeled data from a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. - * - * @param sc SparkContext - * @param dir Directory to the input data files. - * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is - * the label, and the second element represents the feature values (an array of Double). - */ - def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { - sc.textFile(dir).map { line => - val parts = line.trim().split(",") - val label = parts(0).toDouble - val features = parts.slice(1,parts.length).map(_.toDouble) - LabeledPoint(label, features) - } - } - } \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala new file mode 100644 index 0000000000000..542a3d9c3b33d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.mllib.tree.impurity.Gini +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.DecisionTreeModel + +object DecisionTreeRunner extends Logging { + + + def main(args: Array[String]) { + + val sc = new SparkContext(args(0), "DecisionTree") + val data = loadLabeledData(sc, args(1)) + val maxDepth = args(2).toInt + val maxBins = args(3).toInt + + val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = maxDepth, maxBins = maxBins) + val model = new DecisionTree(strategy).train(data) + + val accuracy = accuracyScore(model, data) + logDebug("accuracy = " + accuracy) + + sc.stop() + } + + /** + * Load labeled data from a file. The data format used here is + * , ... + * where , are feature values in Double and is the corresponding label as Double. + * + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is + * the label, and the second element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { + sc.textFile(dir).map { line => + val parts = line.trim().split(",") + val label = parts(0).toDouble + val features = parts.slice(1,parts.length).map(_.toDouble) + LabeledPoint(label, features) + } + } + + //TODO: Port them to a metrics package + def accuracyScore(model : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = { + val correctCount = data.filter(y => model.predict(y.features) == y.label).count() + val count = data.count() + logDebug("correct prediction count = " + correctCount) + logDebug("data count = " + count) + correctCount.toDouble / count + } + + + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala index 7f88053043e0a..c688a478ce0d2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala @@ -18,11 +18,13 @@ package org.apache.spark.mllib.tree import org.apache.spark.mllib.tree.impurity.Impurity -case class Strategy ( +class Strategy ( val kind : String, val impurity : Impurity, val maxDepth : Int, - val numSplits : Int, - val quantileCalculationStrategy : String = "sort") { + val maxBins : Int, + val quantileCalculationStrategy : String = "sort") extends Serializable { + + var numBins : Int = Int.MinValue } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index d0465d8c6fb6f..1d7c03289c407 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -16,6 +16,10 @@ */ package org.apache.spark.mllib.tree.model -class DecisionTreeModel { +import org.apache.spark.mllib.regression.LabeledPoint + +class DecisionTreeModel(val topNode : Node) extends Serializable { + + def predict(features : Array[Double]) = if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 4ca02beec03c0..60a4f99b7f806 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -21,7 +21,7 @@ class InformationGainStats(val gain : Double, val leftImpurity : Double, val leftSamples : Long, val rightImpurity : Double, - val rightSamples : Long) { + val rightSamples : Long) extends Serializable { override def toString = "gain = " + gain + ", impurity = " + impurity + ", left impurity = " diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala new file mode 100644 index 0000000000000..a9210e10ae48b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.model + +import org.apache.spark.Logging +import org.apache.spark.mllib.regression.LabeledPoint + +class Node ( val id : Int, + val predict : Double, + val isLeaf : Boolean, + val split : Option[Split], + var leftNode : Option[Node], + var rightNode : Option[Node], + val stats : Option[InformationGainStats] + ) extends Serializable with Logging{ + + override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", split = " + split + ", stats = " + stats + + def build(nodes : Array[Node]) : Unit = { + + logDebug("building node " + id + " at level " + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) + logDebug("stats = " + stats) + logDebug("predict = " + predict) + if (!isLeaf) { + val leftNodeIndex = id*2 + 1 + val rightNodeIndex = id*2 + 2 + leftNode = Some(nodes(leftNodeIndex)) + rightNode = Some(nodes(rightNodeIndex)) + leftNode.get.build(nodes) + rightNode.get.build(nodes) + } + } + + def predictIfLeaf(feature : Array[Double]) : Double = { + if (isLeaf) { + predict + } else{ + if (feature(split.get.feature) <= split.get.threshold) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } + } + } + +} From 02c595c65f784061b1a78d4cbd5cac5990d1881d Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 22 Dec 2013 12:00:17 -0800 Subject: [PATCH 11/48] added command line parsing Signed-off-by: Manish Amde --- .../classification/ClassificationTree.scala | 21 ------- .../spark/mllib/tree/DecisionTree.scala | 2 +- .../spark/mllib/tree/DecisionTreeRunner.scala | 58 ++++++++++++++++--- 3 files changed, 52 insertions(+), 29 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala deleted file mode 100644 index a6f27e7fd1111..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.mllib.classification - -class ClassificationTree { - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index d8ffa12030f8d..b8cfe03aa151b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -285,7 +285,7 @@ object DecisionTree extends Serializable with Logging { /*Combines the aggregates from partitions @param agg1 Array containing aggregates from one or more partitions - @param agg2 Array contianing aggregates from one or more partitions + @param agg2 Array containing aggregates from one or more partitions @return Combined aggregate from agg1 and agg2 */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala index 542a3d9c3b33d..d46733336d558 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala @@ -17,25 +17,69 @@ package org.apache.spark.mllib.tree import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.mllib.tree.impurity.Gini +import org.apache.spark.mllib.tree.impurity.{Gini,Entropy,Variance} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.DecisionTreeModel object DecisionTreeRunner extends Logging { + val usage = """ + Usage: DecisionTreeRunner [slices] --kind --trainDataDir path --testDataDir path [--maxDepth num] [--impurity ] [--maxBins num] + """ + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println(usage) + System.exit(1) + } + val sc = new SparkContext(args(0), "DecisionTree") - val data = loadLabeledData(sc, args(1)) - val maxDepth = args(2).toInt - val maxBins = args(3).toInt - val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = maxDepth, maxBins = maxBins) - val model = new DecisionTree(strategy).train(data) - val accuracy = accuracyScore(model, data) + val arglist = args.toList.drop(1) + type OptionMap = Map[Symbol, Any] + + def nextOption(map : OptionMap, list: List[String]) : OptionMap = { + def isSwitch(s : String) = (s(0) == '-') + list match { + case Nil => map + case "--kind" :: string :: tail => nextOption(map ++ Map('kind -> string), tail) + case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) + case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) + case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) + case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail) + case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail) + case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) + case option :: tail => println("Unknown option "+option) + exit(1) + } + } + val options = nextOption(Map(),arglist) + logDebug(options.toString()) + + val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) + + val typeStr = options.get('type).toString + //TODO: Create enum + val impurityStr = options.getOrElse('impurity,if (typeStr == "classification") "Gini" else "Variance").toString + val impurity = { + impurityStr match { + case "Gini" => Gini + case "Entropy" => Entropy + case "Variance" => Variance + } + } + val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt + val maxBins = options.getOrElse('maxBins,"100").toString.toInt + + val strategy = new Strategy(kind = typeStr, impurity = Gini, maxDepth = maxDepth, maxBins = maxBins) + val model = new DecisionTree(strategy).train(trainData) + + val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) + val accuracy = accuracyScore(model, testData) logDebug("accuracy = " + accuracy) sc.stop() From 733d6ddf51ddf440efb1a17c818da6d7fd027c4b Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 22 Dec 2013 12:20:50 -0800 Subject: [PATCH 12/48] fixed tests Signed-off-by: Manish Amde --- .../apache/spark/mllib/tree/DecisionTreeSuite.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2b5988bb3c6a3..6b6bb2b55e7b5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -53,7 +53,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins.length==2) assert(splits(0).length==99) assert(bins(0).length==100) - println(splits(1)(98)) + //println(splits(1)(98)) } test("stump with fixed label 0 for Gini"){ @@ -68,7 +68,10 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) assert(splits(0).length==99) assert(bins(0).length==100) - val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + + strategy.numBins = 100 + val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) + println("here") assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) @@ -91,6 +94,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) assert(splits(0).length==99) assert(bins(0).length==100) + + strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) @@ -115,6 +120,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) assert(splits(0).length==99) assert(bins(0).length==100) + + strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) @@ -138,6 +145,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) assert(splits(0).length==99) assert(bins(0).length==100) + + strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) From 154aa77c925e44a92e8bbf2f55e43cab06e75006 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 22 Dec 2013 22:51:17 -0800 Subject: [PATCH 13/48] enums for configurations Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 24 ++++--------- .../spark/mllib/tree/DecisionTreeRunner.scala | 35 +++++++++++++------ .../spark/mllib/tree/configuration/Algo.scala | 22 ++++++++++++ .../tree/configuration/QuantileStrategy.scala | 22 ++++++++++++ .../tree/{ => configuration}/Strategy.scala | 8 +++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 1 + 6 files changed, 81 insertions(+), 31 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala rename mllib/src/main/scala/org/apache/spark/mllib/tree/{ => configuration}/Strategy.scala (77%) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b8cfe03aa151b..9cd1597e6fa18 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -23,8 +23,9 @@ import org.apache.spark.mllib.tree.model._ import org.apache.spark.{SparkContext, Logging} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.Split -import org.apache.spark.mllib.tree.impurity.Gini import scala.util.control.Breaks._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { @@ -34,8 +35,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { //Cache input RDD for speedup during multiple passes input.cache() - //TODO: Find all splits and bins using quantiles including support for categorical features, single-pass - //TODO: Think about broadcasting this val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) logDebug("numSplits = " + bins(0).length) strategy.numBins = bins(0).length @@ -133,7 +132,7 @@ object DecisionTree extends Serializable with Logging { @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree @param parentImpurities Impurities for all parent nodes for the current level - @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree + @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing parameters for construction the DecisionTree @param level Level of the tree @param filters Filter for all nodes at a given level @param splits possible splits for all features @@ -406,27 +405,18 @@ object DecisionTree extends Serializable with Logging { val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - //logDebug("gains.size = " + gains.size) - //logDebug("gains(0).size = " + gains(0).size) - val (bestFeatureIndex,bestSplitIndex, gainStats) = { var bestFeatureIndex = 0 var bestSplitIndex = 0 //Initialization with infeasible values var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,0,-1.0,0) -// var maxGain = Double.MinValue -// var leftSamples = Long.MinValue -// var rightSamples = Long.MinValue for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numSplits - 1){ val gainStats = gains(featureIndex)(splitIndex) - //logDebug("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) if(gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex - //logDebug("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex) - //logDebug( "gain stats = " + bestGainStats) } } } @@ -455,7 +445,7 @@ object DecisionTree extends Serializable with Logging { Returns split and bins for decision tree calculation. @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree - @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree + @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing parameters for construction the DecisionTree @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an Array[Array[Bin]] of size (numFeatures,numSplits1) */ @@ -483,7 +473,7 @@ object DecisionTree extends Serializable with Logging { logDebug("stride = " + stride) strategy.quantileCalculationStrategy match { - case "sort" => { + case Sort => { val splits = Array.ofDim[Split](numFeatures,numBins-1) val bins = Array.ofDim[Bin](numFeatures,numBins) @@ -514,10 +504,10 @@ object DecisionTree extends Serializable with Logging { (splits,bins) } - case "minMax" => { + case MinMax => { (Array.ofDim[Split](numFeatures,numBins),Array.ofDim[Bin](numFeatures,numBins+2)) } - case "approximateHistogram" => { + case ApproxHist => { throw new UnsupportedOperationException("approximate histogram not supported yet.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala index d46733336d558..65b5ab1162597 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala @@ -21,11 +21,14 @@ import org.apache.spark.mllib.tree.impurity.{Gini,Entropy,Variance} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.Algo._ + object DecisionTreeRunner extends Logging { val usage = """ - Usage: DecisionTreeRunner [slices] --kind --trainDataDir path --testDataDir path [--maxDepth num] [--impurity ] [--maxBins num] + Usage: DecisionTreeRunner [slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] """ @@ -46,39 +49,49 @@ object DecisionTreeRunner extends Logging { def isSwitch(s : String) = (s(0) == '-') list match { case Nil => map - case "--kind" :: string :: tail => nextOption(map ++ Map('kind -> string), tail) + case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail) case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail) case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail) case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) - case option :: tail => println("Unknown option "+option) - exit(1) + case option :: tail => logError("Unknown option "+option) + sys.exit(1) } } val options = nextOption(Map(),arglist) logDebug(options.toString()) + //TODO: Add validation for input parameters + //Load training data val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) - val typeStr = options.get('type).toString - //TODO: Create enum - val impurityStr = options.getOrElse('impurity,if (typeStr == "classification") "Gini" else "Variance").toString - val impurity = { - impurityStr match { + //Figure out the type of algorithm + val algoStr = options.get('algo).get.toString + val algo = algoStr match { + case "Classification" => Classification + case "Regression" => Regression + } + + //Identify the type of impurity + val impurityStr = options.getOrElse('impurity,if (algo == Classification) "Gini" else "Variance").toString + val impurity = impurityStr match { case "Gini" => Gini case "Entropy" => Entropy case "Variance" => Variance } - } + val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt val maxBins = options.getOrElse('maxBins,"100").toString.toInt - val strategy = new Strategy(kind = typeStr, impurity = Gini, maxDepth = maxDepth, maxBins = maxBins) + val strategy = new Strategy(algo = algo, impurity = impurity, maxDepth = maxDepth, maxBins = maxBins) val model = new DecisionTree(strategy).train(trainData) + //Load test data val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) + + //Measure algorithm accuracy val accuracy = accuracyScore(model, testData) logDebug("accuracy = " + accuracy) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala new file mode 100644 index 0000000000000..7cd128e381e8f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.configuration + +object Algo extends Enumeration { + type Algo = Value + val Classification, Regression = Value +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala new file mode 100644 index 0000000000000..1bbd2d8c1fe92 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.configuration + +object QuantileStrategy extends Enumeration { + type QuantileStrategy = Value + val Sort, MinMax, ApproxHist = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala similarity index 77% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala rename to mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index c688a478ce0d2..3c759bbc1c805 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -14,16 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.mllib.tree +package org.apache.spark.mllib.tree.configuration import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ class Strategy ( - val kind : String, + val algo : Algo, val impurity : Impurity, val maxDepth : Int, val maxBins : Int, - val quantileCalculationStrategy : String = "sort") extends Serializable { + val quantileCalculationStrategy : QuantileStrategy = Sort) extends Serializable { var numBins : Int = Int.MinValue diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 6b6bb2b55e7b5..86cd1f432d162 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini} import org.apache.spark.mllib.tree.model.Filter +import org.apache.spark.mllib.tree.configuration.Strategy class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { From b0e3e76c47b1b449c91832aee2a6e94cee0a7c6b Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 12 Jan 2014 11:45:47 -0800 Subject: [PATCH 14/48] adding enum for feature type Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 43 ++++++++++--------- .../mllib/tree/configuration/Strategy.scala | 3 +- .../apache/spark/mllib/tree/model/Bin.scala | 4 +- .../apache/spark/mllib/tree/model/Split.scala | 10 +++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 11 ++--- 5 files changed, 40 insertions(+), 31 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 9cd1597e6fa18..1665d0ee1ffb9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -26,6 +26,7 @@ import org.apache.spark.mllib.tree.model.Split import scala.util.control.Breaks._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.FeatureType._ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { @@ -353,21 +354,13 @@ object DecisionTree extends Serializable with Logging { def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) - //logDebug("binData.length = " + binData.length) - //logDebug("binData.sum = " + binData.sum) for (featureIndex <- 0 until numFeatures) { - //logDebug("featureIndex = " + featureIndex) val shift = 2*featureIndex*numSplits leftNodeAgg(featureIndex)(0) = binData(shift + 0) - //logDebug("binData(shift + 0) = " + binData(shift + 0)) leftNodeAgg(featureIndex)(1) = binData(shift + 1) - //logDebug("binData(shift + 1) = " + binData(shift + 1)) rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) - //logDebug(binData(shift + (2 * (numSplits - 1)))) rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) - //logDebug(binData(shift + (2 * (numSplits - 1)) + 1)) for (splitIndex <- 1 until numSplits - 1) { - //logDebug("splitIndex = " + splitIndex) leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) leftNodeAgg(featureIndex)(2 * splitIndex + 1) @@ -479,33 +472,43 @@ object DecisionTree extends Serializable with Logging { //Find all splits for (featureIndex <- 0 until numFeatures){ - val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - - val stride : Double = numSamples.toDouble/numBins - logDebug("stride = " + stride) - for (index <- 0 until numBins-1) { - val sampleIndex = (index+1)*stride.toInt - val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous") - splits(featureIndex)(index) = split + val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinous) { + val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted + + val stride : Double = numSamples.toDouble/numBins + logDebug("stride = " + stride) + for (index <- 0 until numBins-1) { + val sampleIndex = (index+1)*stride.toInt + val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous) + splits(featureIndex)(index) = split + } + } else { + val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) + for (index <- 0 until maxFeatureValue){ + //TODO: Sort by centriod + val split = new Split(featureIndex,index,Categorical) + splits(featureIndex)(index) = split + } } } //Find all bins for (featureIndex <- 0 until numFeatures){ bins(featureIndex)(0) - = new Bin(new DummyLowSplit("continuous"),splits(featureIndex)(0),"continuous") + = new Bin(new DummyLowSplit(Continuous),splits(featureIndex)(0),Continuous) for (index <- 1 until numBins - 1){ - val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),"continuous") + val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Continuous) bins(featureIndex)(index) = bin } bins(featureIndex)(numBins-1) - = new Bin(splits(featureIndex)(numBins-3),new DummyHighSplit("continuous"),"continuous") + = new Bin(splits(featureIndex)(numBins-3),new DummyHighSplit(Continuous),Continuous) } (splits,bins) } case MinMax => { - (Array.ofDim[Split](numFeatures,numBins),Array.ofDim[Bin](numFeatures,numBins+2)) + throw new UnsupportedOperationException("minmax not supported yet.") } case ApproxHist => { throw new UnsupportedOperationException("approximate histogram not supported yet.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 3c759bbc1c805..281dabd3364d8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -25,7 +25,8 @@ class Strategy ( val impurity : Impurity, val maxDepth : Int, val maxBins : Int, - val quantileCalculationStrategy : QuantileStrategy = Sort) extends Serializable { + val quantileCalculationStrategy : QuantileStrategy = Sort, + val categoricalFeaturesInfo : Map[Int,Int] = Map[Int,Int]()) extends Serializable { var numBins : Int = Int.MinValue diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 25d16a9a2fc2f..13191851956ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.mllib.tree.model -case class Bin(lowSplit : Split, highSplit : Split, kind : String) { +import org.apache.spark.mllib.tree.configuration.FeatureType._ + +case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType) { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 1b39154d42e47..01aa349115302 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -16,11 +16,13 @@ */ package org.apache.spark.mllib.tree.model -case class Split(feature: Int, threshold : Double, kind : String){ - override def toString = "Feature = " + feature + ", threshold = " + threshold + ", kind = " + kind +import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType + +case class Split(feature: Int, threshold : Double, featureType : FeatureType){ + override def toString = "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType } -class DummyLowSplit(kind : String) extends Split(Int.MinValue, Double.MinValue, kind) +class DummyLowSplit(kind : FeatureType) extends Split(Int.MinValue, Double.MinValue, kind) -class DummyHighSplit(kind : String) extends Split(Int.MaxValue, Double.MaxValue, kind) +class DummyHighSplit(kind : FeatureType) extends Split(Int.MaxValue, Double.MaxValue, kind) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 86cd1f432d162..6097c6d5ac985 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.Algo._ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { @@ -48,7 +49,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy("regression",Gini,3,100,"sort") + val strategy = new Strategy(Regression,Gini,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(bins.length==2) @@ -61,7 +62,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy("regression",Gini,3,100,"sort") + val strategy = new Strategy(Regression,Gini,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) @@ -87,7 +88,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy("regression",Gini,3,100,"sort") + val strategy = new Strategy(Regression,Gini,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) @@ -113,7 +114,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy("regression",Entropy,3,100,"sort") + val strategy = new Strategy(Regression,Entropy,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) @@ -138,7 +139,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy("regression",Entropy,3,100,"sort") + val strategy = new Strategy(Regression,Entropy,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) From c8f6d60c45ec7ec8cfac94b43fb22d8c294221db Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 12 Jan 2014 11:46:55 -0800 Subject: [PATCH 15/48] adding enum for feature type Signed-off-by: Manish Amde --- .../tree/configuration/FeatureType.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala new file mode 100644 index 0000000000000..a725bf388fe29 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.configuration + +object FeatureType extends Enumeration { + type FeatureType = Value + val Continuous, Categorical = Value +} From e23c2e5089a2bf2a50c5d3f52e5799bf76ca3a16 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 19 Jan 2014 13:23:45 -0800 Subject: [PATCH 16/48] added regression support Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 261 +++++++++++++----- .../spark/mllib/tree/DecisionTreeRunner.scala | 12 + .../spark/mllib/tree/impurity/Entropy.scala | 6 +- .../spark/mllib/tree/impurity/Gini.scala | 5 +- .../spark/mllib/tree/impurity/Impurity.scala | 2 + .../spark/mllib/tree/impurity/Variance.scala | 11 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 10 +- 7 files changed, 231 insertions(+), 76 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1665d0ee1ffb9..1ff8c05bcb790 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -27,6 +27,7 @@ import scala.util.control.Breaks._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.Algo._ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { @@ -51,6 +52,9 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { //parentImpurities(0) = Double.MinValue val nodes = new Array[Node](maxNumNodes) + logDebug("algo = " + strategy.algo) + + breakable { for (level <- 0 until maxDepth){ @@ -247,8 +251,47 @@ object DecisionTree extends Serializable with Logging { arr } - /* - Performs a sequential aggregation over a partition. + def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { + for (node <- 0 until numNodes) { + val validSignalIndex = 1 + numFeatures * node + val isSampleValidForNode = if (arr(validSignalIndex) != -1) true else false + if (isSampleValidForNode) { + val label = arr(0) + for (feature <- 0 until numFeatures) { + val arrShift = 1 + numFeatures * node + val aggShift = 2 * numSplits * numFeatures * node + val arrIndex = arrShift + feature + val aggIndex = aggShift + 2 * feature * numSplits + arr(arrIndex).toInt * 2 + label match { + case (0.0) => agg(aggIndex) = agg(aggIndex) + 1 + case (1.0) => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 + } + } + } + } + } + + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { + for (node <- 0 until numNodes) { + val validSignalIndex = 1 + numFeatures * node + val isSampleValidForNode = if (arr(validSignalIndex) != -1) true else false + if (isSampleValidForNode) { + val label = arr(0) + for (feature <- 0 until numFeatures) { + val arrShift = 1 + numFeatures * node + val aggShift = 3 * numSplits * numFeatures * node + val arrIndex = arrShift + feature + val aggIndex = aggShift + 3 * feature * numSplits + arr(arrIndex).toInt * 3 + //count, sum, sum^2 + agg(aggIndex) = agg(aggIndex) + 1 + agg(aggIndex + 1) = agg(aggIndex + 1) + label + agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + } + } + } + } + + /*Performs a sequential aggregation over a partition. for p bins, k features, l nodes (level = log2(l)) storage is of the form: b111_left_count,b111_right_count, .... , bpk1_left_count, bpk1_right_count, .... , bpkl_left_count, bpkl_right_count @@ -256,32 +299,23 @@ object DecisionTree extends Serializable with Logging { @param agg Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification @param arr Array[Double] of size 1+(numFeatures*numNodes) @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification - */ + */ def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { //TODO: Requires logic for regressions - for (node <- 0 until numNodes) { - val validSignalIndex = 1+numFeatures*node - val isSampleValidForNode = if(arr(validSignalIndex) != -1) true else false - if(isSampleValidForNode){ - val label = arr(0) - for (feature <- 0 until numFeatures){ - val arrShift = 1 + numFeatures*node - val aggShift = 2*numSplits*numFeatures*node - val arrIndex = arrShift + feature - val aggIndex = aggShift + 2*feature*numSplits + arr(arrIndex).toInt*2 - label match { - case(0.0) => agg(aggIndex) = agg(aggIndex) + 1 - case(1.0) => agg(aggIndex+1) = agg(aggIndex+1) + 1 - } - } - } + strategy.algo match { + case Classification => classificationBinSeqOp(arr, agg) + //TODO: Implement this + case Regression => regressionBinSeqOp(arr, agg) } agg } - //TODO: This length if different for regression - val binAggregateLength = 2*numSplits * numFeatures * numNodes - logDebug("binAggregageLength = " + binAggregateLength) + //TODO: This length is different for regression + val binAggregateLength = strategy.algo match { + case Classification => 2*numSplits * numFeatures * numNodes + case Regression => 3*numSplits * numFeatures * numNodes + } + logDebug("binAggregateLength = " + binAggregateLength) /*Combines the aggregates from partitions @param agg1 Array containing aggregates from one or more partitions @@ -290,11 +324,22 @@ object DecisionTree extends Serializable with Logging { @return Combined aggregate from agg1 and agg2 */ def binCombOp(agg1 : Array[Double], agg2: Array[Double]) : Array[Double] = { - val combinedAggregate = new Array[Double](binAggregateLength) - for (index <- 0 until binAggregateLength){ - combinedAggregate(index) = agg1(index) + agg2(index) + strategy.algo match { + case Classification => { + val combinedAggregate = new Array[Double](binAggregateLength) + for (index <- 0 until binAggregateLength){ + combinedAggregate(index) = agg1(index) + agg2(index) + } + combinedAggregate + } + case Regression => { + val combinedAggregate = new Array[Double](binAggregateLength) + for (index <- 0 until binAggregateLength){ + combinedAggregate(index) = agg1(index) + agg2(index) + } + combinedAggregate + } } - combinedAggregate } logDebug("input = " + input.count) @@ -302,7 +347,7 @@ object DecisionTree extends Serializable with Logging { logDebug("binMappedRDD.count = " + binMappedRDD.count) //calculate bin aggregates - val binAggregates = binMappedRDD.aggregate(Array.fill[Double](2*numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp) + val binAggregates = binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) logDebug("binAggregates.length = " + binAggregates.length) //binAggregates.foreach(x => logDebug(x)) @@ -312,36 +357,70 @@ object DecisionTree extends Serializable with Logging { index: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double) : InformationGainStats = { + strategy.algo match { + case Classification => { - val left0Count = leftNodeAgg(featureIndex)(2 * index) - val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) - val leftCount = left0Count + left1Count + val left0Count = leftNodeAgg(featureIndex)(2 * index) + val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) + val leftCount = left0Count + left1Count - val right0Count = rightNodeAgg(featureIndex)(2 * index) - val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) - val rightCount = right0Count + right1Count + val right0Count = rightNodeAgg(featureIndex)(2 * index) + val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) + val rightCount = right0Count + right1Count - val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) + val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) - if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) - if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) + if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) + if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) - val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) - val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) + val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) + val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + + new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) } - } + case Regression => { + val leftCount = leftNodeAgg(featureIndex)(3 * index) + val leftSum = leftNodeAgg(featureIndex)(3 * index + 1) + val leftSumSquares = leftNodeAgg(featureIndex)(3 * index + 2) + + val rightCount = rightNodeAgg(featureIndex)(3 * index) + val rightSum = rightNodeAgg(featureIndex)(3 * index + 1) + val rightSumSquares = rightNodeAgg(featureIndex)(3 * index + 2) + + val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares) + + if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) + if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) + + val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) + val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) - new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) + + } + } } /* @@ -352,26 +431,60 @@ object DecisionTree extends Serializable with Logging { each array is of size(numFeature,2*(numSplits-1)) */ def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { - val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) - for (featureIndex <- 0 until numFeatures) { - val shift = 2*featureIndex*numSplits - leftNodeAgg(featureIndex)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(1) = binData(shift + 1) - rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) - rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) - for (splitIndex <- 1 until numSplits - 1) { - leftNodeAgg(featureIndex)(2 * splitIndex) - = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) - leftNodeAgg(featureIndex)(2 * splitIndex + 1) - = binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) - rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex)) - = binData(shift + (2 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex)) - rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex) + 1) - = binData(shift + (2 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex) + 1) + strategy.algo match { + case Classification => { + + val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) + for (featureIndex <- 0 until numFeatures) { + val shift = 2*featureIndex*numSplits + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) + rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) + for (splitIndex <- 1 until numSplits - 1) { + leftNodeAgg(featureIndex)(2 * splitIndex) + = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) + leftNodeAgg(featureIndex)(2 * splitIndex + 1) + = binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) + rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex)) + = binData(shift + (2 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex)) + rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex) + 1) + = binData(shift + (2 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex) + 1) + } + } + (leftNodeAgg, rightNodeAgg) + } + case Regression => { + + val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numSplits - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numSplits - 1)) + for (featureIndex <- 0 until numFeatures) { + val shift = 3*featureIndex*numSplits + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(2) = binData(shift + 2) + rightNodeAgg(featureIndex)(3 * (numSplits - 2)) = binData(shift + (3 * (numSplits - 1))) + rightNodeAgg(featureIndex)(3 * (numSplits - 2) + 1) = binData(shift + (3 * (numSplits - 1)) + 1) + rightNodeAgg(featureIndex)(3 * (numSplits - 2) + 2) = binData(shift + (3 * (numSplits - 1)) + 2) + for (splitIndex <- 1 until numSplits - 1) { + leftNodeAgg(featureIndex)(3 * splitIndex) + = binData(shift + 3*splitIndex) + leftNodeAgg(featureIndex)(3 * splitIndex - 3) + leftNodeAgg(featureIndex)(3 * splitIndex + 1) + = binData(shift + 3*splitIndex + 1) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) + leftNodeAgg(featureIndex)(3 * splitIndex + 2) + = binData(shift + 3*splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) + rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex)) + = binData(shift + (3 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex)) + rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 1) + = binData(shift + (3 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 1) + rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 2) + = binData(shift + (3 * (numSplits - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 2) + } + } + (leftNodeAgg, rightNodeAgg) } } - (leftNodeAgg, rightNodeAgg) } def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double) @@ -421,10 +534,24 @@ object DecisionTree extends Serializable with Logging { //Calculate best splits for all nodes at a given level val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + def getBinDataForNode(node: Int): Array[Double] = { + strategy.algo match { + case Classification => { + val shift = 2 * node * numSplits * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 2 * numSplits * numFeatures) + binsForNode + } + case Regression => { + val shift = 3 * node * numSplits * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numSplits * numFeatures) + binsForNode + } + } + } + for (node <- 0 until numNodes){ val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node - val shift = 2*node*numSplits*numFeatures - val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures) + val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) logDebug("node impurity = " + parentNodeImpurity) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala index 65b5ab1162597..ae18cb0aaa4e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.mllib.tree +import org.apache.spark.SparkContext._ import org.apache.spark.{Logging, SparkContext} import org.apache.spark.mllib.tree.impurity.{Gini,Entropy,Variance} import org.apache.spark.rdd.RDD @@ -95,6 +96,9 @@ object DecisionTreeRunner extends Logging { val accuracy = accuracyScore(model, testData) logDebug("accuracy = " + accuracy) + val mse = meanSquaredError(model,testData) + logDebug("mean square error = " + mse) + sc.stop() } @@ -126,6 +130,14 @@ object DecisionTreeRunner extends Logging { correctCount.toDouble / count } + //TODO: Make these generic MLTable metrics + def meanSquaredError(tree : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = { + val meanSumOfSquares = data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean() + println("meanSumOfSquares = " + meanSumOfSquares) + meanSumOfSquares + } + + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 00feb25e25322..350627e9de1dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.mllib.tree.impurity +import javax.naming.OperationNotSupportedException + object Entropy extends Impurity { def log2(x: Double) = scala.math.log(x) / scala.math.log(2) @@ -31,4 +33,6 @@ object Entropy extends Impurity { } } - } + def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new OperationNotSupportedException("Entropy.calculate") +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 3396a015e7858..8befeb5a475f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.mllib.tree.impurity +import javax.naming.OperationNotSupportedException + object Gini extends Impurity { def calculate(c0 : Double, c1 : Double): Double = { @@ -29,4 +31,5 @@ object Gini extends Impurity { } } - } + def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new OperationNotSupportedException("Gini.calculate") +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 4b6e679820f59..cda534b462234 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -20,4 +20,6 @@ trait Impurity extends Serializable { def calculate(c0 : Double, c1 : Double): Double + def calculate(count : Double, sum : Double, sumSquares : Double) : Double + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 98f332122785e..65f5b3702779a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -17,7 +17,14 @@ package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException +import org.apache.spark.Logging -object Variance extends Impurity { +object Variance extends Impurity with Logging { def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") - } + + def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + val squaredLoss = sumSquares - (sum*sum)/count + squaredLoss/count + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 6097c6d5ac985..5f9aad0de2f65 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -49,7 +49,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Gini,3,100) + val strategy = new Strategy(Classification,Gini,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(bins.length==2) @@ -62,7 +62,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Gini,3,100) + val strategy = new Strategy(Classification,Gini,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) @@ -88,7 +88,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Gini,3,100) + val strategy = new Strategy(Classification,Gini,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) @@ -114,7 +114,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Entropy,3,100) + val strategy = new Strategy(Classification,Entropy,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) @@ -139,7 +139,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Entropy,3,100) + val strategy = new Strategy(Classification,Entropy,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) From 53108ed6ad241765757c1e4c68189035505b370f Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 19 Jan 2014 16:56:15 -0800 Subject: [PATCH 17/48] fixing index for highest bin Signed-off-by: Manish Amde --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1ff8c05bcb790..975dd4f0cd7e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -629,7 +629,7 @@ object DecisionTree extends Serializable with Logging { bins(featureIndex)(index) = bin } bins(featureIndex)(numBins-1) - = new Bin(splits(featureIndex)(numBins-3),new DummyHighSplit(Continuous),Continuous) + = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(Continuous),Continuous) } (splits,bins) From 6df35b9e70701528b13b33820b687f295bcfb3a4 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 20 Jan 2014 20:33:52 -0800 Subject: [PATCH 18/48] regression predict logic Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 26 ++++++++----------- .../mllib/tree/model/DecisionTreeModel.scala | 14 ++++++++-- .../tree/model/InformationGainStats.scala | 8 +++--- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 975dd4f0cd7e7..e8adef377481c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -87,7 +87,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { topNode.build(nodes) val decisionTreeModel = { - return new DecisionTreeModel(topNode) + return new DecisionTreeModel(topNode, strategy.algo) } return decisionTreeModel @@ -98,14 +98,8 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = scala.math.pow(2, level).toInt - 1 + index - val predict = { - val leftSamples = nodeSplitStats._2.leftSamples.toDouble - val rightSamples = nodeSplitStats._2.rightSamples.toDouble - val totalSamples = leftSamples + rightSamples - leftSamples / totalSamples - } val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) - val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats)) + val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node } @@ -370,8 +364,8 @@ object DecisionTree extends Serializable with Logging { val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) - if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) - if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) + if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity,1) + if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,Double.MinValue,0) val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) @@ -387,7 +381,9 @@ object DecisionTree extends Serializable with Logging { } } - new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) + val predict = leftCount / (leftCount + rightCount) + + new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict) } case Regression => { val leftCount = leftNodeAgg(featureIndex)(3 * index) @@ -400,8 +396,8 @@ object DecisionTree extends Serializable with Logging { val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares) - if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) - if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) + if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity,rightSum/rightCount) + if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,Double.MinValue,leftSum/leftCount) val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) @@ -417,7 +413,7 @@ object DecisionTree extends Serializable with Logging { } } - new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) + new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,(leftSum + rightSum)/(leftCount+rightCount)) } } @@ -515,7 +511,7 @@ object DecisionTree extends Serializable with Logging { var bestFeatureIndex = 0 var bestSplitIndex = 0 //Initialization with infeasible values - var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,0,-1.0,0) + var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1) for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numSplits - 1){ val gainStats = gains(featureIndex)(splitIndex) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 1d7c03289c407..587e549c34ca8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -17,9 +17,19 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ -class DecisionTreeModel(val topNode : Node) extends Serializable { +class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializable { - def predict(features : Array[Double]) = if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0 + def predict(features : Array[Double]) = { + algo match { + case Classification => { + if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0 + } + case Regression => { + topNode.predictIfLeaf(features) + } + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 60a4f99b7f806..b992684b2b05b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -19,14 +19,14 @@ package org.apache.spark.mllib.tree.model class InformationGainStats(val gain : Double, val impurity: Double, val leftImpurity : Double, - val leftSamples : Long, + //val leftSamples : Long, val rightImpurity : Double, - val rightSamples : Long) extends Serializable { + //val rightSamples : Long + val predict : Double) extends Serializable { override def toString = "gain = " + gain + ", impurity = " + impurity + ", left impurity = " - + leftImpurity + ", leftSamples = " + leftSamples + ", right impurity = " - + rightImpurity + ", rightSamples = " + rightSamples + + leftImpurity + ", right impurity = " + rightImpurity + ", predict = " + predict } From dbb7ac13d28fba0848062a7bea40c617cb5f2c80 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 22 Jan 2014 20:44:23 -0800 Subject: [PATCH 19/48] categorical feature support Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 127 +++++++++++++----- .../apache/spark/mllib/tree/model/Bin.scala | 2 +- .../apache/spark/mllib/tree/model/Node.scala | 15 ++- .../apache/spark/mllib/tree/model/Split.scala | 11 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 83 ++++++++++-- 5 files changed, 185 insertions(+), 53 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index e8adef377481c..f89c53a7ad70d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -37,7 +37,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { //Cache input RDD for speedup during multiple passes input.cache() - val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) logDebug("numSplits = " + bins(0).length) strategy.numBins = bins(0).length @@ -54,8 +54,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { logDebug("algo = " + strategy.algo) - - breakable { for (level <- 0 until maxDepth){ @@ -185,10 +183,21 @@ object DecisionTree extends Serializable with Logging { val featureIndex = filter.split.feature val threshold = filter.split.threshold val comparison = filter.comparison - comparison match { - case(-1) => if (features(featureIndex) > threshold) return false - case(0) => if (features(featureIndex) != threshold) return false - case(1) => if (features(featureIndex) <= threshold) return false + val categories = filter.split.categories + val isFeatureContinuous = filter.split.featureType == Continuous + val feature = features(featureIndex) + if (isFeatureContinuous){ + comparison match { + case(-1) => if (feature > threshold) return false + case(1) => if (feature <= threshold) return false + } + } else { + val containsFeature = categories.contains(feature) + comparison match { + case(-1) => if (!containsFeature) return false + case(1) => if (containsFeature) return false + } + } } true @@ -197,18 +206,34 @@ object DecisionTree extends Serializable with Logging { /*Finds the right bin for the given feature*/ def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { //logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex)) - //TODO: Do binary search - for (binIndex <- 0 until strategy.numBins) { - val bin = bins(featureIndex)(binIndex) - //TODO: Remove this requirement post basic functional - val lowThreshold = bin.lowSplit.threshold - val highThreshold = bin.highSplit.threshold - val features = labeledPoint.features - if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) { - return binIndex + + val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinous){ + //TODO: Do binary search + for (binIndex <- 0 until strategy.numBins) { + val bin = bins(featureIndex)(binIndex) + //TODO: Remove this requirement post basic functional + val lowThreshold = bin.lowSplit.threshold + val highThreshold = bin.highSplit.threshold + val features = labeledPoint.features + if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) { + return binIndex + } + } + throw new UnknownError("no bin was found for continuous variable.") + } else { + for (binIndex <- 0 until strategy.numBins) { + val bin = bins(featureIndex)(binIndex) + //TODO: Remove this requirement post basic functional + val category = bin.category + val features = labeledPoint.features + if (category == features(featureIndex)) { + return binIndex + } } + throw new UnknownError("no bin was found for categorical variable.") + } - throw new UnknownError("no bin was found.") } @@ -565,7 +590,7 @@ object DecisionTree extends Serializable with Logging { @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an Array[Array[Bin]] of size (numFeatures,numSplits1) */ - def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { + def findSplitsBins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() @@ -603,31 +628,71 @@ object DecisionTree extends Serializable with Logging { logDebug("stride = " + stride) for (index <- 0 until numBins-1) { val sampleIndex = (index+1)*stride.toInt - val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous) + val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List()) splits(featureIndex)(index) = split } } else { val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) - for (index <- 0 until maxFeatureValue){ - //TODO: Sort by centriod - val split = new Split(featureIndex,index,Categorical) - splits(featureIndex)(index) = split + + require(maxFeatureValue < numBins, "number of categories should be less than number of bins") + + val centriodForCategories + = sampledInput.map(lp => (lp.features(featureIndex),lp.label)) + .groupBy(_._1).mapValues(x => x.map(_._2).sum / x.map(_._1).length) + + //Checking for missing categorical variables + val fullCentriodForCategories = scala.collection.mutable.Map[Double,Double]() + for (i <- 0 until maxFeatureValue){ + if (centriodForCategories.contains(i)){ + fullCentriodForCategories(i) = centriodForCategories(i) + } else { + fullCentriodForCategories(i) = Double.MaxValue + } + } + + val categoriesSortedByCentriod + = fullCentriodForCategories.toList sortBy {_._2} + + logDebug("centriod for categorical variable = " + categoriesSortedByCentriod) + + var categoriesForSplit = List[Double]() + categoriesSortedByCentriod.iterator.zipWithIndex foreach { + case((key, value), index) => { + categoriesForSplit = key :: categoriesForSplit + splits(featureIndex)(index) = new Split(featureIndex,Double.MinValue,Categorical,categoriesForSplit) + bins(featureIndex)(index) = { + if(index == 0) { + new Bin(new DummyCategoricalSplit(featureIndex,Categorical),splits(featureIndex)(0),Categorical,key) + } + else { + new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Categorical,key) + } + } + } } } } //Find all bins for (featureIndex <- 0 until numFeatures){ - bins(featureIndex)(0) - = new Bin(new DummyLowSplit(Continuous),splits(featureIndex)(0),Continuous) - for (index <- 1 until numBins - 1){ - val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Continuous) - bins(featureIndex)(index) = bin + val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinous) { //bins for categorical variables are already assigned + bins(featureIndex)(0) + = new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0),Continuous,Double.MinValue) + for (index <- 1 until numBins - 1){ + val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Continuous,Double.MinValue) + bins(featureIndex)(index) = bin + } + bins(featureIndex)(numBins-1) + = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, Continuous),Continuous,Double.MinValue) + } else { + val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) + for (i <- maxFeatureValue until numBins){ + bins(featureIndex)(i) + = new Bin(new DummyCategoricalSplit(featureIndex,Categorical),new DummyCategoricalSplit(featureIndex,Categorical),Categorical,Double.MaxValue) + } } - bins(featureIndex)(numBins-1) - = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(Continuous),Continuous) } - (splits,bins) } case MinMax => { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 13191851956ad..6664f084a7d8d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -18,6 +18,6 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ -case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType) { +case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType, category : Double) { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index a9210e10ae48b..fb63743848cc9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.FeatureType._ class Node ( val id : Int, val predict : Double, @@ -49,10 +50,18 @@ class Node ( val id : Int, if (isLeaf) { predict } else{ - if (feature(split.get.feature) <= split.get.threshold) { - leftNode.get.predictIfLeaf(feature) + if (split.get.featureType == Continuous) { + if (feature(split.get.feature) <= split.get.threshold) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } } else { - rightNode.get.predictIfLeaf(feature) + if (split.get.categories.contains(feature(split.get.feature))) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 01aa349115302..97f16e67c55b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -18,11 +18,14 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType -case class Split(feature: Int, threshold : Double, featureType : FeatureType){ - override def toString = "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType +case class Split(feature: Int, threshold : Double, featureType : FeatureType, categories : List[Double]){ + override def toString = + "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + ", categories = " + categories } -class DummyLowSplit(kind : FeatureType) extends Split(Int.MinValue, Double.MinValue, kind) +class DummyLowSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MinValue, kind, List()) -class DummyHighSplit(kind : FeatureType) extends Split(Int.MaxValue, Double.MaxValue, kind) +class DummyHighSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MaxValue, kind, List()) + +class DummyCategoricalSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MaxValue, kind, List()) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5f9aad0de2f65..4e68611d2be9e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ +import scala.collection.mutable class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { @@ -50,7 +51,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification,Gini,3,100) - val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) assert(splits.length==2) assert(bins.length==2) assert(splits(0).length==99) @@ -58,12 +59,58 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { //println(splits(1)(98)) } + test("split and bin calculation for categorical variables"){ + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + assert(splits.length==2) + assert(bins.length==2) + assert(splits(0).length==99) + assert(bins(0).length==100) + println(splits(0)(0)) + println(splits(0)(1)) + println(bins(0)(0)) + println(splits(1)(0)) + println(splits(1)(1)) + println(bins(1)(0)) + } + + test("split and bin calculations for categorical variables with no sample for one category"){ + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + assert(splits.length==2) + assert(bins.length==2) + assert(splits(0).length==99) + assert(bins(0).length==100) + println(splits(0)(0)) + println(splits(0)(1)) + println(splits(0)(2)) + println(bins(0)(0)) + println(bins(0)(1)) + println(bins(0)(2)) + println(splits(1)(0)) + println(splits(1)(1)) + println(splits(1)(2)) + println(bins(1)(0)) + println(bins(1)(1)) + println(bins(0)(2)) + println(bins(0)(3)) + } + + //TODO: Test max feature value > num bins + + test("stump with fixed label 0 for Gini"){ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification,Gini,3,100) - val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) assert(bins.length==2) @@ -73,15 +120,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) - println("here") assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) assert(0==bestSplits(0)._2.gain) - assert(10==bestSplits(0)._2.leftSamples) assert(0==bestSplits(0)._2.leftImpurity) - assert(990==bestSplits(0)._2.rightSamples) assert(0==bestSplits(0)._2.rightImpurity) + assert(0.01==bestSplits(0)._2.predict) } test("stump with fixed label 1 for Gini"){ @@ -89,7 +134,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification,Gini,3,100) - val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) assert(bins.length==2) @@ -103,10 +148,10 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) assert(0==bestSplits(0)._2.gain) - assert(10==bestSplits(0)._2.leftSamples) assert(0==bestSplits(0)._2.leftImpurity) - assert(990==bestSplits(0)._2.rightSamples) assert(0==bestSplits(0)._2.rightImpurity) + assert(0.01==bestSplits(0)._2.predict) + } @@ -115,7 +160,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification,Entropy,3,100) - val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) assert(bins.length==2) @@ -129,10 +174,9 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) assert(0==bestSplits(0)._2.gain) - assert(10==bestSplits(0)._2.leftSamples) assert(0==bestSplits(0)._2.leftImpurity) - assert(990==bestSplits(0)._2.rightSamples) assert(0==bestSplits(0)._2.rightImpurity) + assert(0.01==bestSplits(0)._2.predict) } test("stump with fixed label 1 for Entropy"){ @@ -140,7 +184,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification,Entropy,3,100) - val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) assert(bins.length==2) @@ -154,10 +198,9 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) assert(0==bestSplits(0)._2.gain) - assert(10==bestSplits(0)._2.leftSamples) assert(0==bestSplits(0)._2.leftImpurity) - assert(990==bestSplits(0)._2.rightSamples) assert(0==bestSplits(0)._2.rightImpurity) + assert(0.01==bestSplits(0)._2.predict) } @@ -184,4 +227,16 @@ object DecisionTreeSuite { arr } + def generateCategoricalDataPoints() : Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + if (i < 600){ + arr(i) = new LabeledPoint(1.0,Array(0.0,1.0)) + } else { + arr(i) = new LabeledPoint(0.0,Array(1.0,0.0)) + } + } + arr + } + } From d504eb1f8a3f7f06226448d42b709f2f7ec6e91c Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 22 Jan 2014 21:59:15 -0800 Subject: [PATCH 20/48] more tests for categorical features Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 10 ++++------ .../spark/mllib/tree/DecisionTreeSuite.scala | 20 ++++++++++++++++++- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index f89c53a7ad70d..ed0cf825b1d50 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -204,15 +204,12 @@ object DecisionTree extends Serializable with Logging { } /*Finds the right bin for the given feature*/ - def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { - //logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex)) + def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinous : Boolean) : Int = { - val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinous){ //TODO: Do binary search for (binIndex <- 0 until strategy.numBins) { val bin = bins(featureIndex)(binIndex) - //TODO: Remove this requirement post basic functional val lowThreshold = bin.lowSplit.threshold val highThreshold = bin.highSplit.threshold val features = labeledPoint.features @@ -222,9 +219,9 @@ object DecisionTree extends Serializable with Logging { } throw new UnknownError("no bin was found for continuous variable.") } else { + for (binIndex <- 0 until strategy.numBins) { val bin = bins(featureIndex)(binIndex) - //TODO: Remove this requirement post basic functional val category = bin.category val features = labeledPoint.features if (category == features(featureIndex)) { @@ -262,7 +259,8 @@ object DecisionTree extends Serializable with Logging { } else { for (featureIndex <- 0 until numFeatures) { //logDebug("shift+featureIndex =" + (shift+featureIndex)) - arr(shift + featureIndex) = findBin(featureIndex, labeledPoint) + val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinous) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 4e68611d2be9e..40bb94e6794d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -75,6 +75,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { println(splits(1)(0)) println(splits(1)(1)) println(bins(1)(0)) + //TODO: Add asserts + } test("split and bin calculations for categorical variables with no sample for one category"){ @@ -100,12 +102,28 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { println(bins(1)(1)) println(bins(0)(2)) println(bins(0)(3)) + //TODO: Add asserts + } //TODO: Test max feature value > num bins - test("stump with fixed label 0 for Gini"){ + test("stump with all categorical variables"){ + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + strategy.numBins = 100 + val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) + println(bestSplits(0)._1) + println(bestSplits(0)._2) + //TODO: Add asserts + } + + + test("stump with fixed label 0 for Gini"){ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) From 6b7de78e3a59bef8cbb8aff8b2aeed0cd91ab4a1 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 25 Jan 2014 17:53:41 -0800 Subject: [PATCH 21/48] minor refactoring and tests Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 100 +++++++++--------- .../spark/mllib/tree/DecisionTreeSuite.scala | 17 ++- 2 files changed, 63 insertions(+), 54 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ed0cf825b1d50..1116d0c4f711e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -152,8 +152,8 @@ object DecisionTree extends Serializable with Logging { //Find the number of features by looking at the first sample val numFeatures = input.take(1)(0).features.length logDebug("numFeatures = " + numFeatures) - val numSplits = strategy.numBins - logDebug("numSplits = " + numSplits) + val numBins = strategy.numBins + logDebug("numBins = " + numBins) /*Find the filters used before reaching the current code*/ def findParentFilters(nodeIndex: Int): List[Filter] = { @@ -161,8 +161,6 @@ object DecisionTree extends Serializable with Logging { List[Filter]() } else { val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex - //val parentFilterIndex = nodeFilterIndex / 2 - //TODO: Check left or right filter filters(nodeFilterIndex) } } @@ -204,9 +202,9 @@ object DecisionTree extends Serializable with Logging { } /*Finds the right bin for the given feature*/ - def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinous : Boolean) : Int = { + def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = { - if (isFeatureContinous){ + if (isFeatureContinuous){ //TODO: Do binary search for (binIndex <- 0 until strategy.numBins) { val bin = bins(featureIndex)(binIndex) @@ -245,11 +243,11 @@ object DecisionTree extends Serializable with Logging { // calculating bin index and label per feature per node val arr = new Array[Double](1+(numFeatures * numNodes)) arr(0) = labeledPoint.label - for (index <- 0 until numNodes) { - val parentFilters = findParentFilters(index) + for (nodeIndex <- 0 until numNodes) { + val parentFilters = findParentFilters(nodeIndex) //Find out whether the sample qualifies for the particular node val sampleValid = isSampleValid(parentFilters, labeledPoint) - val shift = 1 + numFeatures * index + val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { //Add to invalid bin index -1 for (featureIndex <- 0 until numFeatures) { @@ -274,11 +272,11 @@ object DecisionTree extends Serializable with Logging { val isSampleValidForNode = if (arr(validSignalIndex) != -1) true else false if (isSampleValidForNode) { val label = arr(0) - for (feature <- 0 until numFeatures) { + for (featureIndex <- 0 until numFeatures) { val arrShift = 1 + numFeatures * node - val aggShift = 2 * numSplits * numFeatures * node - val arrIndex = arrShift + feature - val aggIndex = aggShift + 2 * feature * numSplits + arr(arrIndex).toInt * 2 + val aggShift = 2 * numBins * numFeatures * node + val arrIndex = arrShift + featureIndex + val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 label match { case (0.0) => agg(aggIndex) = agg(aggIndex) + 1 case (1.0) => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 @@ -296,9 +294,9 @@ object DecisionTree extends Serializable with Logging { val label = arr(0) for (feature <- 0 until numFeatures) { val arrShift = 1 + numFeatures * node - val aggShift = 3 * numSplits * numFeatures * node + val aggShift = 3 * numBins * numFeatures * node val arrIndex = arrShift + feature - val aggIndex = aggShift + 3 * feature * numSplits + arr(arrIndex).toInt * 3 + val aggIndex = aggShift + 3 * feature * numBins + arr(arrIndex).toInt * 3 //count, sum, sum^2 agg(aggIndex) = agg(aggIndex) + 1 agg(aggIndex + 1) = agg(aggIndex + 1) + label @@ -318,7 +316,6 @@ object DecisionTree extends Serializable with Logging { @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification */ def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { - //TODO: Requires logic for regressions strategy.algo match { case Classification => classificationBinSeqOp(arr, agg) //TODO: Implement this @@ -327,10 +324,9 @@ object DecisionTree extends Serializable with Logging { agg } - //TODO: This length is different for regression val binAggregateLength = strategy.algo match { - case Classification => 2*numSplits * numFeatures * numNodes - case Regression => 3*numSplits * numFeatures * numNodes + case Classification => 2*numBins * numFeatures * numNodes + case Regression => 3*numBins * numFeatures * numNodes } logDebug("binAggregateLength = " + binAggregateLength) @@ -453,52 +449,52 @@ object DecisionTree extends Serializable with Logging { strategy.algo match { case Classification => { - val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) + val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) for (featureIndex <- 0 until numFeatures) { - val shift = 2*featureIndex*numSplits + val shift = 2*featureIndex*numBins leftNodeAgg(featureIndex)(0) = binData(shift + 0) leftNodeAgg(featureIndex)(1) = binData(shift + 1) - rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) - rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) - for (splitIndex <- 1 until numSplits - 1) { + rightNodeAgg(featureIndex)(2 * (numBins - 2)) = binData(shift + (2 * (numBins - 1))) + rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) = binData(shift + (2 * (numBins - 1)) + 1) + for (splitIndex <- 1 until numBins - 1) { leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) - rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex)) - = binData(shift + (2 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex)) - rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex) + 1) - = binData(shift + (2 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex) + 1) + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) + = binData(shift + (2 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) + = binData(shift + (2 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) } } (leftNodeAgg, rightNodeAgg) } case Regression => { - val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numSplits - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numSplits - 1)) + val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) for (featureIndex <- 0 until numFeatures) { - val shift = 3*featureIndex*numSplits + val shift = 3*featureIndex*numBins leftNodeAgg(featureIndex)(0) = binData(shift + 0) leftNodeAgg(featureIndex)(1) = binData(shift + 1) leftNodeAgg(featureIndex)(2) = binData(shift + 2) - rightNodeAgg(featureIndex)(3 * (numSplits - 2)) = binData(shift + (3 * (numSplits - 1))) - rightNodeAgg(featureIndex)(3 * (numSplits - 2) + 1) = binData(shift + (3 * (numSplits - 1)) + 1) - rightNodeAgg(featureIndex)(3 * (numSplits - 2) + 2) = binData(shift + (3 * (numSplits - 1)) + 2) - for (splitIndex <- 1 until numSplits - 1) { + rightNodeAgg(featureIndex)(3 * (numBins - 2)) = binData(shift + (3 * (numBins - 1))) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = binData(shift + (3 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = binData(shift + (3 * (numBins - 1)) + 2) + for (splitIndex <- 1 until numBins - 1) { leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3*splitIndex) + leftNodeAgg(featureIndex)(3 * splitIndex - 3) leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3*splitIndex + 1) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3*splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) - rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex)) - = binData(shift + (3 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex)) - rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 1) - = binData(shift + (3 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 1) - rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 2) - = binData(shift + (3 * (numSplits - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 2) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) + = binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) + = binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) + = binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) } } (leftNodeAgg, rightNodeAgg) @@ -509,10 +505,10 @@ object DecisionTree extends Serializable with Logging { def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double) : Array[Array[InformationGainStats]] = { - val gains = Array.ofDim[InformationGainStats](numFeatures, numSplits - 1) + val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) for (featureIndex <- 0 until numFeatures) { - for (index <- 0 until numSplits -1) { + for (index <- 0 until numBins -1) { //logDebug("splitIndex = " + index) gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity) } @@ -521,10 +517,10 @@ object DecisionTree extends Serializable with Logging { } /* - Find the best split for a node given bin aggregate data + Find the best split for a node given bin aggregate data - @param binData Array[Double] of size 2*numSplits*numFeatures - */ + @param binData Array[Double] of size 2*numSplits*numFeatures + */ def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, InformationGainStats) = { logDebug("node impurity = " + nodeImpurity) val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) @@ -536,7 +532,7 @@ object DecisionTree extends Serializable with Logging { //Initialization with infeasible values var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1) for (featureIndex <- 0 until numFeatures) { - for (splitIndex <- 0 until numSplits - 1){ + for (splitIndex <- 0 until numBins - 1){ val gainStats = gains(featureIndex)(splitIndex) if(gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats @@ -556,13 +552,13 @@ object DecisionTree extends Serializable with Logging { def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { case Classification => { - val shift = 2 * node * numSplits * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 2 * numSplits * numFeatures) + val shift = 2 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures) binsForNode } case Regression => { - val shift = 3 * node * numSplits * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 3 * numSplits * numFeatures) + val shift = 3 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) binsForNode } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 40bb94e6794d7..8d5ed343e0eb4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -109,7 +109,20 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { //TODO: Test max feature value > num bins - test("stump with all categorical variables"){ + test("classification stump with all categorical variables"){ + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + strategy.numBins = 100 + val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) + println(bestSplits(0)._1) + println(bestSplits(0)._2) + //TODO: Add asserts + } + + test("regression stump with all categorical variables"){ val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length == 1000) val rdd = sc.parallelize(arr) @@ -123,7 +136,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { } - test("stump with fixed label 0 for Gini"){ + test("stump with fixed label 0 for Gini"){ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) From b09dc983f4f05da61479c87617526064b0e3dde8 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 26 Jan 2014 14:54:43 -0800 Subject: [PATCH 22/48] minor refactoring Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1116d0c4f711e..ab2c9011dd93b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -367,18 +367,18 @@ object DecisionTree extends Serializable with Logging { def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, - index: Int, + splitIndex: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double) : InformationGainStats = { strategy.algo match { case Classification => { - val left0Count = leftNodeAgg(featureIndex)(2 * index) - val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) + val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) + val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) val leftCount = left0Count + left1Count - val right0Count = rightNodeAgg(featureIndex)(2 * index) - val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) + val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex) + val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1) val rightCount = right0Count + right1Count val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) @@ -405,13 +405,13 @@ object DecisionTree extends Serializable with Logging { new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict) } case Regression => { - val leftCount = leftNodeAgg(featureIndex)(3 * index) - val leftSum = leftNodeAgg(featureIndex)(3 * index + 1) - val leftSumSquares = leftNodeAgg(featureIndex)(3 * index + 2) + val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) + val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1) + val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2) - val rightCount = rightNodeAgg(featureIndex)(3 * index) - val rightSum = rightNodeAgg(featureIndex)(3 * index + 1) - val rightSumSquares = rightNodeAgg(featureIndex)(3 * index + 2) + val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex) + val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1) + val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2) val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares) @@ -463,9 +463,9 @@ object DecisionTree extends Serializable with Logging { leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) - = binData(shift + (2 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) + = binData(shift + (2 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) - = binData(shift + (2 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) + = binData(shift + (2 * (numBins - 2 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) } } (leftNodeAgg, rightNodeAgg) @@ -490,11 +490,11 @@ object DecisionTree extends Serializable with Logging { leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3*splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) - = binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) + = binData(shift + (3 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) - = binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) + = binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) - = binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) + = binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) } } (leftNodeAgg, rightNodeAgg) @@ -508,9 +508,9 @@ object DecisionTree extends Serializable with Logging { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) for (featureIndex <- 0 until numFeatures) { - for (index <- 0 until numBins -1) { + for (splitIndex <- 0 until numBins -1) { //logDebug("splitIndex = " + index) - gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity) + gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, splitIndex, rightNodeAgg, nodeImpurity) } } gains @@ -544,6 +544,8 @@ object DecisionTree extends Serializable with Logging { (bestFeatureIndex,bestSplitIndex,bestGainStats) } + logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) + logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex)) (splits(bestFeatureIndex)(bestSplitIndex),gainStats) } @@ -614,13 +616,14 @@ object DecisionTree extends Serializable with Logging { //Find all splits for (featureIndex <- 0 until numFeatures){ - val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinous) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted val stride : Double = numSamples.toDouble/numBins logDebug("stride = " + stride) for (index <- 0 until numBins-1) { + //TODO: Investigate this val sampleIndex = (index+1)*stride.toInt val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List()) splits(featureIndex)(index) = split From c0e522b7d1f5e27c81d682e5c8c97543fb4242be Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 26 Jan 2014 19:11:43 -0800 Subject: [PATCH 23/48] updated predict and split threshold logic Signed-off-by: Manish Amde --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 9 +++++---- .../org/apache/spark/mllib/tree/DecisionTreeRunner.scala | 1 - .../spark/mllib/tree/model/DecisionTreeModel.scala | 2 +- .../spark/mllib/tree/model/InformationGainStats.scala | 7 ++++--- .../scala/org/apache/spark/mllib/tree/model/Node.scala | 1 + .../org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 8 ++++---- 6 files changed, 15 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ab2c9011dd93b..865a95c5025fc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -211,7 +211,7 @@ object DecisionTree extends Serializable with Logging { val lowThreshold = bin.lowSplit.threshold val highThreshold = bin.highSplit.threshold val features = labeledPoint.features - if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) { + if ((lowThreshold < features(featureIndex)) & (highThreshold >= features(featureIndex))) { return binIndex } } @@ -400,7 +400,8 @@ object DecisionTree extends Serializable with Logging { } } - val predict = leftCount / (leftCount + rightCount) + //val predict = leftCount / (leftCount + rightCount) + val predict = (left1Count + right1Count) / (leftCount + rightCount) new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict) } @@ -672,8 +673,8 @@ object DecisionTree extends Serializable with Logging { //Find all bins for (featureIndex <- 0 until numFeatures){ - val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinous) { //bins for categorical variables are already assigned + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { //bins for categorical variables are already assigned bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0),Continuous,Double.MinValue) for (index <- 1 until numBins - 1){ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala index ae18cb0aaa4e7..4e6ed768d55d3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala @@ -133,7 +133,6 @@ object DecisionTreeRunner extends Logging { //TODO: Make these generic MLTable metrics def meanSquaredError(tree : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = { val meanSumOfSquares = data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean() - println("meanSumOfSquares = " + meanSumOfSquares) meanSumOfSquares } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 587e549c34ca8..0da42e826984c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -24,7 +24,7 @@ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializabl def predict(features : Array[Double]) = { algo match { case Classification => { - if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0 + if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0 } case Regression => { topNode.predictIfLeaf(features) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index b992684b2b05b..55d5893ee93c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -24,9 +24,10 @@ class InformationGainStats(val gain : Double, //val rightSamples : Long val predict : Double) extends Serializable { - override def toString = - "gain = " + gain + ", impurity = " + impurity + ", left impurity = " - + leftImpurity + ", right impurity = " + rightImpurity + ", predict = " + predict + override def toString = { + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" + .format(gain, impurity, leftImpurity, rightImpurity, predict) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index fb63743848cc9..508b7b31d83b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -34,6 +34,7 @@ class Node ( val id : Int, def build(nodes : Array[Node]) : Unit = { logDebug("building node " + id + " at level " + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) + logDebug("id = " + id + ", split = " + split) logDebug("stats = " + stats) logDebug("predict = " + predict) if (!isLeaf) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 8d5ed343e0eb4..15b5b40b06532 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -157,7 +157,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._2.gain) assert(0==bestSplits(0)._2.leftImpurity) assert(0==bestSplits(0)._2.rightImpurity) - assert(0.01==bestSplits(0)._2.predict) + println(bestSplits(0)._2.predict) } test("stump with fixed label 1 for Gini"){ @@ -181,7 +181,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._2.gain) assert(0==bestSplits(0)._2.leftImpurity) assert(0==bestSplits(0)._2.rightImpurity) - assert(0.01==bestSplits(0)._2.predict) + assert(1==bestSplits(0)._2.predict) } @@ -207,7 +207,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._2.gain) assert(0==bestSplits(0)._2.leftImpurity) assert(0==bestSplits(0)._2.rightImpurity) - assert(0.01==bestSplits(0)._2.predict) + assert(0==bestSplits(0)._2.predict) } test("stump with fixed label 1 for Entropy"){ @@ -231,7 +231,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._2.gain) assert(0==bestSplits(0)._2.leftImpurity) assert(0==bestSplits(0)._2.rightImpurity) - assert(0.01==bestSplits(0)._2.predict) + assert(1==bestSplits(0)._2.predict) } From f067d68f0d951e7f0f089419c506fbd5ce2c2fc1 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 26 Jan 2014 19:36:21 -0800 Subject: [PATCH 24/48] minor cleanup Signed-off-by: Manish Amde --- .../apache/spark/mllib/tree/DecisionTree.scala | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 865a95c5025fc..a9a578c4ac262 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -41,7 +41,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { logDebug("numSplits = " + bins(0).length) strategy.numBins = bins(0).length - //TODO: Level-wise training of tree and obtain Decision Tree model val maxDepth = strategy.maxDepth val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1 @@ -62,7 +61,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { logDebug("#####################################") //Find best split for all nodes at a level - val numNodes= scala.math.pow(2,level).toInt val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){ @@ -105,7 +103,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { private def extractInfoForLowerLevels(level: Int, index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), parentImpurities: Array[Double], filters: Array[List[Filter]]) { for (i <- 0 to 1) { - val nodeIndex = (scala.math.pow(2, level + 1)).toInt - 1 + 2 * index + i + val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i if (level < maxDepth - 1) { @@ -205,7 +203,6 @@ object DecisionTree extends Serializable with Logging { def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = { if (isFeatureContinuous){ - //TODO: Do binary search for (binIndex <- 0 until strategy.numBins) { val bin = bins(featureIndex)(binIndex) val lowThreshold = bin.lowSplit.threshold @@ -250,9 +247,12 @@ object DecisionTree extends Serializable with Logging { val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { //Add to invalid bin index -1 - for (featureIndex <- 0 until numFeatures) { - arr(shift+featureIndex) = -1 - //TODO: Break since marking one bin is sufficient + breakable { + for (featureIndex <- 0 until numFeatures) { + arr(shift+featureIndex) = -1 + //Breaking since marking one bin is sufficient + break() + } } } else { for (featureIndex <- 0 until numFeatures) { @@ -318,7 +318,6 @@ object DecisionTree extends Serializable with Logging { def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { strategy.algo match { case Classification => classificationBinSeqOp(arr, agg) - //TODO: Implement this case Regression => regressionBinSeqOp(arr, agg) } agg @@ -599,7 +598,6 @@ object DecisionTree extends Serializable with Logging { logDebug("maxBins = " + numBins) //Calculate the number of sample for approximate quantile calculation - //TODO: Justify this calculation val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 logDebug("fraction of data used for calculating quantiles = " + fraction) @@ -624,7 +622,6 @@ object DecisionTree extends Serializable with Logging { val stride : Double = numSamples.toDouble/numBins logDebug("stride = " + stride) for (index <- 0 until numBins-1) { - //TODO: Investigate this val sampleIndex = (index+1)*stride.toInt val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List()) splits(featureIndex)(index) = split From 5841c2838e6834fc8c767f3c83dba7ef99375fa4 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 26 Jan 2014 22:34:49 -0800 Subject: [PATCH 25/48] unit tests for categorical features Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTreeSuite.scala | 228 +++++++++++++++--- 1 file changed, 191 insertions(+), 37 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 15b5b40b06532..39635a7e654a2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -27,11 +27,12 @@ import org.apache.spark.SparkContext._ import org.jblas._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini} +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import scala.collection.mutable +import org.apache.spark.mllib.tree.configuration.FeatureType._ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { @@ -56,7 +57,6 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins.length==2) assert(splits(0).length==99) assert(bins(0).length==100) - //println(splits(1)(98)) } test("split and bin calculation for categorical variables"){ @@ -69,13 +69,71 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins.length==2) assert(splits(0).length==99) assert(bins(0).length==100) - println(splits(0)(0)) - println(splits(0)(1)) - println(bins(0)(0)) - println(splits(1)(0)) - println(splits(1)(1)) - println(bins(1)(0)) - //TODO: Add asserts + + //Checking splits + + assert(splits(0)(0).feature == 0) + assert(splits(0)(0).threshold == Double.MinValue) + assert(splits(0)(0).featureType == Categorical) + assert(splits(0)(0).categories.length == 1) + assert(splits(0)(0).categories.contains(1.0)) + + + assert(splits(0)(1).feature == 0) + assert(splits(0)(1).threshold == Double.MinValue) + assert(splits(0)(1).featureType == Categorical) + assert(splits(0)(1).categories.length == 2) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(0)(1).categories.contains(0.0)) + + assert(splits(0)(2) == null) + + assert(splits(1)(0).feature == 1) + assert(splits(1)(0).threshold == Double.MinValue) + assert(splits(1)(0).featureType == Categorical) + assert(splits(1)(0).categories.length == 1) + assert(splits(1)(0).categories.contains(0.0)) + + + assert(splits(1)(1).feature == 1) + assert(splits(1)(1).threshold == Double.MinValue) + assert(splits(1)(1).featureType == Categorical) + assert(splits(1)(1).categories.length == 2) + assert(splits(1)(1).categories.contains(1.0)) + assert(splits(1)(1).categories.contains(0.0)) + + assert(splits(1)(2) == null) + + + // Checks bins + + assert(bins(0)(0).category == 1.0) + assert(bins(0)(0).lowSplit.categories.length == 0) + assert(bins(0)(0).highSplit.categories.length == 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + + assert(bins(0)(1).category == 0.0) + assert(bins(0)(1).lowSplit.categories.length == 1) + assert(bins(0)(1).lowSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.length == 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(0.0)) + + assert(bins(0)(2).category == Double.MaxValue) + + assert(bins(1)(0).category == 0.0) + assert(bins(1)(0).lowSplit.categories.length == 0) + assert(bins(1)(0).highSplit.categories.length == 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(1)(1).category == 1.0) + assert(bins(1)(1).lowSplit.categories.length == 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length == 2) + assert(bins(1)(1).highSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(1)(2).category == Double.MaxValue) } @@ -85,29 +143,106 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(bins.length==2) - assert(splits(0).length==99) - assert(bins(0).length==100) - println(splits(0)(0)) - println(splits(0)(1)) - println(splits(0)(2)) - println(bins(0)(0)) - println(bins(0)(1)) - println(bins(0)(2)) - println(splits(1)(0)) - println(splits(1)(1)) - println(splits(1)(2)) - println(bins(1)(0)) - println(bins(1)(1)) - println(bins(0)(2)) - println(bins(0)(3)) - //TODO: Add asserts - } + //Checking splits + + assert(splits(0)(0).feature == 0) + assert(splits(0)(0).threshold == Double.MinValue) + assert(splits(0)(0).featureType == Categorical) + assert(splits(0)(0).categories.length == 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature == 0) + assert(splits(0)(1).threshold == Double.MinValue) + assert(splits(0)(1).featureType == Categorical) + assert(splits(0)(1).categories.length == 2) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(0)(1).categories.contains(0.0)) + + assert(splits(0)(2).feature == 0) + assert(splits(0)(2).threshold == Double.MinValue) + assert(splits(0)(2).featureType == Categorical) + assert(splits(0)(2).categories.length == 3) + assert(splits(0)(2).categories.contains(1.0)) + assert(splits(0)(2).categories.contains(0.0)) + assert(splits(0)(2).categories.contains(2.0)) + + assert(splits(0)(3) == null) + + assert(splits(1)(0).feature == 1) + assert(splits(1)(0).threshold == Double.MinValue) + assert(splits(1)(0).featureType == Categorical) + assert(splits(1)(0).categories.length == 1) + assert(splits(1)(0).categories.contains(0.0)) + + assert(splits(1)(1).feature == 1) + assert(splits(1)(1).threshold == Double.MinValue) + assert(splits(1)(1).featureType == Categorical) + assert(splits(1)(1).categories.length == 2) + assert(splits(1)(1).categories.contains(1.0)) + assert(splits(1)(1).categories.contains(0.0)) + + assert(splits(1)(2).feature == 1) + assert(splits(1)(2).threshold == Double.MinValue) + assert(splits(1)(2).featureType == Categorical) + assert(splits(1)(2).categories.length == 3) + assert(splits(1)(2).categories.contains(1.0)) + assert(splits(1)(2).categories.contains(0.0)) + assert(splits(1)(2).categories.contains(2.0)) + + assert(splits(1)(3) == null) + + + // Checks bins + + assert(bins(0)(0).category == 1.0) + assert(bins(0)(0).lowSplit.categories.length == 0) + assert(bins(0)(0).highSplit.categories.length == 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + + assert(bins(0)(1).category == 0.0) + assert(bins(0)(1).lowSplit.categories.length == 1) + assert(bins(0)(1).lowSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.length == 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(0.0)) + + assert(bins(0)(2).category == 2.0) + assert(bins(0)(2).lowSplit.categories.length == 2) + assert(bins(0)(2).lowSplit.categories.contains(1.0)) + assert(bins(0)(2).lowSplit.categories.contains(0.0)) + assert(bins(0)(2).highSplit.categories.length == 3) + assert(bins(0)(2).highSplit.categories.contains(1.0)) + assert(bins(0)(2).highSplit.categories.contains(0.0)) + assert(bins(0)(2).highSplit.categories.contains(2.0)) + + assert(bins(0)(3).category == Double.MaxValue) + + assert(bins(1)(0).category == 0.0) + assert(bins(1)(0).lowSplit.categories.length == 0) + assert(bins(1)(0).highSplit.categories.length == 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(1)(1).category == 1.0) + assert(bins(1)(1).lowSplit.categories.length == 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length == 2) + assert(bins(1)(1).highSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(1)(2).category == 2.0) + assert(bins(1)(2).lowSplit.categories.length == 2) + assert(bins(1)(2).lowSplit.categories.contains(0.0)) + assert(bins(1)(2).lowSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.length == 3) + assert(bins(1)(2).highSplit.categories.contains(0.0)) + assert(bins(1)(2).highSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.contains(2.0)) + + assert(bins(1)(3).category == Double.MaxValue) - //TODO: Test max feature value > num bins + } test("classification stump with all categorical variables"){ val arr = DecisionTreeSuite.generateCategoricalDataPoints() @@ -117,22 +252,41 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) - println(bestSplits(0)._1) - println(bestSplits(0)._2) - //TODO: Add asserts + + val split = bestSplits(0)._1 + assert(split.categories.length == 1) + assert(split.categories.contains(1.0)) + assert(split.featureType == Categorical) + assert(split.threshold == Double.MinValue) + + val stats = bestSplits(0)._2 + assert(stats.gain > 0) + assert(stats.predict > 0.4) + assert(stats.predict < 0.5) + assert(stats.impurity > 0.2) + } test("regression stump with all categorical variables"){ val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) - println(bestSplits(0)._1) - println(bestSplits(0)._2) - //TODO: Add asserts + + val split = bestSplits(0)._1 + assert(split.categories.length == 1) + assert(split.categories.contains(1.0)) + assert(split.featureType == Categorical) + assert(split.threshold == Double.MinValue) + + val stats = bestSplits(0)._2 + assert(stats.gain > 0) + assert(stats.predict > 0.4) + assert(stats.predict < 0.5) + assert(stats.impurity > 0.2) } @@ -157,7 +311,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._2.gain) assert(0==bestSplits(0)._2.leftImpurity) assert(0==bestSplits(0)._2.rightImpurity) - println(bestSplits(0)._2.predict) + } test("stump with fixed label 1 for Gini"){ From 0dd7659055879be9fbb3280964f87b14c735f225 Mon Sep 17 00:00:00 2001 From: manishamde Date: Sun, 26 Jan 2014 22:42:06 -0800 Subject: [PATCH 26/48] basic doc Signed-off-by: Manish Amde --- .../scala/org/apache/spark/mllib/tree/README.md | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md index 61cba281cccfc..0fd71aa9735bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md @@ -1,15 +1,17 @@ This package contains the default implementation of the decision tree algorithm. The decision tree algorithm supports: -+ information loss calculation with entropy and gini for classification and variance for regression -+ node model pruning -+ printing to dot files -+ unit tests ++ Binary classification ++ Regression ++ Information loss calculation with entropy and gini for classification and variance for regression ++ Both continuous and categorical features -#Performance testing +# Tree improvements ++ Node model pruning ++ Printing to dot files -#Future Extensions +# Future Ensemble Extensions + Random forests + Boosting -+ Extremely randomized trees \ No newline at end of file ++ Extremely randomized trees From dd0c0d799d42c94da3f930065a6c2973143bfd75 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 27 Jan 2014 00:01:43 -0800 Subject: [PATCH 27/48] minor: some docs Signed-off-by: Manish Amde --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index a9a578c4ac262..89a3f6de4fcb5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -29,9 +29,21 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ +/* +A class that implements a decision tree algorithm for classification and regression. +It supports both continuous and categorical features. +@param strategy The configuration parameters for the tree algorithm which specify the type of algorithm (classification, +regression, etc.), feature type (continuous, categorical), depth of the tree, quantile calculation strategy, etc. + */ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { + /* + Method to train a decision tree model over an RDD + + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree + @return a DecisionTreeModel that can be used for prediction + */ def train(input : RDD[LabeledPoint]) : DecisionTreeModel = { //Cache input RDD for speedup during multiple passes From 937277990e80f9a97070c63d39552579f0320fd7 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 16 Feb 2014 19:42:48 -0800 Subject: [PATCH 28/48] code style: max line lenght <= 100 Signed-off-by: Manish Amde --- .../mllib/regression/RegressionTree.scala | 21 -- .../spark/mllib/tree/DecisionTree.scala | 216 +++++++++++++----- .../spark/mllib/tree/DecisionTreeRunner.scala | 5 +- .../spark/mllib/tree/impurity/Gini.scala | 4 +- .../tree/model/InformationGainStats.scala | 13 +- .../apache/spark/mllib/tree/model/Node.scala | 6 +- .../apache/spark/mllib/tree/model/Split.scala | 19 +- 7 files changed, 184 insertions(+), 100 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionTree.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionTree.scala deleted file mode 100644 index fd9beb79ab88f..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionTree.scala +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.mllib.regression - -class RegressionTree { - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 89a3f6de4fcb5..28e2f992e65b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -33,15 +33,18 @@ import org.apache.spark.mllib.tree.configuration.Algo._ A class that implements a decision tree algorithm for classification and regression. It supports both continuous and categorical features. -@param strategy The configuration parameters for the tree algorithm which specify the type of algorithm (classification, -regression, etc.), feature type (continuous, categorical), depth of the tree, quantile calculation strategy, etc. +@param strategy The configuration parameters for the tree algorithm which specify the type of +algorithm (classification, +regression, etc.), feature type (continuous, categorical), depth of the tree, +quantile calculation strategy, etc. */ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { /* Method to train a decision tree model over an RDD - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + for DecisionTree @return a DecisionTreeModel that can be used for prediction */ def train(input : RDD[LabeledPoint]) : DecisionTreeModel = { @@ -73,12 +76,14 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { logDebug("#####################################") //Find best split for all nodes at a level - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins) + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, + level, filters,splits,bins) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){ - extractNodeInfo(nodeSplitStats, level, index, nodes) - extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) + extractNodeInfo(nodeSplitStats, level, index, nodes) + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, + filters) logDebug("final best split = " + nodeSplitStats._1) } @@ -102,7 +107,11 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { } - private def extractNodeInfo(nodeSplitStats: (Split, InformationGainStats), level: Int, index: Int, nodes: Array[Node]) { + private def extractNodeInfo( + nodeSplitStats: (Split, InformationGainStats), + level: Int, index: Int, + nodes: Array[Node]) { + val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = scala.math.pow(2, level).toInt - 1 + index @@ -112,17 +121,31 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { nodes(nodeIndex) = node } - private def extractInfoForLowerLevels(level: Int, index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), parentImpurities: Array[Double], filters: Array[List[Filter]]) { + private def extractInfoForLowerLevels( + level: Int, + index: Int, + maxDepth: Int, + nodeSplitStats: (Split, InformationGainStats), + parentImpurities: Array[Double], + filters: Array[List[Filter]]) { + for (i <- 0 to 1) { val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i if (level < maxDepth - 1) { - val impurity = if (i == 0) nodeSplitStats._2.leftImpurity else nodeSplitStats._2.rightImpurity + val impurity = if (i == 0) { + nodeSplitStats._2.leftImpurity + } else { + nodeSplitStats._2.rightImpurity + } + logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) parentImpurities(nodeIndex) = impurity - filters(nodeIndex) = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) :: filters((nodeIndex - 1) / 2) + val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) + filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) + for (filter <- filters(nodeIndex)) { logDebug("Filter = " + filter) } @@ -137,9 +160,11 @@ object DecisionTree extends Serializable with Logging { /* Returns an Array[Split] of optimal splits for all nodes at a given level - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + for DecisionTree @param parentImpurities Impurities for all parent nodes for the current level - @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing parameters for construction the DecisionTree + @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + parameters for construction the DecisionTree @param level Level of the tree @param filters Filter for all nodes at a given level @param splits possible splits for all features @@ -175,8 +200,8 @@ object DecisionTree extends Serializable with Logging { } } - /*Find whether the sample is valid input for the current node. - + /* + Find whether the sample is valid input for the current node. In other words, does it pass through all the filters for the current node. */ def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { @@ -321,11 +346,15 @@ object DecisionTree extends Serializable with Logging { /*Performs a sequential aggregation over a partition. for p bins, k features, l nodes (level = log2(l)) storage is of the form: - b111_left_count,b111_right_count, .... , bpk1_left_count, bpk1_right_count, .... , bpkl_left_count, bpkl_right_count + b111_left_count,b111_right_count, .... , .. + .. bpk1_left_count, bpk1_right_count, .... , .. + .. bpkl_left_count, bpkl_right_count - @param agg Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification + @param agg Array[Double] storing aggregate calculation of size + 2*numSplits*numFeatures*numNodes for classification @param arr Array[Double] of size 1+(numFeatures*numNodes) - @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification + @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes + for classification */ def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { strategy.algo match { @@ -371,7 +400,9 @@ object DecisionTree extends Serializable with Logging { logDebug("binMappedRDD.count = " + binMappedRDD.count) //calculate bin aggregates - val binAggregates = binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) + val binAggregates = { + binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) + } logDebug("binAggregates.length = " + binAggregates.length) //binAggregates.foreach(x => logDebug(x)) @@ -392,10 +423,20 @@ object DecisionTree extends Serializable with Logging { val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1) val rightCount = right0Count + right1Count - val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) + val impurity = { + if (level > 0) { + topImpurity + } else { + strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) + } + } - if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity,1) - if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,Double.MinValue,0) + if (leftCount == 0) { + return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity,1) + } + if (rightCount == 0) { + return new InformationGainStats(0,topImpurity,topImpurity,Double.MinValue,0) + } val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) @@ -425,10 +466,25 @@ object DecisionTree extends Serializable with Logging { val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1) val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2) - val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares) + val impurity = { + if (level > 0) { + topImpurity + } else { + val count = leftCount + rightCount + val sum = leftSum + rightSum + val sumSquares = leftSumSquares + rightSumSquares + strategy.impurity.calculate(count, sum, sumSquares) + } + } - if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity,rightSum/rightCount) - if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,Double.MinValue,leftSum/leftCount) + if (leftCount == 0) { + return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity, + rightSum/rightCount) + } + if (rightCount == 0) { + return new InformationGainStats(0,topImpurity,topImpurity, + Double.MinValue,leftSum/leftCount) + } val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) @@ -444,7 +500,8 @@ object DecisionTree extends Serializable with Logging { } } - new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,(leftSum + rightSum)/(leftCount+rightCount)) + val predict = (leftSum + rightSum)/(leftCount+rightCount) + new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict) } } @@ -457,8 +514,10 @@ object DecisionTree extends Serializable with Logging { @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) */ - def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { - strategy.algo match { + def extractLeftRightNodeAggregates( + binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { + + strategy.algo match { case Classification => { val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) @@ -467,21 +526,26 @@ object DecisionTree extends Serializable with Logging { val shift = 2*featureIndex*numBins leftNodeAgg(featureIndex)(0) = binData(shift + 0) leftNodeAgg(featureIndex)(1) = binData(shift + 1) - rightNodeAgg(featureIndex)(2 * (numBins - 2)) = binData(shift + (2 * (numBins - 1))) - rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) = binData(shift + (2 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(2 * (numBins - 2)) + = binData(shift + (2 * (numBins - 1))) + rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) + = binData(shift + (2 * (numBins - 1)) + 1) for (splitIndex <- 1 until numBins - 1) { - leftNodeAgg(featureIndex)(2 * splitIndex) - = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) - leftNodeAgg(featureIndex)(2 * splitIndex + 1) - = binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) - rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) - = binData(shift + (2 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) - rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) - = binData(shift + (2 * (numBins - 2 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) + leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2*splitIndex) + + leftNodeAgg(featureIndex)(2 * splitIndex - 2) + leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2*splitIndex + 1) + + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = + binData(shift + (2 *(numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) = + binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) } } (leftNodeAgg, rightNodeAgg) } + case Regression => { val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) @@ -491,22 +555,31 @@ object DecisionTree extends Serializable with Logging { leftNodeAgg(featureIndex)(0) = binData(shift + 0) leftNodeAgg(featureIndex)(1) = binData(shift + 1) leftNodeAgg(featureIndex)(2) = binData(shift + 2) - rightNodeAgg(featureIndex)(3 * (numBins - 2)) = binData(shift + (3 * (numBins - 1))) - rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = binData(shift + (3 * (numBins - 1)) + 1) - rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = binData(shift + (3 * (numBins - 1)) + 2) + rightNodeAgg(featureIndex)(3 * (numBins - 2)) = + binData(shift + (3 * (numBins - 1))) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = + binData(shift + (3 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = + binData(shift + (3 * (numBins - 1)) + 2) for (splitIndex <- 1 until numBins - 1) { leftNodeAgg(featureIndex)(3 * splitIndex) - = binData(shift + 3*splitIndex) + leftNodeAgg(featureIndex)(3 * splitIndex - 3) + = binData(shift + 3*splitIndex) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3) leftNodeAgg(featureIndex)(3 * splitIndex + 1) - = binData(shift + 3*splitIndex + 1) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) + = binData(shift + 3*splitIndex + 1) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) leftNodeAgg(featureIndex)(3 * splitIndex + 2) - = binData(shift + 3*splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) + = binData(shift + 3*splitIndex + 2) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) - = binData(shift + (3 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) + = binData(shift + (3 * (numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) - = binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) + = binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) - = binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) + = binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) } } (leftNodeAgg, rightNodeAgg) @@ -514,15 +587,18 @@ object DecisionTree extends Serializable with Logging { } } - def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double) + def calculateGainsForAllNodeSplits( + leftNodeAgg: Array[Array[Double]], + rightNodeAgg: Array[Array[Double]], + nodeImpurity: Double) : Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numBins -1) { - //logDebug("splitIndex = " + index) - gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, splitIndex, rightNodeAgg, nodeImpurity) + gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, + splitIndex, rightNodeAgg, nodeImpurity) } } gains @@ -533,7 +609,10 @@ object DecisionTree extends Serializable with Logging { @param binData Array[Double] of size 2*numSplits*numFeatures */ - def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, InformationGainStats) = { + def binsToBestSplit( + binData : Array[Double], + nodeImpurity : Double) : (Split, InformationGainStats) = { + logDebug("node impurity = " + nodeImpurity) val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) @@ -593,12 +672,17 @@ object DecisionTree extends Serializable with Logging { /* Returns split and bins for decision tree calculation. - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree - @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing parameters for construction the DecisionTree - @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + for DecisionTree + @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + parameters for construction the DecisionTree + @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures, + numSplits-1) and bins is an Array[Array[Bin]] of size (numFeatures,numSplits1) */ - def findSplitsBins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { + def findSplitsBins( + input : RDD[LabeledPoint], + strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() @@ -641,7 +725,8 @@ object DecisionTree extends Serializable with Logging { } else { val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) - require(maxFeatureValue < numBins, "number of categories should be less than number of bins") + require(maxFeatureValue < numBins, "number of categories should be less than number " + + "of bins") val centriodForCategories = sampledInput.map(lp => (lp.features(featureIndex),lp.label)) @@ -666,13 +751,16 @@ object DecisionTree extends Serializable with Logging { categoriesSortedByCentriod.iterator.zipWithIndex foreach { case((key, value), index) => { categoriesForSplit = key :: categoriesForSplit - splits(featureIndex)(index) = new Split(featureIndex,Double.MinValue,Categorical,categoriesForSplit) + splits(featureIndex)(index) = new Split(featureIndex,Double.MinValue,Categorical, + categoriesForSplit) bins(featureIndex)(index) = { if(index == 0) { - new Bin(new DummyCategoricalSplit(featureIndex,Categorical),splits(featureIndex)(0),Categorical,key) + new Bin(new DummyCategoricalSplit(featureIndex,Categorical), + splits(featureIndex)(0),Categorical,key) } else { - new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Categorical,key) + new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index), + Categorical,key) } } } @@ -685,18 +773,22 @@ object DecisionTree extends Serializable with Logging { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { //bins for categorical variables are already assigned bins(featureIndex)(0) - = new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0),Continuous,Double.MinValue) + = new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0), + Continuous,Double.MinValue) for (index <- 1 until numBins - 1){ - val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Continuous,Double.MinValue) + val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index), + Continuous,Double.MinValue) bins(featureIndex)(index) = bin } bins(featureIndex)(numBins-1) - = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, Continuous),Continuous,Double.MinValue) + = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, + Continuous),Continuous,Double.MinValue) } else { val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) for (i <- maxFeatureValue until numBins){ bins(featureIndex)(i) - = new Bin(new DummyCategoricalSplit(featureIndex,Categorical),new DummyCategoricalSplit(featureIndex,Categorical),Categorical,Double.MaxValue) + = new Bin(new DummyCategoricalSplit(featureIndex,Categorical), + new DummyCategoricalSplit(featureIndex,Categorical),Categorical,Double.MaxValue) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala index 4e6ed768d55d3..05d000f3a3ddc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._ object DecisionTreeRunner extends Logging { val usage = """ - Usage: DecisionTreeRunner [slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] + Usage: DecisionTreeRunner[slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] """ @@ -132,7 +132,8 @@ object DecisionTreeRunner extends Logging { //TODO: Make these generic MLTable metrics def meanSquaredError(tree : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = { - val meanSumOfSquares = data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean() + val meanSumOfSquares = + data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean() meanSumOfSquares } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 8befeb5a475f6..3c7615f684525 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -31,5 +31,7 @@ object Gini extends Impurity { } } - def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new OperationNotSupportedException("Gini.calculate") + def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new OperationNotSupportedException("Gini.calculate") + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 55d5893ee93c2..f410a5a2cf812 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -16,13 +16,12 @@ */ package org.apache.spark.mllib.tree.model -class InformationGainStats(val gain : Double, - val impurity: Double, - val leftImpurity : Double, - //val leftSamples : Long, - val rightImpurity : Double, - //val rightSamples : Long - val predict : Double) extends Serializable { +class InformationGainStats( + val gain : Double, + val impurity: Double, + val leftImpurity : Double, + val rightImpurity : Double, + val predict : Double) extends Serializable { override def toString = { "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 508b7b31d83b6..fb7e0db9c9dd2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -29,11 +29,13 @@ class Node ( val id : Int, val stats : Option[InformationGainStats] ) extends Serializable with Logging{ - override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", split = " + split + ", stats = " + stats + override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + + "split = " + split + ", stats = " + stats def build(nodes : Array[Node]) : Unit = { - logDebug("building node " + id + " at level " + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) + logDebug("building node " + id + " at level " + + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) logDebug("id = " + id + ", split = " + split) logDebug("stats = " + stats) logDebug("predict = " + predict) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 97f16e67c55b5..1604996091597 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -18,14 +18,23 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType -case class Split(feature: Int, threshold : Double, featureType : FeatureType, categories : List[Double]){ +case class Split( + feature: Int, + threshold : Double, + featureType : FeatureType, + categories : List[Double]){ + override def toString = - "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + ", categories = " + categories + "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + + ", categories = " + categories } -class DummyLowSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MinValue, kind, List()) +class DummyLowSplit(feature: Int, kind : FeatureType) + extends Split(feature, Double.MinValue, kind, List()) -class DummyHighSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MaxValue, kind, List()) +class DummyHighSplit(feature: Int, kind : FeatureType) + extends Split(feature, Double.MaxValue, kind, List()) -class DummyCategoricalSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MaxValue, kind, List()) +class DummyCategoricalSplit(feature: Int, kind : FeatureType) + extends Split(feature, Double.MaxValue, kind, List()) From 84f85d6d0a1fe7ed60149cc6b29a9ff76ef09abd Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Thu, 27 Feb 2014 20:57:56 -0800 Subject: [PATCH 29/48] code documentation Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 31 +++++++++------- .../spark/mllib/tree/configuration/Algo.scala | 3 ++ .../tree/configuration/FeatureType.scala | 3 ++ .../tree/configuration/QuantileStrategy.scala | 3 ++ .../mllib/tree/configuration/Strategy.scala | 13 +++++++ .../spark/mllib/tree/impurity/Entropy.scala | 10 ++++++ .../spark/mllib/tree/impurity/Gini.scala | 10 ++++++ .../spark/mllib/tree/impurity/Variance.scala | 10 ++++++ .../apache/spark/mllib/tree/model/Bin.scala | 11 ++++++ .../mllib/tree/model/DecisionTreeModel.scala | 26 ++++++++++++-- .../spark/mllib/tree/model/Filter.scala | 5 +++ .../tree/model/InformationGainStats.scala | 8 +++++ .../apache/spark/mllib/tree/model/Node.scala | 10 ++++++ .../apache/spark/mllib/tree/model/Split.scala | 35 +++++++++++++++---- 14 files changed, 157 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 28e2f992e65b8..b8164f64a7b04 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ -/* +/** A class that implements a decision tree algorithm for classification and regression. It supports both continuous and categorical features. @@ -40,7 +40,7 @@ quantile calculation strategy, etc. */ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { - /* + /** Method to train a decision tree model over an RDD @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data @@ -157,14 +157,14 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { object DecisionTree extends Serializable with Logging { - /* + /** Returns an Array[Split] of optimal splits for all nodes at a given level @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree + for DecisionTree @param parentImpurities Impurities for all parent nodes for the current level @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - parameters for construction the DecisionTree + parameters for construction the DecisionTree @param level Level of the tree @param filters Filter for all nodes at a given level @param splits possible splits for all features @@ -200,7 +200,7 @@ object DecisionTree extends Serializable with Logging { } } - /* + /** Find whether the sample is valid input for the current node. In other words, does it pass through all the filters for the current node. */ @@ -236,7 +236,9 @@ object DecisionTree extends Serializable with Logging { true } - /*Finds the right bin for the given feature*/ + /** + Finds the right bin for the given feature + */ def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = { if (isFeatureContinuous){ @@ -266,7 +268,8 @@ object DecisionTree extends Serializable with Logging { } - /*Finds bins for all nodes (and all features) at a given level + /** + Finds bins for all nodes (and all features) at a given level k features, l nodes (level = log2(l)) Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk Denotes invalid sample for tree by noting bin for feature 1 as -1 @@ -343,7 +346,8 @@ object DecisionTree extends Serializable with Logging { } } - /*Performs a sequential aggregation over a partition. + /** + Performs a sequential aggregation over a partition. for p bins, k features, l nodes (level = log2(l)) storage is of the form: b111_left_count,b111_right_count, .... , .. @@ -370,7 +374,8 @@ object DecisionTree extends Serializable with Logging { } logDebug("binAggregateLength = " + binAggregateLength) - /*Combines the aggregates from partitions + /** + Combines the aggregates from partitions @param agg1 Array containing aggregates from one or more partitions @param agg2 Array containing aggregates from one or more partitions @@ -507,7 +512,7 @@ object DecisionTree extends Serializable with Logging { } } - /* + /** Extracts left and right split aggregates @param binData Array[Double] of size 2*numFeatures*numSplits @@ -604,7 +609,7 @@ object DecisionTree extends Serializable with Logging { gains } - /* + /** Find the best split for a node given bin aggregate data @param binData Array[Double] of size 2*numSplits*numFeatures @@ -669,7 +674,7 @@ object DecisionTree extends Serializable with Logging { bestSplits } - /* + /** Returns split and bins for decision tree calculation. @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 7cd128e381e8f..14ec47ce014e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -16,6 +16,9 @@ */ package org.apache.spark.mllib.tree.configuration +/** + * Enum to select the algorithm for the decision tree + */ object Algo extends Enumeration { type Algo = Value val Classification, Regression = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala index a725bf388fe29..b4e8ae4ac39dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -16,6 +16,9 @@ */ package org.apache.spark.mllib.tree.configuration +/** + * Enum to describe whether a feature is "continuous" or "categorical" + */ object FeatureType extends Enumeration { type FeatureType = Value val Continuous, Categorical = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index 1bbd2d8c1fe92..dae73ab52dea7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -16,6 +16,9 @@ */ package org.apache.spark.mllib.tree.configuration +/** + * Enum for selecting the quantile calculation strategy + */ object QuantileStrategy extends Enumeration { type QuantileStrategy = Value val Sort, MinMax, ApproxHist = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 281dabd3364d8..973aaee49e5fb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -20,6 +20,19 @@ import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +/** + * Stores all the configuration options for tree construction + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param categoricalFeaturesInfo A map storing information about the categorical variables and the + * number of discrete values they take. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + */ class Strategy ( val algo : Algo, val impurity : Impurity, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 350627e9de1dd..c1b2972f9c25b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -18,10 +18,20 @@ package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException +/** + * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during + * binary classification. + */ object Entropy extends Impurity { def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + /** + * entropy calculation + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return entropy value + */ def calculate(c0: Double, c1: Double): Double = { if (c0 == 0 || c1 == 0) { 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 3c7615f684525..099c7e33dd39a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -18,8 +18,18 @@ package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException +/** + * Class for calculating the [[http://en.wikipedia.org/wiki/Gini_coefficient Gini + * coefficent]] during binary classification + */ object Gini extends Impurity { + /** + * gini coefficient calculation + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return gini coefficient value + */ def calculate(c0 : Double, c1 : Double): Double = { if (c0 == 0 || c1 == 0) { 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 65f5b3702779a..b313b8d48eadf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -19,9 +19,19 @@ package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException import org.apache.spark.Logging +/** + * Class for calculating variance during regression + */ object Variance extends Impurity with Logging { def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") + /** + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + * @return + */ def calculate(count: Double, sum: Double, sumSquares: Double): Double = { val squaredLoss = sumSquares - (sum*sum)/count squaredLoss/count diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 6664f084a7d8d..0b4b7d2e5b2df 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -18,6 +18,17 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ +/** + * Used for "binning" the features bins for faster best split calculation. For a continuous + * feature, a bin is determined by a low and a high "split". For a categorical feature, + * the a bin is determined using a single label value (category). + * @param lowSplit signifying the lower threshold for the continuous feature to be + * accepted in the bin + * @param highSplit signifying the upper threshold for the continuous feature to be + * accepted in the bin + * @param featureType type of feature -- categorical or continuous + * @param category categorical label value accepted in the bin + */ case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType, category : Double) { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 0da42e826984c..0e94827b0af70 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -16,12 +16,23 @@ */ package org.apache.spark.mllib.tree.model -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.rdd.RDD +/** + * Model to store the decision tree parameters + * @param topNode root node + * @param algo algorithm type -- classification or regression + */ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializable { - def predict(features : Array[Double]) = { + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return Double prediction from the trained model + */ + def predict(features : Array[Double]) : Double = { algo match { case Classification => { if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0 @@ -32,4 +43,15 @@ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializabl } } + /** + * Predict values for the given data set using the model trained. + * + * @param features RDD representing data points to be predicted + * @return RDD[Int] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Array[Double]]): RDD[Double] = { + features.map(x => predict(x)) + } + + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala index 62e5006c80c1b..9fc794c87398d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala @@ -16,6 +16,11 @@ */ package org.apache.spark.mllib.tree.model +/** + * Filter specifying a split and type of comparison to be applied on features + * @param split split specifying the feature index, type and threshold + * @param comparison integer specifying <,=,> + */ case class Filter(split : Split, comparison : Int) { // Comparison -1,0,1 signifies <.=,> override def toString = " split = " + split + "comparison = " + comparison diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index f410a5a2cf812..0f8d7a36d208f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -16,6 +16,14 @@ */ package org.apache.spark.mllib.tree.model +/** + * Information gain statistics for each split + * @param gain information gain value + * @param impurity current node impurity + * @param leftImpurity left node impurity + * @param rightImpurity right node impurity + * @param predict predicted value + */ class InformationGainStats( val gain : Double, val impurity: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index fb7e0db9c9dd2..374f065a09032 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -20,6 +20,16 @@ import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.FeatureType._ +/** + * Node in a decision tree + * @param id integer node id + * @param predict predicted value at the node + * @param isLeaf whether the leaf is a node + * @param split split to calculate left and right nodes + * @param leftNode left child + * @param rightNode right child + * @param stats information gain stats + */ class Node ( val id : Int, val predict : Double, val isLeaf : Boolean, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 1604996091597..81e57dbf5e521 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -18,6 +18,13 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType +/** + * Split applied to a feature + * @param feature feature index + * @param threshold threshold for continuous feature + * @param featureType type of feature -- categorical or continuous + * @param categories accepted values for categorical variables + */ case class Split( feature: Int, threshold : Double, @@ -29,12 +36,28 @@ case class Split( ", categories = " + categories } -class DummyLowSplit(feature: Int, kind : FeatureType) - extends Split(feature, Double.MinValue, kind, List()) +/** + * Split with minimum threshold for continuous features. Helps with the smallest bin creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyLowSplit(feature: Int, featureType : FeatureType) + extends Split(feature, Double.MinValue, featureType, List()) -class DummyHighSplit(feature: Int, kind : FeatureType) - extends Split(feature, Double.MaxValue, kind, List()) +/** + * Split with maximum threshold for continuous features. Helps with the highest bin creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyHighSplit(feature: Int, featureType : FeatureType) + extends Split(feature, Double.MaxValue, featureType, List()) -class DummyCategoricalSplit(feature: Int, kind : FeatureType) - extends Split(feature, Double.MaxValue, kind, List()) +/** + * Split with no acceptable feature values for categorical features. Helps with the first bin + * creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyCategoricalSplit(feature: Int, featureType : FeatureType) + extends Split(feature, Double.MaxValue, featureType, List()) From d3023b37f8954f8d79b2f0b0d081d9a4eb51415b Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 5 Mar 2014 22:45:53 -0800 Subject: [PATCH 30/48] adding more docs for nested methods --- .../spark/mllib/tree/DecisionTree.scala | 46 +++++++++++++++++-- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b8164f64a7b04..aaa5a4fef6697 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -306,6 +306,20 @@ object DecisionTree extends Serializable with Logging { arr } + /** + Performs a sequential aggregation over a partition for classification. + + for p bins, k features, l nodes (level = log2(l)) storage is of the form: + b111_left_count,b111_right_count, .... , .. + .. bpk1_left_count, bpk1_right_count, .... , .. + .. bpkl_left_count, bpkl_right_count + + @param agg Array[Double] storing aggregate calculation of size + 2*numSplits*numFeatures*numNodes for classification + @param arr Array[Double] of size 1+(numFeatures*numNodes) + @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes + for classification + */ def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { for (node <- 0 until numNodes) { val validSignalIndex = 1 + numFeatures * node @@ -326,6 +340,20 @@ object DecisionTree extends Serializable with Logging { } } + /** + Performs a sequential aggregation over a partition for regression. + + for p bins, k features, l nodes (level = log2(l)) storage is of the form: + b111_count,b111_sum, b111_sum_squares .... , .. + .. bpk1_count, bpk1_sum, bpk1_sum_squares, .... , .. + .. bpkl_count, bpkl_sum, bpkl_sum_squares + + @param agg Array[Double] storing aggregate calculation of size + 3*numSplits*numFeatures*numNodes for classification + @param arr Array[Double] of size 1+(numFeatures*numNodes) + @return Array[Double] storing aggregate calculation of size 3*numSplits*numFeatures*numNodes + for regression + */ def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { for (node <- 0 until numNodes) { val validSignalIndex = 1 + numFeatures * node @@ -354,11 +382,11 @@ object DecisionTree extends Serializable with Logging { .. bpk1_left_count, bpk1_right_count, .... , .. .. bpkl_left_count, bpkl_right_count - @param agg Array[Double] storing aggregate calculation of size - 2*numSplits*numFeatures*numNodes for classification + @param agg Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification + and 3*numSplits*numFeatures*numNodes for regression @param arr Array[Double] of size 1+(numFeatures*numNodes) - @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes - for classification + @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification + and 3*numSplits*numFeatures*numNodes for regression */ def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { strategy.algo match { @@ -411,7 +439,15 @@ object DecisionTree extends Serializable with Logging { logDebug("binAggregates.length = " + binAggregates.length) //binAggregates.foreach(x => logDebug(x)) - + /** + * Calculates the information gain for all splits + * @param leftNodeAgg left node aggregates + * @param featureIndex feature index + * @param splitIndex split index + * @param rightNodeAgg right node aggregate + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, splitIndex: Int, From 63e786bf796f77679a46060f123984758bffc585 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 5 Mar 2014 23:12:38 -0800 Subject: [PATCH 31/48] added multiple train methods for java compatability --- .../spark/mllib/tree/DecisionTree.scala | 75 +++++++++++++++++-- .../spark/mllib/tree/DecisionTreeRunner.scala | 2 +- .../mllib/tree/configuration/Strategy.scala | 2 +- 3 files changed, 72 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index aaa5a4fef6697..1c813244e5630 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.impurity.Impurity /** A class that implements a decision tree algorithm for classification and regression. @@ -38,7 +39,7 @@ algorithm (classification, regression, etc.), feature type (continuous, categorical), depth of the tree, quantile calculation strategy, etc. */ -class DecisionTree(val strategy : Strategy) extends Serializable with Logging { +class DecisionTree private (val strategy : Strategy) extends Serializable with Logging { /** Method to train a decision tree model over an RDD @@ -157,6 +158,70 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { object DecisionTree extends Serializable with Logging { + /** + Method to train a decision tree model over an RDD + + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + for DecisionTree + @param strategy The configuration parameters for the tree algorithm which specify the type of algorithm + (classification, regression, etc.), feature type (continuous, categorical), + depth of the tree, quantile calculation strategy, etc. + @return a DecisionTreeModel that can be used for prediction + */ + def train(input : RDD[LabeledPoint], strategy : Strategy) : DecisionTreeModel = { + new DecisionTree(strategy).train(input : RDD[LabeledPoint]) + } + + /** + Method to train a decision tree model over an RDD + + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + for DecisionTree + @param algo classification or regression + @param impurity criterion used for information gain calculation + @param maxDepth maximum depth of the tree + @return a DecisionTreeModel that can be used for prediction + */ + def train( + input : RDD[LabeledPoint], + algo : Algo, + impurity : Impurity, + maxDepth : Int + ) : DecisionTreeModel = { + val strategy = new Strategy(algo,impurity,maxDepth) + new DecisionTree(strategy).train(input : RDD[LabeledPoint]) + } + + + /** + Method to train a decision tree model over an RDD + + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + for DecisionTree + @param algo classification or regression + @param impurity criterion used for information gain calculation + @param maxDepth maximum depth of the tree + @param maxBins maximum number of bins used for splitting features + @param quantileCalculationStrategy algorithm for calculating quantiles + @param categoricalFeaturesInfo A map storing information about the categorical variables and the number of discrete + values they take. For example, an entry (n -> k) implies the feature n is + categorical with k categories 0, 1, 2, ... , k-1. It's important to note that + features are zero-indexed. + @return a DecisionTreeModel that can be used for prediction + */ + def train( + input : RDD[LabeledPoint], + algo : Algo, + impurity : Impurity, + maxDepth : Int, + maxBins : Int, + quantileCalculationStrategy : QuantileStrategy, + categoricalFeaturesInfo : Map[Int,Int] + ) : DecisionTreeModel = { + val strategy = new Strategy(algo,impurity,maxDepth,maxBins,quantileCalculationStrategy,categoricalFeaturesInfo) + new DecisionTree(strategy).train(input : RDD[LabeledPoint]) + } + /** Returns an Array[Split] of optimal splits for all nodes at a given level @@ -717,13 +782,13 @@ object DecisionTree extends Serializable with Logging { for DecisionTree @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing parameters for construction the DecisionTree - @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures, - numSplits-1) and bins is an - Array[Array[Bin]] of size (numFeatures,numSplits1) + @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree.model.Split] of + size (numFeatures,numSplits-1) and bins is an Array of [org.apache.spark.mllib.tree.model.Bin] of + size (numFeatures,numSplits1) */ def findSplitsBins( input : RDD[LabeledPoint], - strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { + strategy : Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala index 05d000f3a3ddc..d93633d26228d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala @@ -87,7 +87,7 @@ object DecisionTreeRunner extends Logging { val maxBins = options.getOrElse('maxBins,"100").toString.toInt val strategy = new Strategy(algo = algo, impurity = impurity, maxDepth = maxDepth, maxBins = maxBins) - val model = new DecisionTree(strategy).train(trainData) + val model = DecisionTree.train(trainData,strategy) //Load test data val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 973aaee49e5fb..88dfa76fc284f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -37,7 +37,7 @@ class Strategy ( val algo : Algo, val impurity : Impurity, val maxDepth : Int, - val maxBins : Int, + val maxBins : Int = 100, val quantileCalculationStrategy : QuantileStrategy = Sort, val categoricalFeaturesInfo : Map[Int,Int] = Map[Int,Int]()) extends Serializable { From cd2c2b436bfac172cbfeb115220d988042080915 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Fri, 7 Mar 2014 00:38:00 -0800 Subject: [PATCH 32/48] fixing code style based on feedback --- .../spark/mllib/tree/DecisionTree.scala | 467 +++++++++++------- .../spark/mllib/tree/DecisionTreeRunner.scala | 143 ------ .../spark/mllib/tree/configuration/Algo.scala | 3 +- .../tree/configuration/FeatureType.scala | 1 + .../tree/configuration/QuantileStrategy.scala | 1 + .../mllib/tree/configuration/Strategy.scala | 15 +- .../spark/mllib/tree/impurity/Entropy.scala | 1 + .../spark/mllib/tree/impurity/Gini.scala | 1 + .../spark/mllib/tree/impurity/Impurity.scala | 19 +- .../spark/mllib/tree/impurity/Variance.scala | 4 +- .../apache/spark/mllib/tree/model/Bin.scala | 3 +- .../mllib/tree/model/DecisionTreeModel.scala | 5 +- .../spark/mllib/tree/model/Filter.scala | 3 +- .../tree/model/InformationGainStats.scala | 9 +- .../apache/spark/mllib/tree/model/Node.scala | 28 +- .../apache/spark/mllib/tree/model/Split.scala | 7 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 7 +- 17 files changed, 365 insertions(+), 352 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1c813244e5630..d57cb6dc4c91d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -18,37 +18,35 @@ package org.apache.spark.mllib.tree import org.apache.spark.SparkContext._ +import scala.util.control.Breaks._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.tree.model._ import org.apache.spark.{SparkContext, Logging} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.Split -import scala.util.control.Breaks._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} /** -A class that implements a decision tree algorithm for classification and regression. -It supports both continuous and categorical features. - -@param strategy The configuration parameters for the tree algorithm which specify the type of -algorithm (classification, -regression, etc.), feature type (continuous, categorical), depth of the tree, -quantile calculation strategy, etc. - */ -class DecisionTree private (val strategy : Strategy) extends Serializable with Logging { + * A class that implements a decision tree algorithm for classification and regression. It + * supports both continuous and categorical features. + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of algorithm (classification, regression, etc.), feature type (continuous, + * categorical), + * depth of the tree, quantile calculation strategy, etc. + */ +class DecisionTree private(val strategy: Strategy) extends Serializable with Logging { /** - Method to train a decision tree model over an RDD - - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree - @return a DecisionTreeModel that can be used for prediction + * Method to train a decision tree model over an RDD + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @return a DecisionTreeModel that can be used for prediction */ - def train(input : RDD[LabeledPoint]) : DecisionTreeModel = { + def train(input: RDD[LabeledPoint]): DecisionTreeModel = { //Cache input RDD for speedup during multiple passes input.cache() @@ -59,7 +57,7 @@ class DecisionTree private (val strategy : Strategy) extends Serializable with L val maxDepth = strategy.maxDepth - val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1 + val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 val filters = new Array[List[Filter]](maxNumNodes) filters(0) = List() val parentImpurities = new Array[Double](maxNumNodes) @@ -70,7 +68,7 @@ class DecisionTree private (val strategy : Strategy) extends Serializable with L logDebug("algo = " + strategy.algo) breakable { - for (level <- 0 until maxDepth){ + for (level <- 0 until maxDepth) { logDebug("#####################################") logDebug("level = " + level) @@ -78,19 +76,19 @@ class DecisionTree private (val strategy : Strategy) extends Serializable with L //Find best split for all nodes at a level val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, - level, filters,splits,bins) + level, filters, splits, bins) - for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){ + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - extractNodeInfo(nodeSplitStats, level, index, nodes) + extractNodeInfo(nodeSplitStats, level, index, nodes) extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) logDebug("final best split = " + nodeSplitStats._1) } - require(scala.math.pow(2,level)==splitsStatsForLevel.length) + require(scala.math.pow(2, level) == splitsStatsForLevel.length) - val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 ) + val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) if (allLeaf) break @@ -159,92 +157,89 @@ class DecisionTree private (val strategy : Strategy) extends Serializable with L object DecisionTree extends Serializable with Logging { /** - Method to train a decision tree model over an RDD - - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree - @param strategy The configuration parameters for the tree algorithm which specify the type of algorithm - (classification, regression, etc.), feature type (continuous, categorical), - depth of the tree, quantile calculation strategy, etc. - @return a DecisionTreeModel that can be used for prediction + * Method to train a decision tree model over an RDD + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of algoritm (classification, regression, etc.), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + * @return a DecisionTreeModel that can be used for prediction */ - def train(input : RDD[LabeledPoint], strategy : Strategy) : DecisionTreeModel = { - new DecisionTree(strategy).train(input : RDD[LabeledPoint]) + def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) } /** - Method to train a decision tree model over an RDD - - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree - @param algo classification or regression - @param impurity criterion used for information gain calculation - @param maxDepth maximum depth of the tree - @return a DecisionTreeModel that can be used for prediction - */ + * Method to train a decision tree model over an RDD + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data + * @param algo algo classification or regression + * @param impurity impurity criterion used for information gain calculation + * @param maxDepth maxDepth maximum depth of the tree + * @return a DecisionTreeModel that can be used for prediction + */ def train( - input : RDD[LabeledPoint], - algo : Algo, - impurity : Impurity, - maxDepth : Int - ) : DecisionTreeModel = { + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int + ): DecisionTreeModel = { val strategy = new Strategy(algo,impurity,maxDepth) - new DecisionTree(strategy).train(input : RDD[LabeledPoint]) + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) } /** - Method to train a decision tree model over an RDD - - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree - @param algo classification or regression - @param impurity criterion used for information gain calculation - @param maxDepth maximum depth of the tree - @param maxBins maximum number of bins used for splitting features - @param quantileCalculationStrategy algorithm for calculating quantiles - @param categoricalFeaturesInfo A map storing information about the categorical variables and the number of discrete - values they take. For example, an entry (n -> k) implies the feature n is - categorical with k categories 0, 1, 2, ... , k-1. It's important to note that - features are zero-indexed. - @return a DecisionTreeModel that can be used for prediction - */ + * Method to train a decision tree model over an RDD + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data for DecisionTree + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, + * an entry (n -> k) implies the feature n is categorical with k + * categories 0, 1, 2, ... , k-1. It's important to note that + * features are zero-indexed. + * @return a DecisionTreeModel that can be used for prediction + */ def train( - input : RDD[LabeledPoint], - algo : Algo, - impurity : Impurity, - maxDepth : Int, - maxBins : Int, - quantileCalculationStrategy : QuantileStrategy, - categoricalFeaturesInfo : Map[Int,Int] - ) : DecisionTreeModel = { - val strategy = new Strategy(algo,impurity,maxDepth,maxBins,quantileCalculationStrategy,categoricalFeaturesInfo) - new DecisionTree(strategy).train(input : RDD[LabeledPoint]) + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int, + maxBins: Int, + quantileCalculationStrategy: QuantileStrategy, + categoricalFeaturesInfo: Map[Int,Int] + ): DecisionTreeModel = { + val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, + categoricalFeaturesInfo) + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) } /** - Returns an Array[Split] of optimal splits for all nodes at a given level - - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree - @param parentImpurities Impurities for all parent nodes for the current level - @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - parameters for construction the DecisionTree - @param level Level of the tree - @param filters Filter for all nodes at a given level - @param splits possible splits for all features - @param bins possible bins for all features - - @return Array[Split] instance for best splits for all nodes at a given level. - */ + * Returns an array of optimal splits for all nodes at a given level + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param parentImpurities Impurities for all parent nodes for the current level + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @param level Level of the tree + * @param filters Filters for all nodes at a given level + * @param splits possible splits for all features + * @param bins possible bins for all features + * @return array of splits with best splits for all nodes at a given level. + */ def findBestSplits( - input : RDD[LabeledPoint], - parentImpurities : Array[Double], - strategy: Strategy, - level: Int, - filters : Array[List[Filter]], - splits : Array[Array[Split]], - bins : Array[Array[Bin]]) : Array[(Split, InformationGainStats)] = { + input: RDD[LabeledPoint], + parentImpurities: Array[Double], + strategy: Strategy, + level: Int, + filters: Array[List[Filter]], + splits: Array[Array[Split]], + bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { //Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt @@ -266,8 +261,8 @@ object DecisionTree extends Serializable with Logging { } /** - Find whether the sample is valid input for the current node. - In other words, does it pass through all the filters for the current node. + * Find whether the sample is valid input for the current node. In other words, + * does it pass through all the filters for the current node. */ def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { @@ -302,9 +297,12 @@ object DecisionTree extends Serializable with Logging { } /** - Finds the right bin for the given feature + * Finds the right bin for the given feature */ - def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = { + def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean): Int = { if (isFeatureContinuous){ for (binIndex <- 0 until strategy.numBins) { @@ -334,16 +332,18 @@ object DecisionTree extends Serializable with Logging { } /** - Finds bins for all nodes (and all features) at a given level - k features, l nodes (level = log2(l)) - Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk - Denotes invalid sample for tree by noting bin for feature 1 as -1 + * Finds bins for all nodes (and all features) at a given level k features, + * l nodes (level = log2(l)). + * Storage label, b11, b12, b13, .., b1k, + * b21, b22, .. , b2k, + * bl1, bl2, .. , blk + * Denotes invalid sample for tree by noting bin for feature 1 as -1 */ - def findBinsForLevel(labeledPoint : LabeledPoint) : Array[Double] = { + def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // calculating bin index and label per feature per node - val arr = new Array[Double](1+(numFeatures * numNodes)) + val arr = new Array[Double](1 + (numFeatures * numNodes)) arr(0) = labeledPoint.label for (nodeIndex <- 0 until numNodes) { val parentFilters = findParentFilters(nodeIndex) @@ -354,7 +354,7 @@ object DecisionTree extends Serializable with Logging { //Add to invalid bin index -1 breakable { for (featureIndex <- 0 until numFeatures) { - arr(shift+featureIndex) = -1 + arr(shift + featureIndex) = -1 //Breaking since marking one bin is sufficient break() } @@ -440,20 +440,19 @@ object DecisionTree extends Serializable with Logging { } /** - Performs a sequential aggregation over a partition. - - for p bins, k features, l nodes (level = log2(l)) storage is of the form: - b111_left_count,b111_right_count, .... , .. - .. bpk1_left_count, bpk1_right_count, .... , .. - .. bpkl_left_count, bpkl_right_count - - @param agg Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification - and 3*numSplits*numFeatures*numNodes for regression - @param arr Array[Double] of size 1+(numFeatures*numNodes) - @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification - and 3*numSplits*numFeatures*numNodes for regression - */ - def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { + * Performs a sequential aggregation over a partition. + * for p bins, k features, l nodes (level = log2(l)) storage is of the form: + * b111_left_count,b111_right_count, .... , .... + * bpk1_left_count, bpk1_right_count, .... , ...., bpkl_left_count, bpkl_right_count + * @param agg Array[Double] storing aggregate calculation of size + * 2*numSplits*numFeatures*numNodes for classification and + * 3*numSplits*numFeatures*numNodes for regression + * @param arr Array[Double] of size 1+(numFeatures*numNodes) + * @return Array[Double] storing aggregate calculation of size + * 2*numSplits*numFeatures*numNodes for classification and + * 3*numSplits*numFeatures*numNodes for regression + */ + def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { case Classification => classificationBinSeqOp(arr, agg) case Regression => regressionBinSeqOp(arr, agg) @@ -468,13 +467,12 @@ object DecisionTree extends Serializable with Logging { logDebug("binAggregateLength = " + binAggregateLength) /** - Combines the aggregates from partitions - @param agg1 Array containing aggregates from one or more partitions - @param agg2 Array containing aggregates from one or more partitions - - @return Combined aggregate from agg1 and agg2 + * Combines the aggregates from partitions + * @param agg1 Array containing aggregates from one or more partitions + * @param agg2 Array containing aggregates from one or more partitions + * @return Combined aggregate from agg1 and agg2 */ - def binCombOp(agg1 : Array[Double], agg2: Array[Double]) : Array[Double] = { + def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = { strategy.algo match { case Classification => { val combinedAggregate = new Array[Double](binAggregateLength) @@ -513,11 +511,13 @@ object DecisionTree extends Serializable with Logging { * @param topImpurity impurity of the parent node * @return information gain and statistics for all splits */ - def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], - featureIndex: Int, - splitIndex: Int, - rightNodeAgg: Array[Array[Double]], - topImpurity: Double) : InformationGainStats = { + def calculateGainForSplit( + leftNodeAgg: Array[Array[Double]], + featureIndex: Int, + splitIndex: Int, + rightNodeAgg: Array[Array[Double]], + topImpurity: Double): InformationGainStats = { + strategy.algo match { case Classification => { @@ -606,19 +606,18 @@ object DecisionTree extends Serializable with Logging { } } - val predict = (leftSum + rightSum)/(leftCount+rightCount) - new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict) + val predict = (leftSum + rightSum)/(leftCount + rightCount) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) } } } - /** - Extracts left and right split aggregates - - @param binData Array[Double] of size 2*numFeatures*numSplits - @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], Array[Double]) where - each array is of size(numFeature,2*(numSplits-1)) + /** + * Extracts left and right split aggregates + * @param binData Array[Double] of size 2*numFeatures*numSplits + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], + * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) */ def extractLeftRightNodeAggregates( binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { @@ -696,8 +695,7 @@ object DecisionTree extends Serializable with Logging { def calculateGainsForAllNodeSplits( leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], - nodeImpurity: Double) - : Array[Array[InformationGainStats]] = { + nodeImpurity: Double): Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) @@ -710,14 +708,15 @@ object DecisionTree extends Serializable with Logging { gains } - /** - Find the best split for a node given bin aggregate data - - @param binData Array[Double] of size 2*numSplits*numFeatures - */ + /** + * Find the best split for a node given bin aggregate data + * @param binData Array[Double] of size 2*numSplits*numFeatures + * @param nodeImpurity impurity of the top node + * @return + */ def binsToBestSplit( - binData : Array[Double], - nodeImpurity : Double) : (Split, InformationGainStats) = { + binData: Array[Double], + nodeImpurity: Double): (Split, InformationGainStats) = { logDebug("node impurity = " + nodeImpurity) val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) @@ -771,24 +770,24 @@ object DecisionTree extends Serializable with Logging { logDebug("node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) } - bestSplits } + + /** - Returns split and bins for decision tree calculation. - - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree - @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - parameters for construction the DecisionTree - @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree.model.Split] of - size (numFeatures,numSplits-1) and bins is an Array of [org.apache.spark.mllib.tree.model.Bin] of - size (numFeatures,numSplits1) + * Returns split and bins for decision tree calculation. + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree + * .model.Split] of size (numFeatures,numSplits-1) and bins is an Array of [org.apache + * .spark.mllib.tree.model.Bin] of size (numFeatures,numSplits1) */ def findSplitsBins( - input : RDD[LabeledPoint], - strategy : Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + input: RDD[LabeledPoint], + strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() @@ -807,7 +806,7 @@ object DecisionTree extends Serializable with Logging { val sampledInput = input.sample(false, fraction, 42).collect() val numSamples = sampledInput.length - val stride : Double = numSamples.toDouble/numBins + val stride: Double = numSamples.toDouble/numBins logDebug("stride = " + stride) strategy.quantileCalculationStrategy match { @@ -821,11 +820,11 @@ object DecisionTree extends Serializable with Logging { if (isFeatureContinuous) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride : Double = numSamples.toDouble/numBins + val stride: Double = numSamples.toDouble/numBins logDebug("stride = " + stride) for (index <- 0 until numBins-1) { - val sampleIndex = (index+1)*stride.toInt - val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List()) + val sampleIndex = (index + 1)*stride.toInt + val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) splits(featureIndex)(index) = split } } else { @@ -856,17 +855,17 @@ object DecisionTree extends Serializable with Logging { var categoriesForSplit = List[Double]() categoriesSortedByCentriod.iterator.zipWithIndex foreach { case((key, value), index) => { - categoriesForSplit = key :: categoriesForSplit - splits(featureIndex)(index) = new Split(featureIndex,Double.MinValue,Categorical, + categoriesForSplit = key:: categoriesForSplit + splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) bins(featureIndex)(index) = { if(index == 0) { - new Bin(new DummyCategoricalSplit(featureIndex,Categorical), - splits(featureIndex)(0),Categorical,key) + new Bin(new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), Categorical, key) } else { - new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index), - Categorical,key) + new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Categorical, key) } } } @@ -882,19 +881,19 @@ object DecisionTree extends Serializable with Logging { = new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0), Continuous,Double.MinValue) for (index <- 1 until numBins - 1){ - val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index), - Continuous,Double.MinValue) + val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Continuous, Double.MinValue) bins(featureIndex)(index) = bin } bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, - Continuous),Continuous,Double.MinValue) + Continuous), Continuous, Double.MinValue) } else { val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) for (i <- maxFeatureValue until numBins){ bins(featureIndex)(i) - = new Bin(new DummyCategoricalSplit(featureIndex,Categorical), - new DummyCategoricalSplit(featureIndex,Categorical),Categorical,Double.MaxValue) + = new Bin(new DummyCategoricalSplit(featureIndex, Categorical), + new DummyCategoricalSplit(featureIndex, Categorical), Categorical, Double.MaxValue) } } } @@ -906,10 +905,126 @@ object DecisionTree extends Serializable with Logging { case ApproxHist => { throw new UnsupportedOperationException("approximate histogram not supported yet.") } + } + } + + + val usage = """ + Usage: DecisionTreeRunner[slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] + """ + + + def main(args: Array[String]) { + + if (args.length < 2) { + System.err.println(usage) + System.exit(1) + } + + val sc = new SparkContext(args(0), "DecisionTree") + + + val arglist = args.toList.drop(1) + type OptionMap = Map[Symbol, Any] + + def nextOption(map : OptionMap, list: List[String]): OptionMap = { + def isSwitch(s : String) = (s(0) == '-') + list match { + case Nil => map + case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail) + case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) + case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) + case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) + case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string) + , tail) + case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), + tail) + case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) + case option :: tail => logError("Unknown option " + option) + sys.exit(1) + } + } + val options = nextOption(Map(),arglist) + logDebug(options.toString()) + //TODO: Add validation for input parameters + + //Load training data + val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) + + //Figure out the type of algorithm + val algoStr = options.get('algo).get.toString + val algo = algoStr match { + case "Classification" => Classification + case "Regression" => Regression + } + //Identify the type of impurity + val impurityStr = options.getOrElse('impurity, + if (algo == Classification) "Gini" else "Variance").toString + val impurity = impurityStr match { + case "Gini" => Gini + case "Entropy" => Entropy + case "Variance" => Variance } + + val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt + val maxBins = options.getOrElse('maxBins,"100").toString.toInt + + val strategy = new Strategy(algo = algo, impurity = impurity, maxDepth = maxDepth, + maxBins = maxBins) + val model = DecisionTree.train(trainData,strategy) + + //Load test data + val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) + + //Measure algorithm accuracy + val accuracy = accuracyScore(model, testData) + logDebug("accuracy = " + accuracy) + + val mse = meanSquaredError(model,testData) + logDebug("mean square error = " + mse) + + sc.stop() + } + + /** + * Load labeled data from a file. The data format used here is + * , ... + * where , are feature values in Double and is the corresponding label as Double. + * + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is + * the label, and the second element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { + sc.textFile(dir).map { line => + val parts = line.trim().split(",") + val label = parts(0).toDouble + val features = parts.slice(1,parts.length).map(_.toDouble) + LabeledPoint(label, features) + } + } + + //TODO: Port them to a metrics package + def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { + val correctCount = data.filter(y => model.predict(y.features) == y.label).count() + val count = data.count() + logDebug("correct prediction count = " + correctCount) + logDebug("data count = " + count) + correctCount.toDouble / count + } + + //TODO: Make these generic MLTable metrics + def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { + val meanSumOfSquares = + data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)) + .mean() + meanSumOfSquares } -} \ No newline at end of file +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala deleted file mode 100644 index d93633d26228d..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.mllib.tree - -import org.apache.spark.SparkContext._ -import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.mllib.tree.impurity.{Gini,Entropy,Variance} -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.configuration.Algo._ - - -object DecisionTreeRunner extends Logging { - - val usage = """ - Usage: DecisionTreeRunner[slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] - """ - - - def main(args: Array[String]) { - - if (args.length < 2) { - System.err.println(usage) - System.exit(1) - } - - val sc = new SparkContext(args(0), "DecisionTree") - - - val arglist = args.toList.drop(1) - type OptionMap = Map[Symbol, Any] - - def nextOption(map : OptionMap, list: List[String]) : OptionMap = { - def isSwitch(s : String) = (s(0) == '-') - list match { - case Nil => map - case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail) - case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) - case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) - case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) - case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail) - case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail) - case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) - case option :: tail => logError("Unknown option "+option) - sys.exit(1) - } - } - val options = nextOption(Map(),arglist) - logDebug(options.toString()) - //TODO: Add validation for input parameters - - //Load training data - val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) - - //Figure out the type of algorithm - val algoStr = options.get('algo).get.toString - val algo = algoStr match { - case "Classification" => Classification - case "Regression" => Regression - } - - //Identify the type of impurity - val impurityStr = options.getOrElse('impurity,if (algo == Classification) "Gini" else "Variance").toString - val impurity = impurityStr match { - case "Gini" => Gini - case "Entropy" => Entropy - case "Variance" => Variance - } - - val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt - val maxBins = options.getOrElse('maxBins,"100").toString.toInt - - val strategy = new Strategy(algo = algo, impurity = impurity, maxDepth = maxDepth, maxBins = maxBins) - val model = DecisionTree.train(trainData,strategy) - - //Load test data - val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) - - //Measure algorithm accuracy - val accuracy = accuracyScore(model, testData) - logDebug("accuracy = " + accuracy) - - val mse = meanSquaredError(model,testData) - logDebug("mean square error = " + mse) - - sc.stop() - } - - /** - * Load labeled data from a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. - * - * @param sc SparkContext - * @param dir Directory to the input data files. - * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is - * the label, and the second element represents the feature values (an array of Double). - */ - def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { - sc.textFile(dir).map { line => - val parts = line.trim().split(",") - val label = parts(0).toDouble - val features = parts.slice(1,parts.length).map(_.toDouble) - LabeledPoint(label, features) - } - } - - //TODO: Port them to a metrics package - def accuracyScore(model : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = { - val correctCount = data.filter(y => model.predict(y.features) == y.label).count() - val count = data.count() - logDebug("correct prediction count = " + correctCount) - logDebug("data count = " + count) - correctCount.toDouble / count - } - - //TODO: Make these generic MLTable metrics - def meanSquaredError(tree : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = { - val meanSumOfSquares = - data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean() - meanSumOfSquares - } - - - - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 14ec47ce014e7..2dd1f0f27b8f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.configuration /** @@ -22,4 +23,4 @@ package org.apache.spark.mllib.tree.configuration object Algo extends Enumeration { type Algo = Value val Classification, Regression = Value -} \ No newline at end of file +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala index b4e8ae4ac39dd..09ee0586c58fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.configuration /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index dae73ab52dea7..2457a480c2a14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.configuration /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 88dfa76fc284f..9e461cfdbbd08 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.configuration import org.apache.spark.mllib.tree.impurity.Impurity @@ -34,13 +35,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * zero-indexed. */ class Strategy ( - val algo : Algo, - val impurity : Impurity, - val maxDepth : Int, - val maxBins : Int = 100, - val quantileCalculationStrategy : QuantileStrategy = Sort, - val categoricalFeaturesInfo : Map[Int,Int] = Map[Int,Int]()) extends Serializable { + val algo: Algo, + val impurity: Impurity, + val maxDepth: Int, + val maxBins: Int = 100, + val quantileCalculationStrategy: QuantileStrategy = Sort, + val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable { - var numBins : Int = Int.MinValue + var numBins: Int = Int.MinValue } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index c1b2972f9c25b..9018821abc875 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 099c7e33dd39a..20af8f6c1c2cd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index cda534b462234..97092c85aea61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -14,12 +14,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.impurity +/** + * Trail for calculating information gain + */ trait Impurity extends Serializable { + /** + * information calculation for binary classification + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return information value + */ def calculate(c0 : Double, c1 : Double): Double - def calculate(count : Double, sum : Double, sumSquares : Double) : Double + /** + * information calculation for regression + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + * @return information value + */ + def calculate(count: Double, sum: Double, sumSquares: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index b313b8d48eadf..85b7be560fecb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException @@ -23,7 +24,8 @@ import org.apache.spark.Logging * Class for calculating variance during regression */ object Variance extends Impurity with Logging { - def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") + def calculate(c0: Double, c1: Double): Double + = throw new OperationNotSupportedException("Variance.calculate") /** * variance calculation diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 0b4b7d2e5b2df..47afe3aed2b1b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ @@ -29,6 +30,6 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * @param featureType type of feature -- categorical or continuous * @param category categorical label value accepted in the bin */ -case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType, category : Double) { +case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 0e94827b0af70..94d77571dc22f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.Algo._ @@ -24,7 +25,7 @@ import org.apache.spark.rdd.RDD * @param topNode root node * @param algo algorithm type -- classification or regression */ -class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializable { +class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable { /** * Predict values for a single data point using the model trained. @@ -32,7 +33,7 @@ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializabl * @param features array representing a single data point * @return Double prediction from the trained model */ - def predict(features : Array[Double]) : Double = { + def predict(features: Array[Double]): Double = { algo match { case Classification => { if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala index 9fc794c87398d..ebc9595eafef3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.model /** @@ -21,7 +22,7 @@ package org.apache.spark.mllib.tree.model * @param split split specifying the feature index, type and threshold * @param comparison integer specifying <,=,> */ -case class Filter(split : Split, comparison : Int) { +case class Filter(split: Split, comparison: Int) { // Comparison -1,0,1 signifies <.=,> override def toString = " split = " + split + "comparison = " + comparison } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 0f8d7a36d208f..64ff826486f5b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.model /** @@ -25,11 +26,11 @@ package org.apache.spark.mllib.tree.model * @param predict predicted value */ class InformationGainStats( - val gain : Double, + val gain: Double, val impurity: Double, - val leftImpurity : Double, - val rightImpurity : Double, - val predict : Double) extends Serializable { + val leftImpurity: Double, + val rightImpurity: Double, + val predict: Double) extends Serializable { override def toString = { "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 374f065a09032..4a2c876a51b54 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.model import org.apache.spark.Logging @@ -30,19 +31,23 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * @param rightNode right child * @param stats information gain stats */ -class Node ( val id : Int, - val predict : Double, - val isLeaf : Boolean, - val split : Option[Split], - var leftNode : Option[Node], - var rightNode : Option[Node], - val stats : Option[InformationGainStats] - ) extends Serializable with Logging{ +class Node ( + val id: Int, + val predict: Double, + val isLeaf: Boolean, + val split: Option[Split], + var leftNode: Option[Node], + var rightNode: Option[Node], + val stats: Option[InformationGainStats]) extends Serializable with Logging{ override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + "split = " + split + ", stats = " + stats - def build(nodes : Array[Node]) : Unit = { + /** + * build the left node and right nodes if not leaf + * @param nodes array of nodes + */ + def build(nodes : Array[Node]): Unit = { logDebug("building node " + id + " at level " + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) @@ -59,6 +64,11 @@ class Node ( val id : Int, } } + /** + * predict value if node is not leaf + * @param feature feature value + * @return predicted value + */ def predictIfLeaf(feature : Array[Double]) : Double = { if (isLeaf) { predict diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 81e57dbf5e521..fffd68d7a64b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType @@ -27,9 +28,9 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType */ case class Split( feature: Int, - threshold : Double, - featureType : FeatureType, - categories : List[Double]){ + threshold: Double, + featureType: FeatureType, + categories: List[Double]){ override def toString = "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 39635a7e654a2..a299b087dfda8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree import scala.util.Random @@ -393,7 +394,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { object DecisionTreeSuite { - def generateOrderedLabeledPointsWithLabel0() : Array[LabeledPoint] = { + def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i)) @@ -403,7 +404,7 @@ object DecisionTreeSuite { } - def generateOrderedLabeledPointsWithLabel1() : Array[LabeledPoint] = { + def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i)) @@ -412,7 +413,7 @@ object DecisionTreeSuite { arr } - def generateCategoricalDataPoints() : Array[LabeledPoint] = { + def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ if (i < 600){ From eb8fcbeb8131681f0d0667086305a01d0b17f61d Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Fri, 7 Mar 2014 00:45:57 -0800 Subject: [PATCH 33/48] minor code style updates --- .../apache/spark/mllib/tree/DecisionTree.scala | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index d57cb6dc4c91d..9cc7c494f9d64 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.tree -import org.apache.spark.SparkContext._ import scala.util.control.Breaks._ +import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.tree.model._ import org.apache.spark.{SparkContext, Logging} @@ -101,7 +101,6 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log val decisionTreeModel = { return new DecisionTreeModel(topNode, strategy.algo) } - return decisionTreeModel } @@ -538,10 +537,10 @@ object DecisionTree extends Serializable with Logging { } if (leftCount == 0) { - return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity,1) + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1) } if (rightCount == 0) { - return new InformationGainStats(0,topImpurity,topImpurity,Double.MinValue,0) + return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0) } val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) @@ -561,7 +560,7 @@ object DecisionTree extends Serializable with Logging { //val predict = leftCount / (leftCount + rightCount) val predict = (left1Count + right1Count) / (leftCount + rightCount) - new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) } case Regression => { val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) @@ -584,12 +583,12 @@ object DecisionTree extends Serializable with Logging { } if (leftCount == 0) { - return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity, + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, rightSum/rightCount) } if (rightCount == 0) { - return new InformationGainStats(0,topImpurity,topImpurity, - Double.MinValue,leftSum/leftCount) + return new InformationGainStats(0, topImpurity ,topImpurity, + Double.MinValue, leftSum/leftCount) } val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) @@ -1024,7 +1023,4 @@ object DecisionTree extends Serializable with Logging { .mean() meanSumOfSquares } - - - } From 794ff4d90bca2b2c12235439a8603eba76ddd463 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 9 Mar 2014 23:12:37 -0700 Subject: [PATCH 34/48] minor improvements to docs and style --- .../spark/mllib/tree/DecisionTree.scala | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 9cc7c494f9d64..bba67bb5894a9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.tree -import scala.util.control.Breaks._ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.tree.model._ @@ -29,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} +import scala.util.control.Breaks._ /** * A class that implements a decision tree algorithm for classification and regression. It @@ -181,8 +181,8 @@ object DecisionTree extends Serializable with Logging { input: RDD[LabeledPoint], algo: Algo, impurity: Impurity, - maxDepth: Int - ): DecisionTreeModel = { + maxDepth: Int) + : DecisionTreeModel = { val strategy = new Strategy(algo,impurity,maxDepth) new DecisionTree(strategy).train(input: RDD[LabeledPoint]) } @@ -211,8 +211,8 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, - categoricalFeaturesInfo: Map[Int,Int] - ): DecisionTreeModel = { + categoricalFeaturesInfo: Map[Int,Int]) + : DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree(strategy).train(input: RDD[LabeledPoint]) @@ -238,7 +238,8 @@ object DecisionTree extends Serializable with Logging { level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], - bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { + bins: Array[Array[Bin]]) + : Array[(Split, InformationGainStats)] = { //Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt @@ -301,7 +302,8 @@ object DecisionTree extends Serializable with Logging { def findBin( featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean): Int = { + isFeatureContinuous: Boolean) + : Int = { if (isFeatureContinuous){ for (binIndex <- 0 until strategy.numBins) { @@ -515,7 +517,8 @@ object DecisionTree extends Serializable with Logging { featureIndex: Int, splitIndex: Int, rightNodeAgg: Array[Array[Double]], - topImpurity: Double): InformationGainStats = { + topImpurity: Double) + : InformationGainStats = { strategy.algo match { case Classification => { @@ -694,7 +697,8 @@ object DecisionTree extends Serializable with Logging { def calculateGainsForAllNodeSplits( leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], - nodeImpurity: Double): Array[Array[InformationGainStats]] = { + nodeImpurity: Double) + : Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) @@ -715,7 +719,8 @@ object DecisionTree extends Serializable with Logging { */ def binsToBestSplit( binData: Array[Double], - nodeImpurity: Double): (Split, InformationGainStats) = { + nodeImpurity: Double) + : (Split, InformationGainStats) = { logDebug("node impurity = " + nodeImpurity) val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) @@ -786,7 +791,8 @@ object DecisionTree extends Serializable with Logging { */ def findSplitsBins( input: RDD[LabeledPoint], - strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + strategy: Strategy) + : (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() @@ -947,12 +953,11 @@ object DecisionTree extends Serializable with Logging { } val options = nextOption(Map(),arglist) logDebug(options.toString()) - //TODO: Add validation for input parameters //Load training data val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) - //Figure out the type of algorithm + //Identify the type of algorithm val algoStr = options.get('algo).get.toString val algo = algoStr match { case "Classification" => Classification @@ -1007,7 +1012,10 @@ object DecisionTree extends Serializable with Logging { } } - //TODO: Port them to a metrics package + //TODO: Port this method to a generic metrics package + /** + * Calculates the classifier accuracy. + */ def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { val correctCount = data.filter(y => model.predict(y.features) == y.label).count() val count = data.count() @@ -1016,7 +1024,10 @@ object DecisionTree extends Serializable with Logging { correctCount.toDouble / count } - //TODO: Make these generic MLTable metrics + //TODO: Port this method to a generic metrics package + /** + * Calculates the mean squared error for regression + */ def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { val meanSumOfSquares = data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)) From d1ef4f68c3b194fc96989860d539d5ffd502877d Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 10 Mar 2014 00:04:40 -0700 Subject: [PATCH 35/48] more documentation --- .../spark/mllib/tree/DecisionTree.scala | 60 ++++++++++++------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index bba67bb5894a9..059a9336b5f9e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -50,23 +50,34 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log //Cache input RDD for speedup during multiple passes input.cache() + logDebug("algo = " + strategy.algo) + //Finding the splits and the corresponding bins (interval between the splits) using a sample + // of the input data. val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) logDebug("numSplits = " + bins(0).length) + + //Noting numBins for the input data strategy.numBins = bins(0).length + //The depth of the decision tree val maxDepth = strategy.maxDepth - + //The max number of nodes possible given the depth of the tree val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 + //Initalizing an array to hold filters applied to points for each node val filters = new Array[List[Filter]](maxNumNodes) + //The filter at the top node is an empty list filters(0) = List() + //Initializing an array to hold parent impurity calculations for each node val parentImpurities = new Array[Double](maxNumNodes) //Dummy value for top node (updated during first split calculation) - //parentImpurities(0) = Double.MinValue val nodes = new Array[Node](maxNumNodes) - logDebug("algo = " + strategy.algo) - + //The main-idea here is to perform level-wise training of the decision tree nodes thus + // reducing the passes over the data from l to log2(l) where l is the total number of nodes. + // Each data sample is checked for validity w.r.t to each node at a given level -- i.e., + // the sample is only used for the split calculation at the node if the sampled would have + // still survived the filters of the parent nodes. breakable { for (level <- 0 until maxDepth) { @@ -79,36 +90,41 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log level, filters, splits, bins) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - + //Extract info for nodes at the current level extractNodeInfo(nodeSplitStats, level, index, nodes) + //Extract info for nodes at the next lower level extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) logDebug("final best split = " + nodeSplitStats._1) } require(scala.math.pow(2, level) == splitsStatsForLevel.length) - + //Check whether all the nodes at the current level at leaves val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) - if (allLeaf) break + if (allLeaf) break //no more tree construction } } + //Initialize the top or root node of the tree val topNode = nodes(0) + //Build the full tree using the node info calculated in the level-wise best split calculations topNode.build(nodes) - val decisionTreeModel = { - return new DecisionTreeModel(topNode, strategy.algo) - } - return decisionTreeModel + //Return a decision tree model + return new DecisionTreeModel(topNode, strategy.algo) } - + /** + * Extract the decision tree node information for th given tree level and node index + */ private def extractNodeInfo( nodeSplitStats: (Split, InformationGainStats), - level: Int, index: Int, - nodes: Array[Node]) { + level: Int, + index: Int, + nodes: Array[Node]) + : Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 @@ -119,35 +135,37 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log nodes(nodeIndex) = node } + /** + * Extract the decision tree node information for the children of the node + */ private def extractInfoForLowerLevels( level: Int, index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), parentImpurities: Array[Double], - filters: Array[List[Filter]]) { + filters: Array[List[Filter]]) + : Unit = { + // 0 corresponds to the left child node and 1 corresponds to the right child node. for (i <- 0 to 1) { - + //Calculating the index of the node from the node level and the index at the current level val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth - 1) { - val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity } else { nodeSplitStats._2.rightImpurity } - logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) + //noting the parent impurities parentImpurities(nodeIndex) = impurity + //noting the parents filters for the child nodes val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) - for (filter <- filters(nodeIndex)) { logDebug("Filter = " + filter) } - } } } From ad1fc214e0dff277531408506478594985c8f6af Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 10 Mar 2014 21:23:50 -0700 Subject: [PATCH 36/48] incorporated mengxr's code style suggestions --- .../spark/mllib/tree/DecisionTree.scala | 238 +++++++++--------- 1 file changed, 117 insertions(+), 121 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 059a9336b5f9e..085832cf12070 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import scala.util.control.Breaks._ + import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.tree.model._ @@ -28,7 +30,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} -import scala.util.control.Breaks._ +import java.util.Random +import org.apache.spark.util.random.XORShiftRandom /** * A class that implements a decision tree algorithm for classification and regression. It @@ -48,32 +51,32 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - //Cache input RDD for speedup during multiple passes + // Cache input RDD for speedup during multiple passes input.cache() logDebug("algo = " + strategy.algo) - //Finding the splits and the corresponding bins (interval between the splits) using a sample + // Finding the splits and the corresponding bins (interval between the splits) using a sample // of the input data. val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) logDebug("numSplits = " + bins(0).length) - //Noting numBins for the input data + // Noting numBins for the input data strategy.numBins = bins(0).length - //The depth of the decision tree + // The depth of the decision tree val maxDepth = strategy.maxDepth - //The max number of nodes possible given the depth of the tree + // The max number of nodes possible given the depth of the tree val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 - //Initalizing an array to hold filters applied to points for each node + // Initalizing an array to hold filters applied to points for each node val filters = new Array[List[Filter]](maxNumNodes) - //The filter at the top node is an empty list + // The filter at the top node is an empty list filters(0) = List() - //Initializing an array to hold parent impurity calculations for each node + // Initializing an array to hold parent impurity calculations for each node val parentImpurities = new Array[Double](maxNumNodes) - //Dummy value for top node (updated during first split calculation) + // Dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) - //The main-idea here is to perform level-wise training of the decision tree nodes thus + // The main-idea here is to perform level-wise training of the decision tree nodes thus // reducing the passes over the data from l to log2(l) where l is the total number of nodes. // Each data sample is checked for validity w.r.t to each node at a given level -- i.e., // the sample is only used for the split calculation at the node if the sampled would have @@ -85,21 +88,21 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log logDebug("level = " + level) logDebug("#####################################") - //Find best split for all nodes at a level + // Find best split for all nodes at a level val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters, splits, bins) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - //Extract info for nodes at the current level + // Extract info for nodes at the current level extractNodeInfo(nodeSplitStats, level, index, nodes) - //Extract info for nodes at the next lower level + // Extract info for nodes at the next lower level extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) logDebug("final best split = " + nodeSplitStats._1) } require(scala.math.pow(2, level) == splitsStatsForLevel.length) - //Check whether all the nodes at the current level at leaves + // Check whether all the nodes at the current level at leaves val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) if (allLeaf) break //no more tree construction @@ -107,12 +110,12 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log } } - //Initialize the top or root node of the tree + // Initialize the top or root node of the tree val topNode = nodes(0) - //Build the full tree using the node info calculated in the level-wise best split calculations + // Build the full tree using the node info calculated in the level-wise best split calculations topNode.build(nodes) - //Return a decision tree model + // Return a decision tree model return new DecisionTreeModel(topNode, strategy.algo) } @@ -149,7 +152,7 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log // 0 corresponds to the left child node and 1 corresponds to the right child node. for (i <- 0 to 1) { - //Calculating the index of the node from the node level and the index at the current level + // Calculating the index of the node from the node level and the index at the current level val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i if (level < maxDepth - 1) { val impurity = if (i == 0) { @@ -158,9 +161,9 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log nodeSplitStats._2.rightImpurity } logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) - //noting the parent impurities + // noting the parent impurities parentImpurities(nodeIndex) = impurity - //noting the parents filters for the child nodes + // noting the parents filters for the child nodes val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) for (filter <- filters(nodeIndex)) { @@ -236,6 +239,8 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(input: RDD[LabeledPoint]) } + val InvalidBinIndex = -1 + /** * Returns an array of optimal splits for all nodes at a given level * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data @@ -259,16 +264,16 @@ object DecisionTree extends Serializable with Logging { bins: Array[Array[Bin]]) : Array[(Split, InformationGainStats)] = { - //Common calculations for multiple nested methods + // Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt logDebug("numNodes = " + numNodes) - //Find the number of features by looking at the first sample - val numFeatures = input.take(1)(0).features.length + // Find the number of features by looking at the first sample + val numFeatures = input.first().features.length logDebug("numFeatures = " + numFeatures) val numBins = strategy.numBins logDebug("numBins = " + numBins) - /*Find the filters used before reaching the current code*/ + /** Find the filters used before reaching the current code */ def findParentFilters(nodeIndex: Int): List[Filter] = { if (level == 0) { List[Filter]() @@ -284,7 +289,7 @@ object DecisionTree extends Serializable with Logging { */ def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { - //Leaf + // Leaf if ((level > 0) & (parentFilters.length == 0) ){ return false } @@ -360,59 +365,52 @@ object DecisionTree extends Serializable with Logging { */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { - // calculating bin index and label per feature per node val arr = new Array[Double](1 + (numFeatures * numNodes)) arr(0) = labeledPoint.label for (nodeIndex <- 0 until numNodes) { val parentFilters = findParentFilters(nodeIndex) - //Find out whether the sample qualifies for the particular node + // Find out whether the sample qualifies for the particular node val sampleValid = isSampleValid(parentFilters, labeledPoint) val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { - //Add to invalid bin index -1 - breakable { - for (featureIndex <- 0 until numFeatures) { - arr(shift + featureIndex) = -1 - //Breaking since marking one bin is sufficient - break() - } - } + // marking one bin as -1 is sufficient + arr(shift) = InvalidBinIndex } else { - for (featureIndex <- 0 until numFeatures) { - //logDebug("shift+featureIndex =" + (shift+featureIndex)) - val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinous) + var featureIndex = 0 + while (featureIndex < numFeatures){ + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous) + featureIndex += 1 } } - } arr } /** - Performs a sequential aggregation over a partition for classification. - - for p bins, k features, l nodes (level = log2(l)) storage is of the form: - b111_left_count,b111_right_count, .... , .. - .. bpk1_left_count, bpk1_right_count, .... , .. - .. bpkl_left_count, bpkl_right_count - - @param agg Array[Double] storing aggregate calculation of size - 2*numSplits*numFeatures*numNodes for classification - @param arr Array[Double] of size 1+(numFeatures*numNodes) - @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes - for classification - */ + * Performs a sequential aggregation over a partition for classification. + * + * for p bins, k features, l nodes (level = log2(l)) storage is of the form: + * b111_left_count,b111_right_count, .... , .. + * .. bpk1_left_count, bpk1_right_count, .... , .. + * .. bpkl_left_count, bpkl_right_count + * + * @param agg Array[Double] storing aggregate calculation of size + * 2*numSplits*numFeatures*numNodes for classification + * @param arr Array[Double] of size 1+(numFeatures*numNodes) + * @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes + * for classification + */ def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { - for (node <- 0 until numNodes) { - val validSignalIndex = 1 + numFeatures * node - val isSampleValidForNode = if (arr(validSignalIndex) != -1) true else false + for (nodeIndex <- 0 until numNodes) { + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { val label = arr(0) for (featureIndex <- 0 until numFeatures) { - val arrShift = 1 + numFeatures * node - val aggShift = 2 * numBins * numFeatures * node + val arrShift = 1 + numFeatures * nodeIndex + val aggShift = 2 * numBins * numFeatures * nodeIndex val arrIndex = arrShift + featureIndex val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 label match { @@ -425,28 +423,28 @@ object DecisionTree extends Serializable with Logging { } /** - Performs a sequential aggregation over a partition for regression. - - for p bins, k features, l nodes (level = log2(l)) storage is of the form: - b111_count,b111_sum, b111_sum_squares .... , .. - .. bpk1_count, bpk1_sum, bpk1_sum_squares, .... , .. - .. bpkl_count, bpkl_sum, bpkl_sum_squares - - @param agg Array[Double] storing aggregate calculation of size - 3*numSplits*numFeatures*numNodes for classification - @param arr Array[Double] of size 1+(numFeatures*numNodes) - @return Array[Double] storing aggregate calculation of size 3*numSplits*numFeatures*numNodes - for regression - */ + * Performs a sequential aggregation over a partition for regression. + * + * for p bins, k features, l nodes (level = log2(l)) storage is of the form: + * b111_count,b111_sum, b111_sum_squares .... , .. + * .. bpk1_count, bpk1_sum, bpk1_sum_squares, .... , .. + * .. bpkl_count, bpkl_sum, bpkl_sum_squares + * + * @param agg Array[Double] storing aggregate calculation of size + * 3*numSplits*numFeatures*numNodes for classification + * @param arr Array[Double] of size 1+(numFeatures*numNodes) + * @return Array[Double] storing aggregate calculation of size + * 3*numSplits*numFeatures*numNodes for regression + */ def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { - for (node <- 0 until numNodes) { - val validSignalIndex = 1 + numFeatures * node - val isSampleValidForNode = if (arr(validSignalIndex) != -1) true else false + for (nodeIndex <- 0 until numNodes) { + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { val label = arr(0) for (feature <- 0 until numFeatures) { - val arrShift = 1 + numFeatures * node - val aggShift = 3 * numBins * numFeatures * node + val arrShift = 1 + numFeatures * nodeIndex + val aggShift = 3 * numBins * numFeatures * nodeIndex val arrIndex = arrShift + feature val aggIndex = aggShift + 3 * feature * numBins + arr(arrIndex).toInt * 3 //count, sum, sum^2 @@ -513,13 +511,12 @@ object DecisionTree extends Serializable with Logging { logDebug("input = " + input.count) val binMappedRDD = input.map(x => findBinsForLevel(x)) logDebug("binMappedRDD.count = " + binMappedRDD.count) - //calculate bin aggregates + // calculate bin aggregates val binAggregates = { binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) } logDebug("binAggregates.length = " + binAggregates.length) - //binAggregates.foreach(x => logDebug(x)) /** * Calculates the information gain for all splits @@ -578,7 +575,6 @@ object DecisionTree extends Serializable with Logging { } } - //val predict = leftCount / (leftCount + rightCount) val predict = (left1Count + right1Count) / (leftCount + rightCount) new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) @@ -640,7 +636,8 @@ object DecisionTree extends Serializable with Logging { * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) */ def extractLeftRightNodeAggregates( - binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { + binData: Array[Double]) + : (Array[Array[Double]], Array[Array[Double]]) = { strategy.algo match { case Classification => { @@ -747,7 +744,7 @@ object DecisionTree extends Serializable with Logging { val (bestFeatureIndex,bestSplitIndex, gainStats) = { var bestFeatureIndex = 0 var bestSplitIndex = 0 - //Initialization with infeasible values + // Initialization with infeasible values var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1) for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numBins - 1){ @@ -767,7 +764,7 @@ object DecisionTree extends Serializable with Logging { (splits(bestFeatureIndex)(bestSplitIndex),gainStats) } - //Calculate best splits for all nodes at a given level + // Calculate best splits for all nodes at a given level val bestSplits = new Array[(Split, InformationGainStats)](numNodes) def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { @@ -814,28 +811,29 @@ object DecisionTree extends Serializable with Logging { val count = input.count() - //Find the number of features by looking at the first sample + // Find the number of features by looking at the first sample val numFeatures = input.take(1)(0).features.length val maxBins = strategy.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt - logDebug("maxBins = " + numBins) - //Calculate the number of sample for approximate quantile calculation + + // Calculate the number of sample for approximate quantile calculation val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 logDebug("fraction of data used for calculating quantiles = " + fraction) - //sampled input for RDD calculation - val sampledInput = input.sample(false, fraction, 42).collect() + + // sampled input for RDD calculation + val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect() val numSamples = sampledInput.length - val stride: Double = numSamples.toDouble/numBins + val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) strategy.quantileCalculationStrategy match { case Sort => { - val splits = Array.ofDim[Split](numFeatures,numBins-1) - val bins = Array.ofDim[Bin](numFeatures,numBins) + val splits = Array.ofDim[Split](numFeatures, numBins-1) + val bins = Array.ofDim[Bin](numFeatures, numBins) //Find all splits for (featureIndex <- 0 until numFeatures){ @@ -860,10 +858,10 @@ object DecisionTree extends Serializable with Logging { = sampledInput.map(lp => (lp.features(featureIndex),lp.label)) .groupBy(_._1).mapValues(x => x.map(_._2).sum / x.map(_._1).length) - //Checking for missing categorical variables + // Checking for missing categorical variables val fullCentriodForCategories = scala.collection.mutable.Map[Double,Double]() - for (i <- 0 until maxFeatureValue){ - if (centriodForCategories.contains(i)){ + for (i <- 0 until maxFeatureValue) { + if (centriodForCategories.contains(i)) { fullCentriodForCategories(i) = centriodForCategories(i) } else { fullCentriodForCategories(i) = Double.MaxValue @@ -871,14 +869,14 @@ object DecisionTree extends Serializable with Logging { } val categoriesSortedByCentriod - = fullCentriodForCategories.toList sortBy {_._2} + = fullCentriodForCategories.toList.sortBy{_._2} logDebug("centriod for categorical variable = " + categoriesSortedByCentriod) var categoriesForSplit = List[Double]() categoriesSortedByCentriod.iterator.zipWithIndex foreach { case((key, value), index) => { - categoriesForSplit = key:: categoriesForSplit + categoriesForSplit = key :: categoriesForSplit splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) bins(featureIndex)(index) = { @@ -896,10 +894,10 @@ object DecisionTree extends Serializable with Logging { } } - //Find all bins + // Find all bins for (featureIndex <- 0 until numFeatures){ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { //bins for categorical variables are already assigned + if (isFeatureContinuous) { // bins for categorical variables are already assigned bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0), Continuous,Double.MinValue) @@ -933,7 +931,7 @@ object DecisionTree extends Serializable with Logging { val usage = """ - Usage: DecisionTreeRunner[slices] --algo [slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] """ @@ -948,12 +946,10 @@ object DecisionTree extends Serializable with Logging { val sc = new SparkContext(args(0), "DecisionTree") - - val arglist = args.toList.drop(1) + val argList = args.toList.drop(1) type OptionMap = Map[Symbol, Any] def nextOption(map : OptionMap, list: List[String]): OptionMap = { - def isSwitch(s : String) = (s(0) == '-') list match { case Nil => map case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail) @@ -969,20 +965,20 @@ object DecisionTree extends Serializable with Logging { sys.exit(1) } } - val options = nextOption(Map(),arglist) + val options = nextOption(Map(),argList) logDebug(options.toString()) - //Load training data + // Load training data val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) - //Identify the type of algorithm + // Identify the type of algorithm val algoStr = options.get('algo).get.toString val algo = algoStr match { case "Classification" => Classification case "Regression" => Regression } - //Identify the type of impurity + // Identify the type of impurity val impurityStr = options.getOrElse('impurity, if (algo == Classification) "Gini" else "Variance").toString val impurity = impurityStr match { @@ -994,19 +990,22 @@ object DecisionTree extends Serializable with Logging { val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt val maxBins = options.getOrElse('maxBins,"100").toString.toInt - val strategy = new Strategy(algo = algo, impurity = impurity, maxDepth = maxDepth, - maxBins = maxBins) - val model = DecisionTree.train(trainData,strategy) + val strategy = new Strategy(algo, impurity, maxDepth, maxBins) + val model = DecisionTree.train(trainData, strategy) - //Load test data + // Load test data val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) - //Measure algorithm accuracy - val accuracy = accuracyScore(model, testData) - logDebug("accuracy = " + accuracy) + // Measure algorithm accuracy + if (algo == Classification){ + val accuracy = accuracyScore(model, testData) + logDebug("accuracy = " + accuracy) + } - val mse = meanSquaredError(model,testData) - logDebug("mean square error = " + mse) + if (algo == Regression){ + val mse = meanSquaredError(model, testData) + logDebug("mean square error = " + mse) + } sc.stop() } @@ -1030,7 +1029,7 @@ object DecisionTree extends Serializable with Logging { } } - //TODO: Port this method to a generic metrics package + // TODO: Port this method to a generic metrics package /** * Calculates the classifier accuracy. */ @@ -1042,14 +1041,11 @@ object DecisionTree extends Serializable with Logging { correctCount.toDouble / count } - //TODO: Port this method to a generic metrics package + // TODO: Port this method to a generic metrics package /** * Calculates the mean squared error for regression */ def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - val meanSumOfSquares = - data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)) - .mean() - meanSumOfSquares + data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean() } } From 62c25620d3e18f3624caf9d0fdfdf2f2d957fd88 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 10 Mar 2014 21:40:30 -0700 Subject: [PATCH 37/48] fixing comment indentation --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 085832cf12070..1184e985a1faf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -321,7 +321,7 @@ object DecisionTree extends Serializable with Logging { /** * Finds the right bin for the given feature - */ + */ def findBin( featureIndex: Int, labeledPoint: LabeledPoint, @@ -362,7 +362,7 @@ object DecisionTree extends Serializable with Logging { * b21, b22, .. , b2k, * bl1, bl2, .. , blk * Denotes invalid sample for tree by noting bin for feature 1 as -1 - */ + */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // calculating bin index and label per feature per node From 6068356067a4376bd1fa105a59d1f533019469b1 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 12 Mar 2014 11:48:50 -0700 Subject: [PATCH 38/48] ensuring num bins is always greater than max number of categories --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1184e985a1faf..5f04b369a6a60 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -816,7 +816,15 @@ object DecisionTree extends Serializable with Logging { val maxBins = strategy.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt - logDebug("maxBins = " + numBins) + logDebug("numBins = " + numBins) + + // I will also add a require statement ensuring #bins is always greater than the categories + // It's a limitation of the current implementation but a reasonable tradeoff since features + // with large number of categories get favored over continuous features. + if (strategy.categoricalFeaturesInfo.size > 0){ + val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 + require(numBins >= maxCategoriesForFeatures) + } // Calculate the number of sample for approximate quantile calculation val requiredSamples = numBins*numBins From 21163601c4eb9922edbbba312430d837031614e9 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 12 Mar 2014 12:13:12 -0700 Subject: [PATCH 39/48] removing dummy bin calculation for categorical variables --- .../spark/mllib/tree/DecisionTree.scala | 11 +----- .../spark/mllib/tree/DecisionTreeSuite.scala | 38 ++++++++++++------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 5f04b369a6a60..5e88109b5ffb5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -340,8 +340,8 @@ object DecisionTree extends Serializable with Logging { } throw new UnknownError("no bin was found for continuous variable.") } else { - - for (binIndex <- 0 until strategy.numBins) { + val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) + for (binIndex <- 0 until numCategoricalBins) { val bin = bins(featureIndex)(binIndex) val category = bin.category val features = labeledPoint.features @@ -917,13 +917,6 @@ object DecisionTree extends Serializable with Logging { bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) - } else { - val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) - for (i <- maxFeatureValue until numBins){ - bins(featureIndex)(i) - = new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - new DummyCategoricalSplit(featureIndex, Categorical), Categorical, Double.MaxValue) - } } } (splits,bins) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a299b087dfda8..f8914e03bd12f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -64,7 +64,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 2, + 1-> 2)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) assert(splits.length==2) assert(bins.length==2) @@ -120,7 +121,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0)(1).highSplit.categories.contains(1.0)) assert(bins(0)(1).highSplit.categories.contains(0.0)) - assert(bins(0)(2).category == Double.MaxValue) + assert(bins(0)(2) == null) assert(bins(1)(0).category == 0.0) assert(bins(1)(0).lowSplit.categories.length == 0) @@ -134,7 +135,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(1)(1).highSplit.categories.contains(0.0)) assert(bins(1)(1).highSplit.categories.contains(1.0)) - assert(bins(1)(2).category == Double.MaxValue) + assert(bins(1)(2) == null) } @@ -142,7 +143,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, + 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) //Checking splits @@ -217,7 +219,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0)(2).highSplit.categories.contains(0.0)) assert(bins(0)(2).highSplit.categories.contains(2.0)) - assert(bins(0)(3).category == Double.MaxValue) + assert(bins(0)(3) == null) assert(bins(1)(0).category == 0.0) assert(bins(1)(0).lowSplit.categories.length == 0) @@ -240,7 +242,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(1)(2).highSplit.categories.contains(1.0)) assert(bins(1)(2).highSplit.categories.contains(2.0)) - assert(bins(1)(3).category == Double.MaxValue) + assert(bins(1)(3) == null) } @@ -249,10 +251,12 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, + 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) strategy.numBins = 100 - val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) val split = bestSplits(0)._1 assert(split.categories.length == 1) @@ -272,10 +276,12 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3, + 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) strategy.numBins = 100 - val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) val split = bestSplits(0)._1 assert(split.categories.length == 1) @@ -305,7 +311,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) strategy.numBins = 100 - val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) @@ -329,7 +336,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) strategy.numBins = 100 - val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) @@ -355,7 +363,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) strategy.numBins = 100 - val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) @@ -379,7 +388,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) strategy.numBins = 100 - val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) From 632818f821959d0182591d9f6b6d03f7a78a754e Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 12 Mar 2014 22:01:51 -0700 Subject: [PATCH 40/48] removing threshold for classification predict method --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 8 ++++++-- .../spark/mllib/tree/model/DecisionTreeModel.scala | 9 +-------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 5e88109b5ffb5..a16bff2b5f4d7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1034,8 +1034,12 @@ object DecisionTree extends Serializable with Logging { /** * Calculates the classifier accuracy. */ - def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - val correctCount = data.filter(y => model.predict(y.features) == y.label).count() + def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint], + threshold: Double = 0.5): Double = { + def predictedValue(features: Array[Double]) = { + if (model.predict(features) < threshold) 0.0 else 1.0 + } + val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() val count = data.count() logDebug("correct prediction count = " + correctCount) logDebug("data count = " + count) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 94d77571dc22f..a056da77641ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -34,14 +34,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @return Double prediction from the trained model */ def predict(features: Array[Double]): Double = { - algo match { - case Classification => { - if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0 - } - case Regression => { - topNode.predictIfLeaf(features) - } - } + topNode.predictIfLeaf(features) } /** From ff363a7353b28e9bcf16944deb376e075555dfd1 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 16 Mar 2014 23:28:45 -0700 Subject: [PATCH 41/48] binary search for bins and while loop for categorical feature bins --- .../spark/mllib/tree/DecisionTree.scala | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index a16bff2b5f4d7..b7492038445cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -319,6 +319,7 @@ object DecisionTree extends Serializable with Logging { true } + // TODO: Unit test this /** * Finds the right bin for the given feature */ @@ -328,26 +329,47 @@ object DecisionTree extends Serializable with Logging { isFeatureContinuous: Boolean) : Int = { - if (isFeatureContinuous){ - for (binIndex <- 0 until strategy.numBins) { - val bin = bins(featureIndex)(binIndex) + val binForFeatures = bins(featureIndex) + val feature = labeledPoint.features(featureIndex) + + def binarySearchForBins(): Int = { + var left = 0 + var right = binForFeatures.length-1 + while (left <= right) { + val mid = left + (right - left) / 2 + val bin = binForFeatures(mid) val lowThreshold = bin.lowSplit.threshold val highThreshold = bin.highSplit.threshold - val features = labeledPoint.features - if ((lowThreshold < features(featureIndex)) & (highThreshold >= features(featureIndex))) { - return binIndex + if ((lowThreshold < feature) & (highThreshold >= feature)){ + return mid + } + else if ((lowThreshold >= feature)){ + right = mid - 1 } + else { + left = mid + 1 + } + } + -1 + } + + if (isFeatureContinuous){ + val binIndex = binarySearchForBins() + if (binIndex == -1){ + throw new UnknownError("no bin was found for continuous variable.") } - throw new UnknownError("no bin was found for continuous variable.") + binIndex } else { val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) - for (binIndex <- 0 until numCategoricalBins) { + var binIndex = 0 + while (binIndex < numCategoricalBins) { val bin = bins(featureIndex)(binIndex) val category = bin.category val features = labeledPoint.features if (category == features(featureIndex)) { return binIndex } + binIndex += 1 } throw new UnknownError("no bin was found for categorical variable.") From 4576b64f7e2d758122f44ad3062fbd54a98e6023 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 22 Mar 2014 20:25:34 -0700 Subject: [PATCH 42/48] documentation and for to while loop conversion --- .../spark/mllib/tree/DecisionTree.scala | 303 ++++++++++++------ 1 file changed, 204 insertions(+), 99 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b7492038445cc..3d5eb0fcf263b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -81,6 +81,8 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log // Each data sample is checked for validity w.r.t to each node at a given level -- i.e., // the sample is only used for the split calculation at the node if the sampled would have // still survived the filters of the parent nodes. + + // TODO: Convert for loop to while loop breakable { for (level <- 0 until maxDepth) { @@ -120,7 +122,7 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log } /** - * Extract the decision tree node information for th given tree level and node index + * Extract the decision tree node information for the given tree level and node index */ private def extractNodeInfo( nodeSplitStats: (Split, InformationGainStats), @@ -151,6 +153,7 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log : Unit = { // 0 corresponds to the left child node and 1 corresponds to the right child node. + // TODO: Convert to while loop for (i <- 0 to 1) { // Calculating the index of the node from the node level and the index at the current level val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i @@ -264,6 +267,31 @@ object DecisionTree extends Serializable with Logging { bins: Array[Array[Bin]]) : Array[(Split, InformationGainStats)] = { + + // The high-level description for the best split optimizations are noted here. + // + // *Level-wise training* + // We perform bin calculations for all nodes at the given level to avoid making multiple + // passes over the data. Thus, for a slightly increased computation and storage cost we save + // several iterations over the data especially at higher levels of the decision tree. + // + // *Bin-wise computation* + // We use a bin-wise best split computation strategy instead of a straightforward best split + // computation strategy. Instead of analyzing each sample for contribution to the left/right + // child node impurity of every split, we first categorize each feature of a sample into a + // bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, + // is ordered (read ordering for categorical variables in the findSplitsBins method), + // we exploit this structure to calculate aggregates for bins and then use these aggregates + // to calculate information gain for each split. + // + // *Aggregation over partitions* + // Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know + // the number of splits in advance. Thus, we store the aggregates (at the appropriate + // indices) in a single array for all bins and rely upon the RDD aggregate method to + // drastically reduce the communication overhead. + + // Implementation below + // Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt logDebug("numNodes = " + numNodes) @@ -294,6 +322,7 @@ object DecisionTree extends Serializable with Logging { return false } + // Apply each filter and check sample validity. Return false when invalid condition found. for (filter <- parentFilters) { val features = labeledPoint.features val featureIndex = filter.split.feature @@ -316,12 +345,13 @@ object DecisionTree extends Serializable with Logging { } } + + //Return true when the sample is valid for all filters true } - // TODO: Unit test this /** - * Finds the right bin for the given feature + * Find bin for one feature */ def findBin( featureIndex: Int, @@ -332,9 +362,12 @@ object DecisionTree extends Serializable with Logging { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) + /** + * Binary search helper method for continuous feature + */ def binarySearchForBins(): Int = { var left = 0 - var right = binForFeatures.length-1 + var right = binForFeatures.length - 1 while (left <= right) { val mid = left + (right - left) / 2 val bin = binForFeatures(mid) @@ -353,13 +386,10 @@ object DecisionTree extends Serializable with Logging { -1 } - if (isFeatureContinuous){ - val binIndex = binarySearchForBins() - if (binIndex == -1){ - throw new UnknownError("no bin was found for continuous variable.") - } - binIndex - } else { + /** + * Sequential search helper method to find bin for categorical feature + */ + def sequentialBinSearchForCategoricalFeature() : Int = { val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) var binIndex = 0 while (binIndex < numCategoricalBins) { @@ -371,26 +401,40 @@ object DecisionTree extends Serializable with Logging { } binIndex += 1 } - throw new UnknownError("no bin was found for categorical variable.") - + -1 } + if (isFeatureContinuous){ + // Perform binary search for finding bin for continuous features. + val binIndex = binarySearchForBins() + if (binIndex == -1){ + throw new UnknownError("no bin was found for continuous variable.") + } + binIndex + } else { + // Perform sequential search to find bin for categorical features. + val binIndex = sequentialBinSearchForCategoricalFeature() + if (binIndex == -1){ + throw new UnknownError("no bin was found for categorical variable.") + } + binIndex + } } /** - * Finds bins for all nodes (and all features) at a given level k features, - * l nodes (level = log2(l)). - * Storage label, b11, b12, b13, .., b1k, - * b21, b22, .. , b2k, - * bl1, bl2, .. , blk - * Denotes invalid sample for tree by noting bin for feature 1 as -1 + * Finds bins for all nodes (and all features) at a given level. + * For l nodes, k features the storage is as follows: + * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk + * where b_ij is an integer between 0 and numBins - 1. + * Invalid sample is denoted by noting bin for feature 1 as -1 */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // calculating bin index and label per feature per node val arr = new Array[Double](1 + (numFeatures * numNodes)) arr(0) = labeledPoint.label - for (nodeIndex <- 0 until numNodes) { + var nodeIndex = 0 + while (nodeIndex < numNodes) { val parentFilters = findParentFilters(nodeIndex) // Find out whether the sample qualifies for the particular node val sampleValid = isSampleValid(parentFilters, labeledPoint) @@ -406,17 +450,15 @@ object DecisionTree extends Serializable with Logging { featureIndex += 1 } } + nodeIndex += 1 } arr } /** - * Performs a sequential aggregation over a partition for classification. - * - * for p bins, k features, l nodes (level = log2(l)) storage is of the form: - * b111_left_count,b111_right_count, .... , .. - * .. bpk1_left_count, bpk1_right_count, .... , .. - * .. bpkl_left_count, bpkl_right_count + * Performs a sequential aggregation over a partition for classification. For l nodes, + * k features, either the left count or the right count of one of the p bins is + * incremented based upon whether the feature is classified as 0 or 1. * * @param agg Array[Double] storing aggregate calculation of size * 2*numSplits*numFeatures*numNodes for classification @@ -425,32 +467,38 @@ object DecisionTree extends Serializable with Logging { * for classification */ def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { - for (nodeIndex <- 0 until numNodes) { + // Iterating over all nodes + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Checking whether the instance was valid for this nodeIndex val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { + // Actual class label val label = arr(0) - for (featureIndex <- 0 until numFeatures) { + // Iterating over all features + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Finding the bin index for this feature val arrShift = 1 + numFeatures * nodeIndex - val aggShift = 2 * numBins * numFeatures * nodeIndex val arrIndex = arrShift + featureIndex + // Updating the left or right count for one bin + val aggShift = 2 * numBins * numFeatures * nodeIndex val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 label match { case (0.0) => agg(aggIndex) = agg(aggIndex) + 1 case (1.0) => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 } + featureIndex += 1 } } + nodeIndex += 1 } } /** - * Performs a sequential aggregation over a partition for regression. - * - * for p bins, k features, l nodes (level = log2(l)) storage is of the form: - * b111_count,b111_sum, b111_sum_squares .... , .. - * .. bpk1_count, bpk1_sum, bpk1_sum_squares, .... , .. - * .. bpkl_count, bpkl_sum, bpkl_sum_squares + * Performs a sequential aggregation over a partition for regression. For l nodes, k features, + * the count, sum, sum of squares of one of the p bins is incremented. * * @param agg Array[Double] storing aggregate calculation of size * 3*numSplits*numFeatures*numNodes for classification @@ -459,37 +507,37 @@ object DecisionTree extends Serializable with Logging { * 3*numSplits*numFeatures*numNodes for regression */ def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { - for (nodeIndex <- 0 until numNodes) { + // Iterating over all nodes + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Checking whether the instance was valid for this nodeIndex val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { + // Actual class label val label = arr(0) - for (feature <- 0 until numFeatures) { + // Iterating over all features + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Finding the bin index for this feature val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // updating count, sum, sum^2 for one bin val aggShift = 3 * numBins * numFeatures * nodeIndex - val arrIndex = arrShift + feature - val aggIndex = aggShift + 3 * feature * numBins + arr(arrIndex).toInt * 3 - //count, sum, sum^2 + val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 agg(aggIndex) = agg(aggIndex) + 1 agg(aggIndex + 1) = agg(aggIndex + 1) + label agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + // increment featureIndex + featureIndex += 1 } } + nodeIndex += 1 } } /** * Performs a sequential aggregation over a partition. - * for p bins, k features, l nodes (level = log2(l)) storage is of the form: - * b111_left_count,b111_right_count, .... , .... - * bpk1_left_count, bpk1_right_count, .... , ...., bpkl_left_count, bpkl_right_count - * @param agg Array[Double] storing aggregate calculation of size - * 2*numSplits*numFeatures*numNodes for classification and - * 3*numSplits*numFeatures*numNodes for regression - * @param arr Array[Double] of size 1+(numFeatures*numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2*numSplits*numFeatures*numNodes for classification and - * 3*numSplits*numFeatures*numNodes for regression */ def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { @@ -499,6 +547,7 @@ object DecisionTree extends Serializable with Logging { agg } + // Calculating bin aggregate length for classification or regression val binAggregateLength = strategy.algo match { case Classification => 2*numBins * numFeatures * numNodes case Regression => 3*numBins * numFeatures * numNodes @@ -512,27 +561,17 @@ object DecisionTree extends Serializable with Logging { * @return Combined aggregate from agg1 and agg2 */ def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = { - strategy.algo match { - case Classification => { - val combinedAggregate = new Array[Double](binAggregateLength) - for (index <- 0 until binAggregateLength){ - combinedAggregate(index) = agg1(index) + agg2(index) - } - combinedAggregate - } - case Regression => { - val combinedAggregate = new Array[Double](binAggregateLength) - for (index <- 0 until binAggregateLength){ - combinedAggregate(index) = agg1(index) + agg2(index) - } - combinedAggregate - } + var index = 0 + val combinedAggregate = new Array[Double](binAggregateLength) + while (index < binAggregateLength){ + combinedAggregate(index) = agg1(index) + agg2(index) + index += 1 } + combinedAggregate } - logDebug("input = " + input.count) + // find feature bins for all nodes at a level val binMappedRDD = input.map(x => findBinsForLevel(x)) - logDebug("binMappedRDD.count = " + binMappedRDD.count) // calculate bin aggregates val binAggregates = { @@ -541,7 +580,7 @@ object DecisionTree extends Serializable with Logging { logDebug("binAggregates.length = " + binAggregates.length) /** - * Calculates the information gain for all splits + * Calculates the information gain for all splits based upon left/right split aggregates * @param leftNodeAgg left node aggregates * @param featureIndex feature index * @param splitIndex split index @@ -572,6 +611,7 @@ object DecisionTree extends Serializable with Logging { if (level > 0) { topImpurity } else { + // Calculating impurity for root node strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) } } @@ -614,6 +654,7 @@ object DecisionTree extends Serializable with Logging { if (level > 0) { topImpurity } else { + // Calculating impurity for root node val count = leftCount + rightCount val sum = leftSum + rightSum val sumSquares = leftSumSquares + rightSumSquares @@ -623,11 +664,11 @@ object DecisionTree extends Serializable with Logging { if (leftCount == 0) { return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, - rightSum/rightCount) + rightSum / rightCount) } if (rightCount == 0) { return new InformationGainStats(0, topImpurity ,topImpurity, - Double.MinValue, leftSum/leftCount) + Double.MinValue, leftSum / leftCount) } val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) @@ -644,16 +685,16 @@ object DecisionTree extends Serializable with Logging { } } - val predict = (leftSum + rightSum)/(leftCount + rightCount) + val predict = (leftSum + rightSum) / (leftCount + rightCount) new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) } } } - /** + /** * Extracts left and right split aggregates - * @param binData Array[Double] of size 2*numFeatures*numSplits + * @param binData Array[Double] of size 2*numFeatures*numSplits * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) */ @@ -663,58 +704,90 @@ object DecisionTree extends Serializable with Logging { strategy.algo match { case Classification => { - + // Initializing left and right split aggregates val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - for (featureIndex <- 0 until numFeatures) { + // Iterating over all features + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex val shift = 2*featureIndex*numBins + + // left node aggregate for the lowest split leftNodeAgg(featureIndex)(0) = binData(shift + 0) leftNodeAgg(featureIndex)(1) = binData(shift + 1) + + // right node aggregate for the highest split rightNodeAgg(featureIndex)(2 * (numBins - 2)) = binData(shift + (2 * (numBins - 1))) rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) = binData(shift + (2 * (numBins - 1)) + 1) - for (splitIndex <- 1 until numBins - 1) { - leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2*splitIndex) + + + // Iterating over all splits + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) - leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2*splitIndex + 1) + + leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = binData(shift + (2 *(numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) = binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) + + splitIndex += 1 } + featureIndex += 1 } (leftNodeAgg, rightNodeAgg) } case Regression => { - + // Initializing left and right split aggregates val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) - for (featureIndex <- 0 until numFeatures) { + // Iterating over all features + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex val shift = 3*featureIndex*numBins + // left node aggregate for the lowest split leftNodeAgg(featureIndex)(0) = binData(shift + 0) leftNodeAgg(featureIndex)(1) = binData(shift + 1) leftNodeAgg(featureIndex)(2) = binData(shift + 2) + + // right node aggregate for the highest split rightNodeAgg(featureIndex)(3 * (numBins - 2)) = binData(shift + (3 * (numBins - 1))) rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = binData(shift + (3 * (numBins - 1)) + 1) rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = binData(shift + (3 * (numBins - 1)) + 2) - for (splitIndex <- 1 until numBins - 1) { + + // Iterating over all splits + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split leftNodeAgg(featureIndex)(3 * splitIndex) - = binData(shift + 3*splitIndex) + + = binData(shift + 3 * splitIndex) + leftNodeAgg(featureIndex)(3 * splitIndex - 3) leftNodeAgg(featureIndex)(3 * splitIndex + 1) - = binData(shift + 3*splitIndex + 1) + + = binData(shift + 3 * splitIndex + 1) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) leftNodeAgg(featureIndex)(3 * splitIndex + 2) - = binData(shift + 3*splitIndex + 2) + + = binData(shift + 3 * splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = binData(shift + (3 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) @@ -724,13 +797,19 @@ object DecisionTree extends Serializable with Logging { rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) + + splitIndex += 1 } + featureIndex += 1 } (leftNodeAgg, rightNodeAgg) } } } + /** + * Calculates information gain for all nodes splits + */ def calculateGainsForAllNodeSplits( leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], @@ -749,10 +828,10 @@ object DecisionTree extends Serializable with Logging { } /** - * Find the best split for a node given bin aggregate data + * Find the best split for a node * @param binData Array[Double] of size 2*numSplits*numFeatures * @param nodeImpurity impurity of the top node - * @return + * @return tuple of split and information gain */ def binsToBestSplit( binData: Array[Double], @@ -760,23 +839,33 @@ object DecisionTree extends Serializable with Logging { : (Split, InformationGainStats) = { logDebug("node impurity = " + nodeImpurity) + + //extract left right node aggregates val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) + + // calculate gains for all splits val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) val (bestFeatureIndex,bestSplitIndex, gainStats) = { - var bestFeatureIndex = 0 - var bestSplitIndex = 0 // Initialization with infeasible values + var bestFeatureIndex = Int.MinValue + var bestSplitIndex = Int.MinValue var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1) - for (featureIndex <- 0 until numFeatures) { - for (splitIndex <- 0 until numBins - 1){ + // Iterating over features + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Iterating over all splits + var splitIndex = 0 + while (splitIndex < numBins - 1){ val gainStats = gains(featureIndex)(splitIndex) if(gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex } + splitIndex += 1 } + featureIndex += 1 } (bestFeatureIndex,bestSplitIndex,bestGainStats) } @@ -786,8 +875,9 @@ object DecisionTree extends Serializable with Logging { (splits(bestFeatureIndex)(bestSplitIndex),gainStats) } - // Calculate best splits for all nodes at a given level - val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + /** + * get bin data for one node + */ def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { case Classification => { @@ -803,14 +893,21 @@ object DecisionTree extends Serializable with Logging { } } - for (node <- 0 until numNodes){ + // Calculate best splits for all nodes at a given level + val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + // Iterating over all nodes at this level + var node = 0 + while (node < numNodes){ val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) logDebug("node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) + node += 1 } + + //Return best splits bestSplits } @@ -865,12 +962,15 @@ object DecisionTree extends Serializable with Logging { val splits = Array.ofDim[Split](numFeatures, numBins-1) val bins = Array.ofDim[Bin](numFeatures, numBins) - //Find all splits - for (featureIndex <- 0 until numFeatures){ + // Find all splits + + // Iterating over all features + var featureIndex = 0 + while (featureIndex < numFeatures){ + // Checking whether the feature is continuous val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride: Double = numSamples.toDouble/numBins logDebug("stride = " + stride) for (index <- 0 until numBins-1) { @@ -880,15 +980,16 @@ object DecisionTree extends Serializable with Logging { } } else { val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) - require(maxFeatureValue < numBins, "number of categories should be less than number " + "of bins") + // For categorical variables, each bin is a category. The bins are sorted and they + // are ordered by calculating the centriod of their corresponding labels. val centriodForCategories = sampledInput.map(lp => (lp.features(featureIndex),lp.label)) .groupBy(_._1).mapValues(x => x.map(_._2).sum / x.map(_._1).length) - // Checking for missing categorical variables + // Checking for missing categorical variables and putting them last in the sorted list val fullCentriodForCategories = scala.collection.mutable.Map[Double,Double]() for (i <- 0 until maxFeatureValue) { if (centriodForCategories.contains(i)) { @@ -898,6 +999,7 @@ object DecisionTree extends Serializable with Logging { } } + //bins sorted by centriods val categoriesSortedByCentriod = fullCentriodForCategories.toList.sortBy{_._2} @@ -922,10 +1024,12 @@ object DecisionTree extends Serializable with Logging { } } } + featureIndex += 1 } // Find all bins - for (featureIndex <- 0 until numFeatures){ + featureIndex = 0 + while (featureIndex < numFeatures){ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { // bins for categorical variables are already assigned bins(featureIndex)(0) @@ -940,6 +1044,7 @@ object DecisionTree extends Serializable with Logging { = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) } + featureIndex += 1 } (splits,bins) } From 24500c541de04febc79dd7508a2d84cb764856c1 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 22 Mar 2014 22:38:47 -0700 Subject: [PATCH 43/48] minor style updates --- .../spark/mllib/tree/DecisionTree.scala | 518 ++++++++---------- 1 file changed, 237 insertions(+), 281 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 3d5eb0fcf263b..db91d90a7eaa4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,18 +19,16 @@ package org.apache.spark.mllib.tree import scala.util.control.Breaks._ +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.tree.model._ -import org.apache.spark.{SparkContext, Logging} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.Split import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} -import java.util.Random +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.model._ +import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom /** @@ -38,49 +36,50 @@ import org.apache.spark.util.random.XORShiftRandom * supports both continuous and categorical features. * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, - * categorical), - * depth of the tree, quantile calculation strategy, etc. - */ + * categorical), depth of the tree, quantile calculation strategy, etc. + */ class DecisionTree private(val strategy: Strategy) extends Serializable with Logging { /** * Method to train a decision tree model over an RDD * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree * @return a DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - // Cache input RDD for speedup during multiple passes + // Cache input RDD for speedup during multiple passes. input.cache() logDebug("algo = " + strategy.algo) - // Finding the splits and the corresponding bins (interval between the splits) using a sample + // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) logDebug("numSplits = " + bins(0).length) - // Noting numBins for the input data + // Set number of bins for the input data. strategy.numBins = bins(0).length - // The depth of the decision tree + // depth of the decision tree val maxDepth = strategy.maxDepth - // The max number of nodes possible given the depth of the tree + // the max number of nodes possible given the depth of the tree val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 - // Initalizing an array to hold filters applied to points for each node + // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) - // The filter at the top node is an empty list + // The filter at the top node is an empty list. filters(0) = List() - // Initializing an array to hold parent impurity calculations for each node + // Initialize an array to hold parent impurity calculations for each node. val parentImpurities = new Array[Double](maxNumNodes) - // Dummy value for top node (updated during first split calculation) + // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) - // The main-idea here is to perform level-wise training of the decision tree nodes thus - // reducing the passes over the data from l to log2(l) where l is the total number of nodes. - // Each data sample is checked for validity w.r.t to each node at a given level -- i.e., - // the sample is only used for the split calculation at the node if the sampled would have - // still survived the filters of the parent nodes. + + /* + * The main idea here is to perform level-wise training of the decision tree nodes thus + * reducing the passes over the data from l to log2(l) where l is the total number of nodes. + * Each data sample is checked for validity w.r.t to each node at a given level -- i.e., + * the sample is only used for the split calculation at the node if the sampled would have + * still survived the filters of the parent nodes. + */ // TODO: Convert for loop to while loop breakable { @@ -90,35 +89,32 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log logDebug("level = " + level) logDebug("#####################################") - // Find best split for all nodes at a level + // Find best split for all nodes at a level. val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters, splits, bins) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - // Extract info for nodes at the current level + // Extract info for nodes at the current level. extractNodeInfo(nodeSplitStats, level, index, nodes) - // Extract info for nodes at the next lower level + // Extract info for nodes at the next lower level. extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) logDebug("final best split = " + nodeSplitStats._1) - } require(scala.math.pow(2, level) == splitsStatsForLevel.length) - // Check whether all the nodes at the current level at leaves + // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) - if (allLeaf) break //no more tree construction - + if (allLeaf) break // no more tree construction } } - // Initialize the top or root node of the tree + // Initialize the top or root node of the tree. val topNode = nodes(0) - // Build the full tree using the node info calculated in the level-wise best split calculations + // Build the full tree using the node info calculated in the level-wise best split calculations. topNode.build(nodes) - // Return a decision tree model - return new DecisionTreeModel(topNode, strategy.algo) + new DecisionTreeModel(topNode, strategy.algo) } /** @@ -128,9 +124,7 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log nodeSplitStats: (Split, InformationGainStats), level: Int, index: Int, - nodes: Array[Node]) - : Unit = { - + nodes: Array[Node]): Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = scala.math.pow(2, level).toInt - 1 + index @@ -149,13 +143,11 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), parentImpurities: Array[Double], - filters: Array[List[Filter]]) - : Unit = { - + filters: Array[List[Filter]]): Unit = { // 0 corresponds to the left child node and 1 corresponds to the right child node. // TODO: Convert to while loop for (i <- 0 to 1) { - // Calculating the index of the node from the node level and the index at the current level + // Calculate the index of the node from the node level and the index at the current level. val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i if (level < maxDepth - 1) { val impurity = if (i == 0) { @@ -184,7 +176,7 @@ object DecisionTree extends Serializable with Logging { * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree * @param strategy The configuration parameters for the tree algorithm which specify the type - * of algoritm (classification, regression, etc.), feature type (continuous, + * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. * @return a DecisionTreeModel that can be used for prediction */ @@ -196,7 +188,7 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model over an RDD * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as * training data - * @param algo algo classification or regression + * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation * @param maxDepth maxDepth maximum depth of the tree * @return a DecisionTreeModel that can be used for prediction @@ -205,8 +197,7 @@ object DecisionTree extends Serializable with Logging { input: RDD[LabeledPoint], algo: Algo, impurity: Impurity, - maxDepth: Int) - : DecisionTreeModel = { + maxDepth: Int): DecisionTreeModel = { val strategy = new Strategy(algo,impurity,maxDepth) new DecisionTree(strategy).train(input: RDD[LabeledPoint]) } @@ -235,8 +226,7 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, - categoricalFeaturesInfo: Map[Int,Int]) - : DecisionTreeModel = { + categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree(strategy).train(input: RDD[LabeledPoint]) @@ -264,44 +254,42 @@ object DecisionTree extends Serializable with Logging { level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], - bins: Array[Array[Bin]]) - : Array[(Split, InformationGainStats)] = { - - - // The high-level description for the best split optimizations are noted here. - // - // *Level-wise training* - // We perform bin calculations for all nodes at the given level to avoid making multiple - // passes over the data. Thus, for a slightly increased computation and storage cost we save - // several iterations over the data especially at higher levels of the decision tree. - // - // *Bin-wise computation* - // We use a bin-wise best split computation strategy instead of a straightforward best split - // computation strategy. Instead of analyzing each sample for contribution to the left/right - // child node impurity of every split, we first categorize each feature of a sample into a - // bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, - // is ordered (read ordering for categorical variables in the findSplitsBins method), - // we exploit this structure to calculate aggregates for bins and then use these aggregates - // to calculate information gain for each split. - // - // *Aggregation over partitions* - // Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know - // the number of splits in advance. Thus, we store the aggregates (at the appropriate - // indices) in a single array for all bins and rely upon the RDD aggregate method to - // drastically reduce the communication overhead. - - // Implementation below - - // Common calculations for multiple nested methods + bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { + + /* + * The high-level description for the best split optimizations are noted here. + * + * *Level-wise training* + * We perform bin calculations for all nodes at the given level to avoid making multiple + * passes over the data. Thus, for a slightly increased computation and storage cost we save + * several iterations over the data especially at higher levels of the decision tree. + * + * *Bin-wise computation* + * We use a bin-wise best split computation strategy instead of a straightforward best split + * computation strategy. Instead of analyzing each sample for contribution to the left/right + * child node impurity of every split, we first categorize each feature of a sample into a + * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, + * is ordered (read ordering for categorical variables in the findSplitsBins method), + * we exploit this structure to calculate aggregates for bins and then use these aggregates + * to calculate information gain for each split. + * + * *Aggregation over partitions* + * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know + * the number of splits in advance. Thus, we store the aggregates (at the appropriate + * indices) in a single array for all bins and rely upon the RDD aggregate method to + * drastically reduce the communication overhead. + */ + + // common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt logDebug("numNodes = " + numNodes) - // Find the number of features by looking at the first sample + // Find the number of features by looking at the first sample. val numFeatures = input.first().features.length logDebug("numFeatures = " + numFeatures) val numBins = strategy.numBins logDebug("numBins = " + numBins) - /** Find the filters used before reaching the current code */ + /** Find the filters used before reaching the current code. */ def findParentFilters(nodeIndex: Int): List[Filter] = { if (level == 0) { List[Filter]() @@ -312,13 +300,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Find whether the sample is valid input for the current node. In other words, - * does it pass through all the filters for the current node. - */ + * Find whether the sample is valid input for the current node, i.e., whether it passes through + * all the filters for the current node. + */ def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { - - // Leaf - if ((level > 0) & (parentFilters.length == 0) ){ + // leaf + if ((level > 0) & (parentFilters.length == 0)) { return false } @@ -331,39 +318,37 @@ object DecisionTree extends Serializable with Logging { val categories = filter.split.categories val isFeatureContinuous = filter.split.featureType == Continuous val feature = features(featureIndex) - if (isFeatureContinuous){ + if (isFeatureContinuous) { comparison match { - case(-1) => if (feature > threshold) return false - case(1) => if (feature <= threshold) return false + case -1 => if (feature > threshold) return false + case 1 => if (feature <= threshold) return false } } else { val containsFeature = categories.contains(feature) comparison match { - case(-1) => if (!containsFeature) return false - case(1) => if (containsFeature) return false + case -1 => if (!containsFeature) return false + case 1 => if (containsFeature) return false } } } - //Return true when the sample is valid for all filters + // Return true when the sample is valid for all filters. true } /** - * Find bin for one feature + * Find bin for one feature. */ def findBin( featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean) - : Int = { - + isFeatureContinuous: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) /** - * Binary search helper method for continuous feature + * Binary search helper method for continuous feature. */ def binarySearchForBins(): Int = { var left = 0 @@ -376,7 +361,7 @@ object DecisionTree extends Serializable with Logging { if ((lowThreshold < feature) & (highThreshold >= feature)){ return mid } - else if ((lowThreshold >= feature)){ + else if (lowThreshold >= feature) { right = mid - 1 } else { @@ -387,9 +372,9 @@ object DecisionTree extends Serializable with Logging { } /** - * Sequential search helper method to find bin for categorical feature + * Sequential search helper method to find bin for categorical feature. */ - def sequentialBinSearchForCategoricalFeature() : Int = { + def sequentialBinSearchForCategoricalFeature(): Int = { val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) var binIndex = 0 while (binIndex < numCategoricalBins) { @@ -404,7 +389,7 @@ object DecisionTree extends Serializable with Logging { -1 } - if (isFeatureContinuous){ + if (isFeatureContinuous) { // Perform binary search for finding bin for continuous features. val binIndex = binarySearchForBins() if (binIndex == -1){ @@ -424,27 +409,26 @@ object DecisionTree extends Serializable with Logging { /** * Finds bins for all nodes (and all features) at a given level. * For l nodes, k features the storage is as follows: - * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk + * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, * where b_ij is an integer between 0 and numBins - 1. - * Invalid sample is denoted by noting bin for feature 1 as -1 + * Invalid sample is denoted by noting bin for feature 1 as -1. */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { - - // calculating bin index and label per feature per node + // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) arr(0) = labeledPoint.label var nodeIndex = 0 while (nodeIndex < numNodes) { val parentFilters = findParentFilters(nodeIndex) - // Find out whether the sample qualifies for the particular node + // Find out whether the sample qualifies for the particular node. val sampleValid = isSampleValid(parentFilters, labeledPoint) val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { - // marking one bin as -1 is sufficient + // Mark one bin as -1 is sufficient. arr(shift) = InvalidBinIndex } else { var featureIndex = 0 - while (featureIndex < numFeatures){ + while (featureIndex < numFeatures) { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous) featureIndex += 1 @@ -461,33 +445,33 @@ object DecisionTree extends Serializable with Logging { * incremented based upon whether the feature is classified as 0 or 1. * * @param agg Array[Double] storing aggregate calculation of size - * 2*numSplits*numFeatures*numNodes for classification - * @param arr Array[Double] of size 1+(numFeatures*numNodes) - * @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes - * for classification + * 2 * numSplits * numFeatures*numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 2 * numSplits * numFeatures * numNodes for classification */ def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { - // Iterating over all nodes + // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { - // Checking whether the instance was valid for this nodeIndex + // Check whether the instance was valid for this nodeIndex. val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { - // Actual class label + // actual class label val label = arr(0) - // Iterating over all features + // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // Finding the bin index for this feature + // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex - // Updating the left or right count for one bin + // Update the left or right count for one bin. val aggShift = 2 * numBins * numFeatures * nodeIndex val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 label match { - case (0.0) => agg(aggIndex) = agg(aggIndex) + 1 - case (1.0) => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 + case 0.0 => agg(aggIndex) = agg(aggIndex) + 1 + case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 } featureIndex += 1 } @@ -501,34 +485,33 @@ object DecisionTree extends Serializable with Logging { * the count, sum, sum of squares of one of the p bins is incremented. * * @param agg Array[Double] storing aggregate calculation of size - * 3*numSplits*numFeatures*numNodes for classification - * @param arr Array[Double] of size 1+(numFeatures*numNodes) + * 3 * numSplits * numFeatures * numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) * @return Array[Double] storing aggregate calculation of size - * 3*numSplits*numFeatures*numNodes for regression + * 3 * numSplits * numFeatures * numNodes for regression */ def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { - // Iterating over all nodes + // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { - // Checking whether the instance was valid for this nodeIndex + // Check whether the instance was valid for this nodeIndex. val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { - // Actual class label + // actual class label val label = arr(0) - // Iterating over all features + // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // Finding the bin index for this feature + // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex - // updating count, sum, sum^2 for one bin + // Update count, sum, and sum^2 for one bin. val aggShift = 3 * numBins * numFeatures * nodeIndex val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 agg(aggIndex) = agg(aggIndex) + 1 agg(aggIndex + 1) = agg(aggIndex + 1) + label agg(aggIndex + 2) = agg(aggIndex + 2) + label*label - // increment featureIndex featureIndex += 1 } } @@ -547,15 +530,15 @@ object DecisionTree extends Serializable with Logging { agg } - // Calculating bin aggregate length for classification or regression + // Calculate bin aggregate length for classification or regression. val binAggregateLength = strategy.algo match { - case Classification => 2*numBins * numFeatures * numNodes - case Regression => 3*numBins * numFeatures * numNodes + case Classification => 2 * numBins * numFeatures * numNodes + case Regression => 3 * numBins * numFeatures * numNodes } logDebug("binAggregateLength = " + binAggregateLength) /** - * Combines the aggregates from partitions + * Combines the aggregates from partitions. * @param agg1 Array containing aggregates from one or more partitions * @param agg2 Array containing aggregates from one or more partitions * @return Combined aggregate from agg1 and agg2 @@ -563,24 +546,24 @@ object DecisionTree extends Serializable with Logging { def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = { var index = 0 val combinedAggregate = new Array[Double](binAggregateLength) - while (index < binAggregateLength){ + while (index < binAggregateLength) { combinedAggregate(index) = agg1(index) + agg2(index) index += 1 } combinedAggregate } - // find feature bins for all nodes at a level + // Find feature bins for all nodes at a level. val binMappedRDD = input.map(x => findBinsForLevel(x)) - // calculate bin aggregates + // Calculate bin aggregates. val binAggregates = { binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) } logDebug("binAggregates.length = " + binAggregates.length) /** - * Calculates the information gain for all splits based upon left/right split aggregates + * Calculates the information gain for all splits based upon left/right split aggregates. * @param leftNodeAgg left node aggregates * @param featureIndex feature index * @param splitIndex split index @@ -593,12 +576,9 @@ object DecisionTree extends Serializable with Logging { featureIndex: Int, splitIndex: Int, rightNodeAgg: Array[Array[Double]], - topImpurity: Double) - : InformationGainStats = { - + topImpurity: Double): InformationGainStats = { strategy.algo match { - case Classification => { - + case Classification => val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) val leftCount = left0Count + left1Count @@ -611,7 +591,7 @@ object DecisionTree extends Serializable with Logging { if (level > 0) { topImpurity } else { - // Calculating impurity for root node + // Calculate impurity for root node. strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) } } @@ -640,8 +620,7 @@ object DecisionTree extends Serializable with Logging { val predict = (left1Count + right1Count) / (leftCount + rightCount) new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) - } - case Regression => { + case Regression => val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1) val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2) @@ -654,7 +633,7 @@ object DecisionTree extends Serializable with Logging { if (level > 0) { topImpurity } else { - // Calculating impurity for root node + // Calculate impurity for root node. val count = leftCount + rightCount val sum = leftSum + rightSum val sumSquares = leftSumSquares + rightSumSquares @@ -687,31 +666,27 @@ object DecisionTree extends Serializable with Logging { val predict = (leftSum + rightSum) / (leftCount + rightCount) new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) - - } } } /** - * Extracts left and right split aggregates + * Extracts left and right split aggregates. * @param binData Array[Double] of size 2*numFeatures*numSplits * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) */ def extractLeftRightNodeAggregates( - binData: Array[Double]) - : (Array[Array[Double]], Array[Array[Double]]) = { - + binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { strategy.algo match { - case Classification => { - // Initializing left and right split aggregates + case Classification => + // Initialize left and right split aggregates. val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - // Iterating over all features + // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { // shift for this featureIndex - val shift = 2*featureIndex*numBins + val shift = 2 * featureIndex * numBins // left node aggregate for the lowest split leftNodeAgg(featureIndex)(0) = binData(shift + 0) @@ -723,7 +698,7 @@ object DecisionTree extends Serializable with Logging { rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) = binData(shift + (2 * (numBins - 1)) + 1) - // Iterating over all splits + // Iterate over all splits. var splitIndex = 1 while (splitIndex < numBins - 1) { // calculating left node aggregate for a split as a sum of left node aggregate of a @@ -747,17 +722,15 @@ object DecisionTree extends Serializable with Logging { featureIndex += 1 } (leftNodeAgg, rightNodeAgg) - } - - case Regression => { - // Initializing left and right split aggregates + case Regression => + // Initialize left and right split aggregates. val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) - // Iterating over all features + // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { // shift for this featureIndex - val shift = 3*featureIndex*numBins + val shift = 3 * featureIndex * numBins // left node aggregate for the lowest split leftNodeAgg(featureIndex)(0) = binData(shift + 0) leftNodeAgg(featureIndex)(1) = binData(shift + 1) @@ -771,55 +744,49 @@ object DecisionTree extends Serializable with Logging { rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = binData(shift + (3 * (numBins - 1)) + 2) - // Iterating over all splits + // Iterate over all splits. var splitIndex = 1 while (splitIndex < numBins - 1) { // calculating left node aggregate for a split as a sum of left node aggregate of a // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(3 * splitIndex) - = binData(shift + 3 * splitIndex) + + leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) + leftNodeAgg(featureIndex)(3 * splitIndex - 3) - leftNodeAgg(featureIndex)(3 * splitIndex + 1) - = binData(shift + 3 * splitIndex + 1) + + leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) - leftNodeAgg(featureIndex)(3 * splitIndex + 2) - = binData(shift + 3 * splitIndex + 2) + + leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) // calculating right node aggregate for a split as a sum of right node aggregate of a // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) - = binData(shift + (3 * (numBins - 2 - splitIndex))) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) - = binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) - = binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = + binData(shift + (3 * (numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) splitIndex += 1 } featureIndex += 1 } (leftNodeAgg, rightNodeAgg) - } } } /** - * Calculates information gain for all nodes splits + * Calculates information gain for all nodes splits. */ def calculateGainsForAllNodeSplits( leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], - nodeImpurity: Double) - : Array[Array[InformationGainStats]] = { - + nodeImpurity: Double): Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) for (featureIndex <- 0 until numFeatures) { - for (splitIndex <- 0 until numBins -1) { + for (splitIndex <- 0 until numBins - 1) { gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, splitIndex, rightNodeAgg, nodeImpurity) } @@ -827,38 +794,37 @@ object DecisionTree extends Serializable with Logging { gains } - /** - * Find the best split for a node - * @param binData Array[Double] of size 2*numSplits*numFeatures + /** + * Find the best split for a node. + * @param binData Array[Double] of size 2 * numSplits * numFeatures * @param nodeImpurity impurity of the top node * @return tuple of split and information gain */ def binsToBestSplit( binData: Array[Double], - nodeImpurity: Double) - : (Split, InformationGainStats) = { + nodeImpurity: Double): (Split, InformationGainStats) = { logDebug("node impurity = " + nodeImpurity) - //extract left right node aggregates + // Extract left right node aggregates. val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) - // calculate gains for all splits + // Calculate gains for all splits. val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) val (bestFeatureIndex,bestSplitIndex, gainStats) = { - // Initialization with infeasible values + // Initialize with infeasible values. var bestFeatureIndex = Int.MinValue var bestSplitIndex = Int.MinValue - var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1) - // Iterating over features + var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) + // Iterate over features. var featureIndex = 0 while (featureIndex < numFeatures) { - // Iterating over all splits + // Iterate over all splits. var splitIndex = 0 - while (splitIndex < numBins - 1){ - val gainStats = gains(featureIndex)(splitIndex) - if(gainStats.gain > bestGainStats.gain) { + while (splitIndex < numBins - 1) { + val gainStats = gains(featureIndex)(splitIndex) + if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex @@ -867,29 +833,28 @@ object DecisionTree extends Serializable with Logging { } featureIndex += 1 } - (bestFeatureIndex,bestSplitIndex,bestGainStats) + (bestFeatureIndex, bestSplitIndex, bestGainStats) } logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex)) - (splits(bestFeatureIndex)(bestSplitIndex),gainStats) + + (splits(bestFeatureIndex)(bestSplitIndex), gainStats) } /** - * get bin data for one node + * Get bin data for one node. */ def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { - case Classification => { + case Classification => val shift = 2 * node * numBins * numFeatures val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures) binsForNode - } - case Regression => { + case Regression => val shift = 3 * node * numBins * numFeatures val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) binsForNode - } } } @@ -897,7 +862,7 @@ object DecisionTree extends Serializable with Logging { val bestSplits = new Array[(Split, InformationGainStats)](numNodes) // Iterating over all nodes at this level var node = 0 - while (node < numNodes){ + while (node < numNodes) { val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) @@ -907,12 +872,9 @@ object DecisionTree extends Serializable with Logging { node += 1 } - //Return best splits bestSplits } - - /** * Returns split and bins for decision tree calculation. * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data @@ -920,14 +882,12 @@ object DecisionTree extends Serializable with Logging { * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for construction the DecisionTree * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree - * .model.Split] of size (numFeatures,numSplits-1) and bins is an Array of [org.apache - * .spark.mllib.tree.model.Bin] of size (numFeatures,numSplits1) + * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache + * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) */ def findSplitsBins( input: RDD[LabeledPoint], - strategy: Strategy) - : (Array[Array[Split]], Array[Array[Bin]]) = { - + strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() // Find the number of features by looking at the first sample @@ -937,15 +897,17 @@ object DecisionTree extends Serializable with Logging { val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) - // I will also add a require statement ensuring #bins is always greater than the categories - // It's a limitation of the current implementation but a reasonable tradeoff since features - // with large number of categories get favored over continuous features. - if (strategy.categoricalFeaturesInfo.size > 0){ + /* + * TODO: Add a require statement ensuring #bins is always greater than the categories. + * It's a limitation of the current implementation but a reasonable trade-off since features + * with large number of categories get favored over continuous features. + */ + if (strategy.categoricalFeaturesInfo.size > 0) { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 require(numBins >= maxCategoriesForFeatures) } - // Calculate the number of sample for approximate quantile calculation + // Calculate the number of sample for approximate quantile calculation. val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 logDebug("fraction of data used for calculating quantiles = " + fraction) @@ -958,23 +920,23 @@ object DecisionTree extends Serializable with Logging { logDebug("stride = " + stride) strategy.quantileCalculationStrategy match { - case Sort => { - val splits = Array.ofDim[Split](numFeatures, numBins-1) + case Sort => + val splits = Array.ofDim[Split](numFeatures, numBins - 1) val bins = Array.ofDim[Bin](numFeatures, numBins) - // Find all splits + // Find all splits. - // Iterating over all features + // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures){ - // Checking whether the feature is continuous + // Check whether the feature is continuous. val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { - val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride: Double = numSamples.toDouble/numBins + val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted + val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) - for (index <- 0 until numBins-1) { - val sampleIndex = (index + 1)*stride.toInt + for (index <- 0 until numBins - 1) { + val sampleIndex = (index + 1) * stride.toInt val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) splits(featureIndex)(index) = split } @@ -984,87 +946,78 @@ object DecisionTree extends Serializable with Logging { "of bins") // For categorical variables, each bin is a category. The bins are sorted and they - // are ordered by calculating the centriod of their corresponding labels. - val centriodForCategories - = sampledInput.map(lp => (lp.features(featureIndex),lp.label)) - .groupBy(_._1).mapValues(x => x.map(_._2).sum / x.map(_._1).length) - - // Checking for missing categorical variables and putting them last in the sorted list - val fullCentriodForCategories = scala.collection.mutable.Map[Double,Double]() + // are ordered by calculating the centroid of their corresponding labels. + val centroidForCategories = + sampledInput.map(lp => (lp.features(featureIndex),lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + + // Check for missing categorical variables and putting them last in the sorted list. + val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() for (i <- 0 until maxFeatureValue) { - if (centriodForCategories.contains(i)) { - fullCentriodForCategories(i) = centriodForCategories(i) + if (centroidForCategories.contains(i)) { + fullCentroidForCategories(i) = centroidForCategories(i) } else { - fullCentriodForCategories(i) = Double.MaxValue + fullCentroidForCategories(i) = Double.MaxValue } } - //bins sorted by centriods - val categoriesSortedByCentriod - = fullCentriodForCategories.toList.sortBy{_._2} + // bins sorted by centroids + val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - logDebug("centriod for categorical variable = " + categoriesSortedByCentriod) + logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) var categoriesForSplit = List[Double]() - categoriesSortedByCentriod.iterator.zipWithIndex foreach { - case((key, value), index) => { + categoriesSortedByCentroid.iterator.zipWithIndex.foreach { + case ((key, value), index) => categoriesForSplit = key :: categoriesForSplit splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) bins(featureIndex)(index) = { - if(index == 0) { + if (index == 0) { new Bin(new DummyCategoricalSplit(featureIndex, Categorical), splits(featureIndex)(0), Categorical, key) - } - else { + } else { new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Categorical, key) } } - } } } featureIndex += 1 } - // Find all bins + // Find all bins. featureIndex = 0 - while (featureIndex < numFeatures){ + while (featureIndex < numFeatures) { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { // bins for categorical variables are already assigned - bins(featureIndex)(0) - = new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0), - Continuous,Double.MinValue) + if (isFeatureContinuous) { // Bins for categorical variables are already assigned. + bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), + splits(featureIndex)(0), Continuous, Double.MinValue) for (index <- 1 until numBins - 1){ val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Continuous, Double.MinValue) bins(featureIndex)(index) = bin } - bins(featureIndex)(numBins-1) - = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, - Continuous), Continuous, Double.MinValue) + bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2), + new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) } featureIndex += 1 } (splits,bins) - } - case MinMax => { + case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") - } - case ApproxHist => { + case ApproxHist => throw new UnsupportedOperationException("approximate histogram not supported yet.") - } } } - val usage = """ Usage: DecisionTreeRunner [slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] """ - def main(args: Array[String]) { if (args.length < 2) { @@ -1093,20 +1046,20 @@ object DecisionTree extends Serializable with Logging { sys.exit(1) } } - val options = nextOption(Map(),argList) + val options = nextOption(Map(), argList) logDebug(options.toString()) - // Load training data + // Load training data. val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) - // Identify the type of algorithm + // Identify the type of algorithm. val algoStr = options.get('algo).get.toString val algo = algoStr match { case "Classification" => Classification case "Regression" => Regression } - // Identify the type of impurity + // Identify the type of impurity. val impurityStr = options.getOrElse('impurity, if (algo == Classification) "Gini" else "Variance").toString val impurity = impurityStr match { @@ -1115,22 +1068,22 @@ object DecisionTree extends Serializable with Logging { case "Variance" => Variance } - val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt - val maxBins = options.getOrElse('maxBins,"100").toString.toInt + val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt + val maxBins = options.getOrElse('maxBins, "100").toString.toInt val strategy = new Strategy(algo, impurity, maxDepth, maxBins) val model = DecisionTree.train(trainData, strategy) - // Load test data + // Load test data. val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) // Measure algorithm accuracy - if (algo == Classification){ + if (algo == Classification) { val accuracy = accuracyScore(model, testData) logDebug("accuracy = " + accuracy) } - if (algo == Regression){ + if (algo == Regression) { val mse = meanSquaredError(model, testData) logDebug("mean square error = " + mse) } @@ -1140,7 +1093,7 @@ object DecisionTree extends Serializable with Logging { /** * Load labeled data from a file. The data format used here is - * , ... + * , ..., * where , are feature values in Double and is the corresponding label as Double. * * @param sc SparkContext @@ -1157,12 +1110,12 @@ object DecisionTree extends Serializable with Logging { } } - // TODO: Port this method to a generic metrics package + // TODO: Port this method to a generic metrics package. /** * Calculates the classifier accuracy. */ def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint], - threshold: Double = 0.5): Double = { + threshold: Double = 0.5): Double = { def predictedValue(features: Array[Double]) = { if (model.predict(features) < threshold) 0.0 else 1.0 } @@ -1175,9 +1128,12 @@ object DecisionTree extends Serializable with Logging { // TODO: Port this method to a generic metrics package /** - * Calculates the mean squared error for regression + * Calculates the mean squared error for regression. */ def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean() + data.map { y => + val err = tree.predict(y.features) - y.label + err * err + }.mean() } } From f963ef5f806330433b2aaf585e8708ebc82db8ca Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 22 Mar 2014 23:34:47 -0700 Subject: [PATCH 44/48] making methods private --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index db91d90a7eaa4..834d617d85faf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -247,7 +247,7 @@ object DecisionTree extends Serializable with Logging { * @param bins possible bins for all features * @return array of splits with best splits for all nodes at a given level. */ - def findBestSplits( + private def findBestSplits( input: RDD[LabeledPoint], parentImpurities: Array[Double], strategy: Strategy, @@ -885,7 +885,7 @@ object DecisionTree extends Serializable with Logging { * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) */ - def findSplitsBins( + private def findSplitsBins( input: RDD[LabeledPoint], strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() From 201702fc64d2be27309b59bddfce624fad765f70 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 23 Mar 2014 00:23:14 -0700 Subject: [PATCH 45/48] making some more methods private --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 834d617d85faf..b003f6fe54f3b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1114,7 +1114,7 @@ object DecisionTree extends Serializable with Logging { /** * Calculates the classifier accuracy. */ - def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint], + private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint], threshold: Double = 0.5): Double = { def predictedValue(features: Array[Double]) = { if (model.predict(features) < threshold) 0.0 else 1.0 @@ -1130,7 +1130,7 @@ object DecisionTree extends Serializable with Logging { /** * Calculates the mean squared error for regression. */ - def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { + private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { data.map { y => val err = tree.predict(y.features) - y.label err * err From 62dc723fac20409a04b3a47bb6d6a86be03bad37 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 23 Mar 2014 18:53:42 -0700 Subject: [PATCH 46/48] updating javadoc and converting helper methods to package private to allow unit testing --- .../spark/mllib/tree/DecisionTree.scala | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b003f6fe54f3b..3ab644e74df1b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -172,7 +172,11 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log object DecisionTree extends Serializable with Logging { /** - * Method to train a decision tree model over an RDD + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. The parameters for the algorithm are specified using the strategy parameter. + * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree * @param strategy The configuration parameters for the tree algorithm which specify the type @@ -185,7 +189,11 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model over an RDD + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. + * * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as * training data * @param algo algorithm, classification or regression @@ -204,8 +212,13 @@ object DecisionTree extends Serializable with Logging { /** - * Method to train a decision tree model over an RDD - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The decision tree method supports binary classification and + * regression. For the binary classification, the label for each instance should either be 0 or + * 1 to denote the two classes. The method also supports categorical features inputs where the + * number of categories can specified using the categoricalFeaturesInfo option. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as * training data for DecisionTree * @param algo classification or regression * @param impurity criterion used for information gain calculation @@ -236,6 +249,7 @@ object DecisionTree extends Serializable with Logging { /** * Returns an array of optimal splits for all nodes at a given level + * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree * @param parentImpurities Impurities for all parent nodes for the current level @@ -247,7 +261,7 @@ object DecisionTree extends Serializable with Logging { * @param bins possible bins for all features * @return array of splits with best splits for all nodes at a given level. */ - private def findBestSplits( + protected[tree] def findBestSplits( input: RDD[LabeledPoint], parentImpurities: Array[Double], strategy: Strategy, @@ -885,7 +899,7 @@ object DecisionTree extends Serializable with Logging { * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) */ - private def findSplitsBins( + protected[tree] def findSplitsBins( input: RDD[LabeledPoint], strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() From e1dd86ffdee4bd15c1a2d8c9c70b7eacf29e9bdb Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 24 Mar 2014 19:23:44 -0700 Subject: [PATCH 47/48] implementing code style suggestions --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 +- .../org/apache/spark/mllib/tree/impurity/Entropy.scala | 4 ++-- .../scala/org/apache/spark/mllib/tree/impurity/Gini.scala | 8 ++++---- .../org/apache/spark/mllib/tree/impurity/Variance.scala | 7 +++---- .../scala/org/apache/spark/mllib/tree/model/Node.scala | 1 - .../org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 7 ------- 6 files changed, 10 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 3ab644e74df1b..5e8fc70bd3c04 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -245,7 +245,7 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(input: RDD[LabeledPoint]) } - val InvalidBinIndex = -1 + private val InvalidBinIndex = -1 /** * Returns an array of optimal splits for all nodes at a given level diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 9018821abc875..8832d7a6929a9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.impurity -import javax.naming.OperationNotSupportedException +import java.lang.UnsupportedOperationException /** * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during @@ -45,5 +45,5 @@ object Entropy extends Impurity { } def calculate(count: Double, sum: Double, sumSquares: Double): Double = - throw new OperationNotSupportedException("Entropy.calculate") + throw new UnsupportedOperationException("Entropy.calculate") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 20af8f6c1c2cd..3f043125a6aba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -17,11 +17,11 @@ package org.apache.spark.mllib.tree.impurity -import javax.naming.OperationNotSupportedException +import java.lang.UnsupportedOperationException /** - * Class for calculating the [[http://en.wikipedia.org/wiki/Gini_coefficient Gini - * coefficent]] during binary classification + * Class for calculating the [[http://en.wikipedia + * .org/wiki/Decision_tree_learning#Gini_impurity]] during binary classification */ object Gini extends Impurity { @@ -43,6 +43,6 @@ object Gini extends Impurity { } def calculate(count: Double, sum: Double, sumSquares: Double): Double = - throw new OperationNotSupportedException("Gini.calculate") + throw new UnsupportedOperationException("Gini.calculate") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 85b7be560fecb..35b1c4e5c3727 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.tree.impurity -import javax.naming.OperationNotSupportedException -import org.apache.spark.Logging +import java.lang.UnsupportedOperationException /** * Class for calculating variance during regression */ -object Variance extends Impurity with Logging { +object Variance extends Impurity { def calculate(c0: Double, c1: Double): Double - = throw new OperationNotSupportedException("Variance.calculate") + = throw new UnsupportedOperationException("Variance.calculate") /** * variance calculation diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 4a2c876a51b54..c3e5c00c8d53c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.Logging -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.FeatureType._ /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index f8914e03bd12f..2dfcdd857b504 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,22 +17,15 @@ package org.apache.spark.mllib.tree -import scala.util.Random - import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ - -import org.jblas._ -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ -import scala.collection.mutable import org.apache.spark.mllib.tree.configuration.FeatureType._ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { From f536ae949e8ffaaef9f5c9e0dcebe093954b156b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 31 Mar 2014 16:05:49 -0700 Subject: [PATCH 48/48] another pass on code style --- .../mllib/tree/configuration/Strategy.scala | 15 +- .../spark/mllib/tree/impurity/Entropy.scala | 2 - .../spark/mllib/tree/impurity/Gini.scala | 16 +- .../spark/mllib/tree/impurity/Impurity.scala | 2 +- .../spark/mllib/tree/impurity/Variance.scala | 14 +- .../apache/spark/mllib/tree/model/Bin.scala | 4 +- .../mllib/tree/model/DecisionTreeModel.scala | 2 - .../tree/model/InformationGainStats.scala | 2 - .../apache/spark/mllib/tree/model/Node.scala | 7 +- .../apache/spark/mllib/tree/model/Split.scala | 6 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 412 +++++++++--------- 11 files changed, 233 insertions(+), 249 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 9e461cfdbbd08..7c9b4796ed62b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -34,14 +34,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. */ -class Strategy ( - val algo: Algo, - val impurity: Impurity, - val maxDepth: Int, - val maxBins: Int = 100, - val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable { +class Strategy ( + val algo: Algo, + val impurity: Impurity, + val maxDepth: Int, + val maxBins: Int = 100, + val quantileCalculationStrategy: QuantileStrategy = Sort, + val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable { var numBins: Int = Int.MinValue - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 8832d7a6929a9..b93995fcf9441 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.tree.impurity -import java.lang.UnsupportedOperationException - /** * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during * binary classification. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 3f043125a6aba..c0407554a91b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -17,32 +17,30 @@ package org.apache.spark.mllib.tree.impurity -import java.lang.UnsupportedOperationException - /** - * Class for calculating the [[http://en.wikipedia - * .org/wiki/Decision_tree_learning#Gini_impurity]] during binary classification + * Class for calculating the + * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] + * during binary classification. */ object Gini extends Impurity { /** - * gini coefficient calculation + * Gini coefficient calculation * @param c0 count of instances with label 0 * @param c1 count of instances with label 1 - * @return gini coefficient value + * @return Gini coefficient value */ - def calculate(c0 : Double, c1 : Double): Double = { + override def calculate(c0: Double, c1: Double): Double = { if (c0 == 0 || c1 == 0) { 0 } else { val total = c0 + c1 val f0 = c0 / total val f1 = c1 / total - 1 - f0*f0 - f1*f1 + 1 - f0 * f0 - f1 * f1 } } def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 97092c85aea61..a4069063af2ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.tree.impurity /** - * Trail for calculating information gain + * Trait for calculating information gain. */ trait Impurity extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 35b1c4e5c3727..b74577dcec167 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -17,25 +17,21 @@ package org.apache.spark.mllib.tree.impurity -import java.lang.UnsupportedOperationException - /** * Class for calculating variance during regression */ object Variance extends Impurity { - def calculate(c0: Double, c1: Double): Double - = throw new UnsupportedOperationException("Variance.calculate") + override def calculate(c0: Double, c1: Double): Double = + throw new UnsupportedOperationException("Variance.calculate") /** * variance calculation * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels - * @return */ - def calculate(count: Double, sum: Double, sumSquares: Double): Double = { - val squaredLoss = sumSquares - (sum*sum)/count - squaredLoss/count + override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + val squaredLoss = sumSquares - (sum * sum) / count + squaredLoss / count } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 47afe3aed2b1b..a57faa13745f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -30,6 +30,4 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * @param featureType type of feature -- categorical or continuous * @param category categorical label value accepted in the bin */ -case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) { - -} +case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a056da77641ee..a8bbf21daec01 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -46,6 +46,4 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable def predict(features: RDD[Array[Double]]): RDD[Double] = { features.map(x => predict(x)) } - - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 64ff826486f5b..99bf79cf12e45 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -36,6 +36,4 @@ class InformationGainStats( "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" .format(gain, impurity, leftImpurity, rightImpurity, predict) } - - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index c3e5c00c8d53c..ea4693c5c2f4e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -37,7 +37,7 @@ class Node ( val split: Option[Split], var leftNode: Option[Node], var rightNode: Option[Node], - val stats: Option[InformationGainStats]) extends Serializable with Logging{ + val stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + "split = " + split + ", stats = " + stats @@ -46,7 +46,7 @@ class Node ( * build the left node and right nodes if not leaf * @param nodes array of nodes */ - def build(nodes : Array[Node]): Unit = { + def build(nodes: Array[Node]): Unit = { logDebug("building node " + id + " at level " + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) @@ -68,7 +68,7 @@ class Node ( * @param feature feature value * @return predicted value */ - def predictIfLeaf(feature : Array[Double]) : Double = { + def predictIfLeaf(feature: Array[Double]) : Double = { if (isLeaf) { predict } else{ @@ -87,5 +87,4 @@ class Node ( } } } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index fffd68d7a64b5..4e64a81dda74e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -42,7 +42,7 @@ case class Split( * @param feature feature index * @param featureType type of feature -- categorical or continuous */ -class DummyLowSplit(feature: Int, featureType : FeatureType) +class DummyLowSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MinValue, featureType, List()) /** @@ -50,7 +50,7 @@ class DummyLowSplit(feature: Int, featureType : FeatureType) * @param feature feature index * @param featureType type of feature -- categorical or continuous */ -class DummyHighSplit(feature: Int, featureType : FeatureType) +class DummyHighSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MaxValue, featureType, List()) /** @@ -59,6 +59,6 @@ class DummyHighSplit(feature: Int, featureType : FeatureType) * @param feature feature index * @param featureType type of feature -- categorical or continuous */ -class DummyCategoricalSplit(feature: Int, featureType : FeatureType) +class DummyCategoricalSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MaxValue, featureType, List()) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2dfcdd857b504..a359bf3a76ce1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -41,246 +41,254 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { System.clearProperty("spark.driver.port") } - test("split and bin calculation"){ + test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(bins.length==2) - assert(splits(0).length==99) - assert(bins(0).length==100) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 99) + assert(bins(0).length === 100) } - test("split and bin calculation for categorical variables"){ + test("split and bin calculation for categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 2, - 1-> 2)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(bins.length==2) - assert(splits(0).length==99) - assert(bins(0).length==100) - - //Checking splits - - assert(splits(0)(0).feature == 0) - assert(splits(0)(0).threshold == Double.MinValue) - assert(splits(0)(0).featureType == Categorical) - assert(splits(0)(0).categories.length == 1) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + // Check splits. + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) assert(splits(0)(0).categories.contains(1.0)) - - assert(splits(0)(1).feature == 0) - assert(splits(0)(1).threshold == Double.MinValue) - assert(splits(0)(1).featureType == Categorical) - assert(splits(0)(1).categories.length == 2) + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) assert(splits(0)(1).categories.contains(1.0)) assert(splits(0)(1).categories.contains(0.0)) - assert(splits(0)(2) == null) + assert(splits(0)(2) === null) - assert(splits(1)(0).feature == 1) - assert(splits(1)(0).threshold == Double.MinValue) - assert(splits(1)(0).featureType == Categorical) - assert(splits(1)(0).categories.length == 1) + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) assert(splits(1)(0).categories.contains(0.0)) - - assert(splits(1)(1).feature == 1) - assert(splits(1)(1).threshold == Double.MinValue) - assert(splits(1)(1).featureType == Categorical) - assert(splits(1)(1).categories.length == 2) + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 2) assert(splits(1)(1).categories.contains(1.0)) assert(splits(1)(1).categories.contains(0.0)) - assert(splits(1)(2) == null) - + assert(splits(1)(2) === null) - // Checks bins + // Check bins. - assert(bins(0)(0).category == 1.0) - assert(bins(0)(0).lowSplit.categories.length == 0) - assert(bins(0)(0).highSplit.categories.length == 1) + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) assert(bins(0)(0).highSplit.categories.contains(1.0)) - assert(bins(0)(1).category == 0.0) - assert(bins(0)(1).lowSplit.categories.length == 1) + assert(bins(0)(1).category === 0.0) + assert(bins(0)(1).lowSplit.categories.length === 1) assert(bins(0)(1).lowSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.length == 2) + assert(bins(0)(1).highSplit.categories.length === 2) assert(bins(0)(1).highSplit.categories.contains(1.0)) assert(bins(0)(1).highSplit.categories.contains(0.0)) - assert(bins(0)(2) == null) + assert(bins(0)(2) === null) - assert(bins(1)(0).category == 0.0) - assert(bins(1)(0).lowSplit.categories.length == 0) - assert(bins(1)(0).highSplit.categories.length == 1) + assert(bins(1)(0).category === 0.0) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) assert(bins(1)(0).highSplit.categories.contains(0.0)) - assert(bins(1)(1).category == 1.0) - assert(bins(1)(1).lowSplit.categories.length == 1) + assert(bins(1)(1).category === 1.0) + assert(bins(1)(1).lowSplit.categories.length === 1) assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length == 2) + assert(bins(1)(1).highSplit.categories.length === 2) assert(bins(1)(1).highSplit.categories.contains(0.0)) assert(bins(1)(1).highSplit.categories.contains(1.0)) - assert(bins(1)(2) == null) - + assert(bins(1)(2) === null) } - test("split and bin calculations for categorical variables with no sample for one category"){ + test("split and bin calculations for categorical variables with no sample for one category") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, - 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - - //Checking splits - - assert(splits(0)(0).feature == 0) - assert(splits(0)(0).threshold == Double.MinValue) - assert(splits(0)(0).featureType == Categorical) - assert(splits(0)(0).categories.length == 1) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + // Check splits. + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) assert(splits(0)(0).categories.contains(1.0)) - assert(splits(0)(1).feature == 0) - assert(splits(0)(1).threshold == Double.MinValue) - assert(splits(0)(1).featureType == Categorical) - assert(splits(0)(1).categories.length == 2) + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) assert(splits(0)(1).categories.contains(1.0)) assert(splits(0)(1).categories.contains(0.0)) - assert(splits(0)(2).feature == 0) - assert(splits(0)(2).threshold == Double.MinValue) - assert(splits(0)(2).featureType == Categorical) - assert(splits(0)(2).categories.length == 3) + assert(splits(0)(2).feature === 0) + assert(splits(0)(2).threshold === Double.MinValue) + assert(splits(0)(2).featureType === Categorical) + assert(splits(0)(2).categories.length === 3) assert(splits(0)(2).categories.contains(1.0)) assert(splits(0)(2).categories.contains(0.0)) assert(splits(0)(2).categories.contains(2.0)) - assert(splits(0)(3) == null) + assert(splits(0)(3) === null) - assert(splits(1)(0).feature == 1) - assert(splits(1)(0).threshold == Double.MinValue) - assert(splits(1)(0).featureType == Categorical) - assert(splits(1)(0).categories.length == 1) + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) assert(splits(1)(0).categories.contains(0.0)) - assert(splits(1)(1).feature == 1) - assert(splits(1)(1).threshold == Double.MinValue) - assert(splits(1)(1).featureType == Categorical) - assert(splits(1)(1).categories.length == 2) + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 2) assert(splits(1)(1).categories.contains(1.0)) assert(splits(1)(1).categories.contains(0.0)) - assert(splits(1)(2).feature == 1) - assert(splits(1)(2).threshold == Double.MinValue) - assert(splits(1)(2).featureType == Categorical) - assert(splits(1)(2).categories.length == 3) + assert(splits(1)(2).feature === 1) + assert(splits(1)(2).threshold === Double.MinValue) + assert(splits(1)(2).featureType === Categorical) + assert(splits(1)(2).categories.length === 3) assert(splits(1)(2).categories.contains(1.0)) assert(splits(1)(2).categories.contains(0.0)) assert(splits(1)(2).categories.contains(2.0)) - assert(splits(1)(3) == null) + assert(splits(1)(3) === null) + // Check bins. - // Checks bins - - assert(bins(0)(0).category == 1.0) - assert(bins(0)(0).lowSplit.categories.length == 0) - assert(bins(0)(0).highSplit.categories.length == 1) + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) assert(bins(0)(0).highSplit.categories.contains(1.0)) - assert(bins(0)(1).category == 0.0) - assert(bins(0)(1).lowSplit.categories.length == 1) + assert(bins(0)(1).category === 0.0) + assert(bins(0)(1).lowSplit.categories.length === 1) assert(bins(0)(1).lowSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.length == 2) + assert(bins(0)(1).highSplit.categories.length === 2) assert(bins(0)(1).highSplit.categories.contains(1.0)) assert(bins(0)(1).highSplit.categories.contains(0.0)) - assert(bins(0)(2).category == 2.0) - assert(bins(0)(2).lowSplit.categories.length == 2) + assert(bins(0)(2).category === 2.0) + assert(bins(0)(2).lowSplit.categories.length === 2) assert(bins(0)(2).lowSplit.categories.contains(1.0)) assert(bins(0)(2).lowSplit.categories.contains(0.0)) - assert(bins(0)(2).highSplit.categories.length == 3) + assert(bins(0)(2).highSplit.categories.length === 3) assert(bins(0)(2).highSplit.categories.contains(1.0)) assert(bins(0)(2).highSplit.categories.contains(0.0)) assert(bins(0)(2).highSplit.categories.contains(2.0)) - assert(bins(0)(3) == null) + assert(bins(0)(3) === null) - assert(bins(1)(0).category == 0.0) - assert(bins(1)(0).lowSplit.categories.length == 0) - assert(bins(1)(0).highSplit.categories.length == 1) + assert(bins(1)(0).category === 0.0) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) assert(bins(1)(0).highSplit.categories.contains(0.0)) - assert(bins(1)(1).category == 1.0) - assert(bins(1)(1).lowSplit.categories.length == 1) + assert(bins(1)(1).category === 1.0) + assert(bins(1)(1).lowSplit.categories.length === 1) assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length == 2) + assert(bins(1)(1).highSplit.categories.length === 2) assert(bins(1)(1).highSplit.categories.contains(0.0)) assert(bins(1)(1).highSplit.categories.contains(1.0)) - assert(bins(1)(2).category == 2.0) - assert(bins(1)(2).lowSplit.categories.length == 2) + assert(bins(1)(2).category === 2.0) + assert(bins(1)(2).lowSplit.categories.length === 2) assert(bins(1)(2).lowSplit.categories.contains(0.0)) assert(bins(1)(2).lowSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.length == 3) + assert(bins(1)(2).highSplit.categories.length === 3) assert(bins(1)(2).highSplit.categories.contains(0.0)) assert(bins(1)(2).highSplit.categories.contains(1.0)) assert(bins(1)(2).highSplit.categories.contains(2.0)) - assert(bins(1)(3) == null) - - + assert(bins(1)(3) === null) } - test("classification stump with all categorical variables"){ + test("classification stump with all categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, - 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins) val split = bestSplits(0)._1 - assert(split.categories.length == 1) + assert(split.categories.length === 1) assert(split.categories.contains(1.0)) - assert(split.featureType == Categorical) - assert(split.threshold == Double.MinValue) + assert(split.featureType === Categorical) + assert(split.threshold === Double.MinValue) val stats = bestSplits(0)._2 assert(stats.gain > 0) assert(stats.predict > 0.4) assert(stats.predict < 0.5) assert(stats.impurity > 0.2) - } - test("regression stump with all categorical variables"){ + test("regression stump with all categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3, - 1-> 3)) + val strategy = new Strategy( + Regression, + Variance, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins) val split = bestSplits(0)._1 - assert(split.categories.length == 1) + assert(split.categories.length === 1) assert(split.categories.contains(1.0)) - assert(split.featureType == Categorical) - assert(split.threshold == Double.MinValue) + assert(split.featureType === Categorical) + assert(split.threshold === Double.MinValue) val stats = bestSplits(0)._2 assert(stats.gain > 0) @@ -289,110 +297,104 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(stats.impurity > 0.2) } - - test("stump with fixed label 0 for Gini"){ + test("stump with fixed label 0 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(splits(0).length==99) - assert(bins.length==2) - assert(bins(0).length==100) - assert(splits(0).length==99) - assert(bins(0).length==100) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins) - assert(bestSplits.length == 1) - assert(0==bestSplits(0)._1.feature) - assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2.gain) - assert(0==bestSplits(0)._2.leftImpurity) - assert(0==bestSplits(0)._2.rightImpurity) - + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) } - test("stump with fixed label 1 for Gini"){ + test("stump with fixed label 1 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(splits(0).length==99) - assert(bins.length==2) - assert(bins(0).length==100) - assert(splits(0).length==99) - assert(bins(0).length==100) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins) - assert(bestSplits.length == 1) - assert(0==bestSplits(0)._1.feature) - assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2.gain) - assert(0==bestSplits(0)._2.leftImpurity) - assert(0==bestSplits(0)._2.rightImpurity) - assert(1==bestSplits(0)._2.predict) - + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 1) } - - test("stump with fixed label 0 for Entropy"){ + test("stump with fixed label 0 for Entropy") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Entropy,3,100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(splits(0).length==99) - assert(bins.length==2) - assert(bins(0).length==100) - assert(splits(0).length==99) - assert(bins(0).length==100) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins) - assert(bestSplits.length == 1) - assert(0==bestSplits(0)._1.feature) - assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2.gain) - assert(0==bestSplits(0)._2.leftImpurity) - assert(0==bestSplits(0)._2.rightImpurity) - assert(0==bestSplits(0)._2.predict) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 0) } - test("stump with fixed label 1 for Entropy"){ + test("stump with fixed label 1 for Entropy") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Entropy,3,100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(splits(0).length==99) - assert(bins.length==2) - assert(bins(0).length==100) - assert(splits(0).length==99) - assert(bins(0).length==100) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins) - assert(bestSplits.length == 1) - assert(0==bestSplits(0)._1.feature) - assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2.gain) - assert(0==bestSplits(0)._2.leftImpurity) - assert(0==bestSplits(0)._2.rightImpurity) - assert(1==bestSplits(0)._2.predict) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 1) } - - } object DecisionTreeSuite { @@ -406,7 +408,6 @@ object DecisionTreeSuite { arr } - def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ @@ -427,5 +428,4 @@ object DecisionTreeSuite { } arr } - }