From eaf84c0cdcf2eee5c00addbc4c73d24aa90e68b8 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 5 Aug 2014 16:36:21 -0700 Subject: [PATCH 1/6] Added DecisionTree static train() methods API to match Python, but without default parameters --- .../spark/mllib/tree/DecisionTree.scala | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1d03e6e3b36c..da1546147913 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -300,6 +300,30 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(input) } + def train( + input: RDD[LabeledPoint], + algo: Algo, + numClassesForClassification: Int, + categoricalFeaturesInfo: Map[Int,Int], + impurity: Impurity, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = ??? + + def trainClassifier( + input: RDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: Map[Int,Int], + impurity: Impurity, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = ??? + + def trainRegressor( + input: RDD[LabeledPoint], + categoricalFeaturesInfo: Map[Int,Int], + impurity: Impurity, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = ??? + private val InvalidBinIndex = -1 /** From c69985063c22b1802d7755bd6b12e1e8cd76ba74 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 5 Aug 2014 16:46:05 -0700 Subject: [PATCH 2/6] a few doc comments --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index da1546147913..36c47187789c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -300,6 +300,7 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(input) } + // Optional arguments in Python: maxBins def train( input: RDD[LabeledPoint], algo: Algo, @@ -309,6 +310,7 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int, maxBins: Int): DecisionTreeModel = ??? + // Optional arguments in Python: all but input, numClassesForClassification def trainClassifier( input: RDD[LabeledPoint], numClassesForClassification: Int, @@ -317,6 +319,7 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int, maxBins: Int): DecisionTreeModel = ??? + // Optional arguments in Python: all but input def trainRegressor( input: RDD[LabeledPoint], categoricalFeaturesInfo: Map[Int,Int], From e35866176f69ac408fa422cd6c09d6a2db6e9435 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 5 Aug 2014 17:43:22 -0700 Subject: [PATCH 3/6] DecisionTree API change: * Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs). Added factory classes for Algo and Impurity, but made private[mllib]. --- .../mllib/api/python/PythonMLLibAPI.scala | 17 +- .../spark/mllib/tree/DecisionTree.scala | 228 +++++++++++++++--- .../spark/mllib/tree/configuration/Algo.scala | 6 + .../mllib/tree/impurity/Impurities.scala | 32 +++ 4 files changed, 235 insertions(+), 48 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 1d5d3762ed8e..b6da4da0af6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -32,9 +32,9 @@ import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.stat.correlation.CorrelationNames @@ -498,17 +498,8 @@ class PythonMLLibAPI extends Serializable { val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) - val algo: Algo = algoStr match { - case "classification" => Classification - case "regression" => Regression - case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr") - } - val impurity: Impurity = impurityStr match { - case "gini" => Gini - case "entropy" => Entropy - case "variance" => Variance - case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr") - } + val algo = Algo.stringToAlgo(algoStr) + val impurity = Impurities.stringToImpurity(impurityStr) val strategy = new Strategy( algo = algo, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 36c47187789c..8cb246ea48c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,14 +17,16 @@ package org.apache.spark.mllib.tree +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom @@ -213,10 +215,8 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -237,10 +237,8 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -263,11 +261,8 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The decision tree method supports binary classification and - * regression. For the binary classification, the label for each instance should either be 0 or - * 1 to denote the two classes. The method also supports categorical features inputs where the - * number of categories can specified using the categoricalFeaturesInfo option. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -279,11 +274,9 @@ object DecisionTree extends Serializable with Logging { * @param numClassesForClassification number of classes for classification. Default value of 2. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. * @return DecisionTreeModel that can be used for prediction */ def train( @@ -300,32 +293,197 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(input) } - // Optional arguments in Python: maxBins + /** + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * This version takes basic types, for consistency with Python API. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. + * @param algo "classification" or "regression" + * @param numClassesForClassification number of classes for classification. Default value of 2. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param impurity criterion used for information gain calculation + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param maxBins maximum number of bins used for splitting features + * (default Python value = 100) + * @return DecisionTreeModel that can be used for prediction + */ def train( input: RDD[LabeledPoint], - algo: Algo, + algo: String, numClassesForClassification: Int, - categoricalFeaturesInfo: Map[Int,Int], - impurity: Impurity, + categoricalFeaturesInfo: Map[Int, Int], + impurity: String, maxDepth: Int, - maxBins: Int): DecisionTreeModel = ??? + maxBins: Int): DecisionTreeModel = { + val algoType = Algo.stringToAlgo(algo) + val impurityType = Impurities.stringToImpurity(impurity) + train(input, algoType, impurityType, maxDepth, numClassesForClassification, maxBins, Sort, + categoricalFeaturesInfo) + } - // Optional arguments in Python: all but input, numClassesForClassification + /** + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * This version takes basic types, for consistency with Python API. + * This version is Java-friendly, taking a Java map for categoricalFeaturesInfo. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. + * @param algo "classification" or "regression" + * @param numClassesForClassification number of classes for classification. Default value of 2. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param impurity criterion used for information gain calculation + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param maxBins maximum number of bins used for splitting features + * (default Python value = 100) + * @return DecisionTreeModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: String, + numClassesForClassification: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + train(input, algo, numClassesForClassification, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + impurity, maxDepth, maxBins) + } + + /** + * Method to train a decision tree model for binary or multiclass classification. + * This version takes basic types, for consistency with Python API. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClassesForClassification number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * (default Python value = {}, i.e., no categorical features) + * @param impurity criterion used for information gain calculation + * (default Python value = "gini") + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (default Python value = 4) + * @param maxBins maximum number of bins used for splitting features + * (default Python value = 100) + * @return DecisionTreeModel that can be used for prediction + */ def trainClassifier( input: RDD[LabeledPoint], numClassesForClassification: Int, - categoricalFeaturesInfo: Map[Int,Int], - impurity: Impurity, + categoricalFeaturesInfo: Map[Int, Int], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + train(input, "classification", numClassesForClassification, categoricalFeaturesInfo, impurity, + maxDepth, maxBins) + } + + /** + * Method to train a decision tree model for binary or multiclass classification. + * This version takes basic types, for consistency with Python API. + * This version is Java-friendly, taking a Java map for categoricalFeaturesInfo. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClassesForClassification number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * (default Python value = {}, i.e., no categorical features) + * @param impurity criterion used for information gain calculation + * (default Python value = "gini") + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (default Python value = 4) + * @param maxBins maximum number of bins used for splitting features + * (default Python value = 100) + * @return DecisionTreeModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + trainClassifier(input, numClassesForClassification, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + impurity, maxDepth, maxBins) + } + + /** + * Method to train a decision tree model for regression. + * This version takes basic types, for consistency with Python API. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * (default Python value = {}, i.e., no categorical features) + * @param impurity criterion used for information gain calculation + * (default Python value = "variance") + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (default Python value = 4) + * @param maxBins maximum number of bins used for splitting features + * (default Python value = 100) + * @return DecisionTreeModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + categoricalFeaturesInfo: Map[Int, Int], + impurity: String, maxDepth: Int, - maxBins: Int): DecisionTreeModel = ??? + maxBins: Int): DecisionTreeModel = { + train(input, "regression", 0, categoricalFeaturesInfo, impurity, maxDepth, maxBins) + } - // Optional arguments in Python: all but input + /** + * Method to train a decision tree model for regression. + * This version takes basic types, for consistency with Python API. + * This version is Java-friendly, taking a Java map for categoricalFeaturesInfo. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * (default Python value = {}, i.e., no categorical features) + * @param impurity criterion used for information gain calculation + * (default Python value = "variance") + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (default Python value = 4) + * @param maxBins maximum number of bins used for splitting features + * (default Python value = 100) + * @return DecisionTreeModel that can be used for prediction + */ def trainRegressor( input: RDD[LabeledPoint], - categoricalFeaturesInfo: Map[Int,Int], - impurity: Impurity, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + impurity: String, maxDepth: Int, - maxBins: Int): DecisionTreeModel = ??? + maxBins: Int): DecisionTreeModel = { + trainRegressor(input, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + impurity, maxDepth, maxBins) + } + private val InvalidBinIndex = -1 @@ -1361,10 +1519,10 @@ object DecisionTree extends Serializable with Logging { * (a) For multiclass classification with a low-arity feature * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), * the feature is split based on subsets of categories. - * There are 2^(maxFeatureValue - 1) - 1 splits. + * There are math.pow(2, (maxFeatureValue - 1) - 1) splits. * (b) For regression and binary classification, * and for multiclass classification with a high-arity feature, - * there is one split per category. + * * Categorical case (a) features are called unordered features. * Other cases are called ordered features. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 79a01f58319e..a70fc5712f59 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -27,4 +27,10 @@ import org.apache.spark.annotation.Experimental object Algo extends Enumeration { type Algo = Value val Classification, Regression = Value + + private[mllib] def stringToAlgo(name: String): Algo = name match { + case "classification" => Classification + case "regression" => Regression + case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name") + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala new file mode 100644 index 000000000000..305555b1501d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Factory class for Impurity types. + */ +private[mllib] object Impurities { + + def stringToImpurity(name: String): Impurity = name match { + case "gini" => Gini + case "entropy" => Entropy + case "variance" => Variance + case _ => throw new IllegalArgumentException(s"Did not recognize Impurity name: $name") + } + +} From fe6dbfad5f2734eb6cb54ef1708312df3e67b8fc Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 5 Aug 2014 19:34:06 -0700 Subject: [PATCH 4/6] removed unnecessary imports --- .../org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index b6da4da0af6f..a4a4a6381a85 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -25,13 +25,11 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ -import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.tree.impurity._ From ee1d236582e33db5d5636405233740d66b0a539c Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 6 Aug 2014 11:43:44 -0700 Subject: [PATCH 5/6] DecisionTree API updates: * Removed train() function in Python API (tree.py) ** Removed corresponding function in Scala/Java API (the ones taking basic types) DecisionTree internal updates: * Renamed Algo and Impurity factory methods to fromString() DecisionTree doc updates: * Added notes recommending use of trainClassifier, trainRegressor * Say supported values for impurity * Shortened doc for Java-friendly train* functions. --- .../mllib/api/python/PythonMLLibAPI.scala | 4 +- .../spark/mllib/tree/DecisionTree.scala | 164 ++++-------------- .../spark/mllib/tree/configuration/Algo.scala | 2 +- .../mllib/tree/impurity/Impurities.scala | 4 +- python/pyspark/mllib/tree.py | 51 ++---- 5 files changed, 57 insertions(+), 168 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 9d88a38831d8..ba7ccd8ce4b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -521,8 +521,8 @@ class PythonMLLibAPI extends Serializable { val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) - val algo = Algo.stringToAlgo(algoStr) - val impurity = Impurities.stringToImpurity(impurityStr) + val algo = Algo.fromString(algoStr) + val impurity = Impurities.fromString(impurityStr) val strategy = new Strategy( algo = algo, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 8cb246ea48c3..bff771650295 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -202,6 +202,10 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -218,6 +222,10 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -240,6 +248,10 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -264,6 +276,10 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -293,77 +309,8 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(input) } - /** - * Method to train a decision tree model. - * The method supports binary and multiclass classification and regression. - * This version takes basic types, for consistency with Python API. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param algo "classification" or "regression" - * @param numClassesForClassification number of classes for classification. Default value of 2. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. - * @param impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param maxBins maximum number of bins used for splitting features - * (default Python value = 100) - * @return DecisionTreeModel that can be used for prediction - */ - def train( - input: RDD[LabeledPoint], - algo: String, - numClassesForClassification: Int, - categoricalFeaturesInfo: Map[Int, Int], - impurity: String, - maxDepth: Int, - maxBins: Int): DecisionTreeModel = { - val algoType = Algo.stringToAlgo(algo) - val impurityType = Impurities.stringToImpurity(impurity) - train(input, algoType, impurityType, maxDepth, numClassesForClassification, maxBins, Sort, - categoricalFeaturesInfo) - } - - /** - * Method to train a decision tree model. - * The method supports binary and multiclass classification and regression. - * This version takes basic types, for consistency with Python API. - * This version is Java-friendly, taking a Java map for categoricalFeaturesInfo. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param algo "classification" or "regression" - * @param numClassesForClassification number of classes for classification. Default value of 2. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. - * @param impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param maxBins maximum number of bins used for splitting features - * (default Python value = 100) - * @return DecisionTreeModel that can be used for prediction - */ - def train( - input: RDD[LabeledPoint], - algo: String, - numClassesForClassification: Int, - categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], - impurity: String, - maxDepth: Int, - maxBins: Int): DecisionTreeModel = { - train(input, algo, numClassesForClassification, - categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, - impurity, maxDepth, maxBins) - } - /** * Method to train a decision tree model for binary or multiclass classification. - * This version takes basic types, for consistency with Python API. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels should take values {0, 1, ..., numClasses-1}. @@ -371,14 +318,13 @@ object DecisionTree extends Serializable with Logging { * @param categoricalFeaturesInfo Map storing arity of categorical features. * E.g., an entry (n -> k) indicates that feature n is categorical * with k categories indexed from 0: {0, 1, ..., k-1}. - * (default Python value = {}, i.e., no categorical features) - * @param impurity criterion used for information gain calculation - * (default Python value = "gini") + * @param impurity Criterion used for information gain calculation. + * Supported values: "gini" (recommended) or "entropy". * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (default Python value = 4) + * (suggested value: 4) * @param maxBins maximum number of bins used for splitting features - * (default Python value = 100) + * (suggested value: 100) * @return DecisionTreeModel that can be used for prediction */ def trainClassifier( @@ -388,30 +334,13 @@ object DecisionTree extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int): DecisionTreeModel = { - train(input, "classification", numClassesForClassification, categoricalFeaturesInfo, impurity, - maxDepth, maxBins) + val impurityType = Impurities.fromString(impurity) + train(input, Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort, + categoricalFeaturesInfo) } /** - * Method to train a decision tree model for binary or multiclass classification. - * This version takes basic types, for consistency with Python API. - * This version is Java-friendly, taking a Java map for categoricalFeaturesInfo. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * Labels should take values {0, 1, ..., numClasses-1}. - * @param numClassesForClassification number of classes for classification. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. - * (default Python value = {}, i.e., no categorical features) - * @param impurity criterion used for information gain calculation - * (default Python value = "gini") - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (default Python value = 4) - * @param maxBins maximum number of bins used for splitting features - * (default Python value = 100) - * @return DecisionTreeModel that can be used for prediction + * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] */ def trainClassifier( input: RDD[LabeledPoint], @@ -427,21 +356,19 @@ object DecisionTree extends Serializable with Logging { /** * Method to train a decision tree model for regression. - * This version takes basic types, for consistency with Python API. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels are real numbers. * @param categoricalFeaturesInfo Map storing arity of categorical features. * E.g., an entry (n -> k) indicates that feature n is categorical * with k categories indexed from 0: {0, 1, ..., k-1}. - * (default Python value = {}, i.e., no categorical features) - * @param impurity criterion used for information gain calculation - * (default Python value = "variance") + * @param impurity Criterion used for information gain calculation. + * Supported values: "variance". * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (default Python value = 4) + * (suggested value: 4) * @param maxBins maximum number of bins used for splitting features - * (default Python value = 100) + * (suggested value: 100) * @return DecisionTreeModel that can be used for prediction */ def trainRegressor( @@ -450,28 +377,12 @@ object DecisionTree extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int): DecisionTreeModel = { - train(input, "regression", 0, categoricalFeaturesInfo, impurity, maxDepth, maxBins) + val impurityType = Impurities.fromString(impurity) + train(input, Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo) } /** - * Method to train a decision tree model for regression. - * This version takes basic types, for consistency with Python API. - * This version is Java-friendly, taking a Java map for categoricalFeaturesInfo. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * Labels are real numbers. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. - * (default Python value = {}, i.e., no categorical features) - * @param impurity criterion used for information gain calculation - * (default Python value = "variance") - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (default Python value = 4) - * @param maxBins maximum number of bins used for splitting features - * (default Python value = 100) - * @return DecisionTreeModel that can be used for prediction + * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] */ def trainRegressor( input: RDD[LabeledPoint], @@ -1516,16 +1427,15 @@ object DecisionTree extends Serializable with Logging { * Categorical features: * For each feature, there is 1 bin per split. * Splits and bins are handled in 2 ways: - * (a) For multiclass classification with a low-arity feature + * (a) "unordered features" + * For multiclass classification with a low-arity feature * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), * the feature is split based on subsets of categories. - * There are math.pow(2, (maxFeatureValue - 1) - 1) splits. - * (b) For regression and binary classification, + * There are math.pow(2, maxFeatureValue - 1) - 1 splits. + * (b) "ordered features" + * For regression and binary classification, * and for multiclass classification with a high-arity feature, - * - - * Categorical case (a) features are called unordered features. - * Other cases are called ordered features. + * there is one bin per category. * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index a70fc5712f59..0ef9c6181a0a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -28,7 +28,7 @@ object Algo extends Enumeration { type Algo = Value val Classification, Regression = Value - private[mllib] def stringToAlgo(name: String): Algo = name match { + private[mllib] def fromString(name: String): Algo = name match { case "classification" => Classification case "regression" => Regression case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala index 305555b1501d..15fad38aab0f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala @@ -18,11 +18,11 @@ package org.apache.spark.mllib.tree.impurity /** - * Factory class for Impurity types. + * Factory for Impurity. */ private[mllib] object Impurities { - def stringToImpurity(name: String): Impurity = name match { + def fromString(name: String): Impurity = name match { case "gini" => Gini case "entropy" => Entropy case "variance" => Variance diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 1e0006df75ac..754580db2e3b 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -128,7 +128,7 @@ class DecisionTree(object): """ @staticmethod - def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, + def trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=4, maxBins=100): """ Train a DecisionTreeModel for classification. @@ -147,12 +147,20 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "classification", numClasses, - categoricalFeaturesInfo, - impurity, maxDepth, maxBins) + sc = data.context + dataBytes = _get_unmangled_labeled_point_rdd(data) + categoricalFeaturesInfoJMap = \ + MapConverter().convert(categoricalFeaturesInfo, + sc._gateway._gateway_client) + model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( + dataBytes._jrdd, "classification", + numClasses, categoricalFeaturesInfoJMap, + impurity, maxDepth, maxBins) + dataBytes.unpersist() + return DecisionTreeModel(sc, model) @staticmethod - def trainRegressor(data, categoricalFeaturesInfo={}, + def trainRegressor(data, categoricalFeaturesInfo, impurity="variance", maxDepth=4, maxBins=100): """ Train a DecisionTreeModel for regression. @@ -170,43 +178,14 @@ def trainRegressor(data, categoricalFeaturesInfo={}, :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "regression", 0, - categoricalFeaturesInfo, - impurity, maxDepth, maxBins) - - - @staticmethod - def train(data, algo, numClasses, categoricalFeaturesInfo, - impurity, maxDepth, maxBins=100): - """ - Train a DecisionTreeModel for classification or regression. - - :param data: Training data: RDD of LabeledPoint. - For classification, labels are integers - {0,1,...,numClasses}. - For regression, labels are real numbers. - :param algo: "classification" or "regression" - :param numClasses: Number of classes for classification. - :param categoricalFeaturesInfo: Map from categorical feature index - to number of categories. - Any feature not in this map - is treated as continuous. - :param impurity: For classification: "entropy" or "gini". - For regression: "variance". - :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each node. - :return: DecisionTreeModel - """ sc = data.context dataBytes = _get_unmangled_labeled_point_rdd(data) categoricalFeaturesInfoJMap = \ MapConverter().convert(categoricalFeaturesInfo, sc._gateway._gateway_client) model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( - dataBytes._jrdd, algo, - numClasses, categoricalFeaturesInfoJMap, + dataBytes._jrdd, "regression", + 0, categoricalFeaturesInfoJMap, impurity, maxDepth, maxBins) dataBytes.unpersist() return DecisionTreeModel(sc, model) From a0d7dbe09a2b4672151ff993f9ac9d2074d66fb6 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 6 Aug 2014 13:28:51 -0700 Subject: [PATCH 6/6] DecisionTree: In Java-friendly train* methods, changed to use JavaRDD instead of RDD. --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 10 ++++++---- .../apache/spark/mllib/tree/impurity/Impurities.scala | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index bff771650295..c8a865659682 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import org.apache.spark.api.java.JavaRDD + import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental @@ -343,13 +345,13 @@ object DecisionTree extends Serializable with Logging { * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] */ def trainClassifier( - input: RDD[LabeledPoint], + input: JavaRDD[LabeledPoint], numClassesForClassification: Int, categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], impurity: String, maxDepth: Int, maxBins: Int): DecisionTreeModel = { - trainClassifier(input, numClassesForClassification, + trainClassifier(input.rdd, numClassesForClassification, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, impurity, maxDepth, maxBins) } @@ -385,12 +387,12 @@ object DecisionTree extends Serializable with Logging { * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] */ def trainRegressor( - input: RDD[LabeledPoint], + input: JavaRDD[LabeledPoint], categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], impurity: String, maxDepth: Int, maxBins: Int): DecisionTreeModel = { - trainRegressor(input, + trainRegressor(input.rdd, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, impurity, maxDepth, maxBins) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala index 15fad38aab0f..9a6452aa13a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.tree.impurity /** - * Factory for Impurity. + * Factory for Impurity instances. */ private[mllib] object Impurities {