Skip to content

Commit 47ccd5e

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-2851] [mllib] DecisionTree Python consistency update
Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs). Added factory classes for Algo and Impurity, but made private[mllib]. CC: mengxr dorx Please let me know if there are other changes which would help with API consistency---thanks! Author: Joseph K. Bradley <[email protected]> Closes #1798 from jkbradley/dt-python-consistency and squashes the following commits: 6f7edf8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency a0d7dbe [Joseph K. Bradley] DecisionTree: In Java-friendly train* methods, changed to use JavaRDD instead of RDD. ee1d236 [Joseph K. Bradley] DecisionTree API updates: * Removed train() function in Python API (tree.py) ** Removed corresponding function in Scala/Java API (the ones taking basic types) 00f820e [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency fe6dbfa [Joseph K. Bradley] removed unnecessary imports e358661 [Joseph K. Bradley] DecisionTree API change: * Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs). c699850 [Joseph K. Bradley] a few doc comments eaf84c0 [Joseph K. Bradley] Added DecisionTree static train() methods API to match Python, but without default parameters
1 parent ffd1f59 commit 47ccd5e

File tree

5 files changed

+181
-77
lines changed

5 files changed

+181
-77
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,14 @@ import org.apache.spark.annotation.DeveloperApi
2525
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
2626
import org.apache.spark.mllib.classification._
2727
import org.apache.spark.mllib.clustering._
28-
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
2928
import org.apache.spark.mllib.optimization._
3029
import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors}
3130
import org.apache.spark.mllib.random.{RandomRDDGenerators => RG}
3231
import org.apache.spark.mllib.recommendation._
3332
import org.apache.spark.mllib.regression._
34-
import org.apache.spark.mllib.tree.configuration.Algo._
35-
import org.apache.spark.mllib.tree.configuration.Strategy
33+
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
3634
import org.apache.spark.mllib.tree.DecisionTree
37-
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
35+
import org.apache.spark.mllib.tree.impurity._
3836
import org.apache.spark.mllib.tree.model.DecisionTreeModel
3937
import org.apache.spark.mllib.stat.Statistics
4038
import org.apache.spark.mllib.stat.correlation.CorrelationNames
@@ -523,17 +521,8 @@ class PythonMLLibAPI extends Serializable {
523521

524522
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
525523

526-
val algo: Algo = algoStr match {
527-
case "classification" => Classification
528-
case "regression" => Regression
529-
case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr")
530-
}
531-
val impurity: Impurity = impurityStr match {
532-
case "gini" => Gini
533-
case "entropy" => Entropy
534-
case "variance" => Variance
535-
case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr")
536-
}
524+
val algo = Algo.fromString(algoStr)
525+
val impurity = Impurities.fromString(impurityStr)
537526

538527
val strategy = new Strategy(
539528
algo = algo,

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 124 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@
1717

1818
package org.apache.spark.mllib.tree
1919

20+
import org.apache.spark.api.java.JavaRDD
21+
22+
import scala.collection.JavaConverters._
23+
2024
import org.apache.spark.annotation.Experimental
2125
import org.apache.spark.Logging
2226
import org.apache.spark.mllib.regression.LabeledPoint
23-
import org.apache.spark.mllib.tree.configuration.Strategy
27+
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
2428
import org.apache.spark.mllib.tree.configuration.Algo._
2529
import org.apache.spark.mllib.tree.configuration.FeatureType._
2630
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
27-
import org.apache.spark.mllib.tree.impurity.Impurity
31+
import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity}
2832
import org.apache.spark.mllib.tree.model._
2933
import org.apache.spark.rdd.RDD
3034
import org.apache.spark.util.random.XORShiftRandom
@@ -200,6 +204,10 @@ object DecisionTree extends Serializable with Logging {
200204
* Method to train a decision tree model.
201205
* The method supports binary and multiclass classification and regression.
202206
*
207+
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
208+
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
209+
* is recommended to clearly separate classification and regression.
210+
*
203211
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
204212
* For classification, labels should take values {0, 1, ..., numClasses-1}.
205213
* For regression, labels are real numbers.
@@ -213,10 +221,12 @@ object DecisionTree extends Serializable with Logging {
213221
}
214222

215223
/**
216-
* Method to train a decision tree model where the instances are represented as an RDD of
217-
* (label, features) pairs. The method supports binary classification and regression. For the
218-
* binary classification, the label for each instance should either be 0 or 1 to denote the two
219-
* classes.
224+
* Method to train a decision tree model.
225+
* The method supports binary and multiclass classification and regression.
226+
*
227+
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
228+
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
229+
* is recommended to clearly separate classification and regression.
220230
*
221231
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
222232
* For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -237,10 +247,12 @@ object DecisionTree extends Serializable with Logging {
237247
}
238248

239249
/**
240-
* Method to train a decision tree model where the instances are represented as an RDD of
241-
* (label, features) pairs. The method supports binary classification and regression. For the
242-
* binary classification, the label for each instance should either be 0 or 1 to denote the two
243-
* classes.
250+
* Method to train a decision tree model.
251+
* The method supports binary and multiclass classification and regression.
252+
*
253+
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
254+
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
255+
* is recommended to clearly separate classification and regression.
244256
*
245257
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
246258
* For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -263,11 +275,12 @@ object DecisionTree extends Serializable with Logging {
263275
}
264276

265277
/**
266-
* Method to train a decision tree model where the instances are represented as an RDD of
267-
* (label, features) pairs. The decision tree method supports binary classification and
268-
* regression. For the binary classification, the label for each instance should either be 0 or
269-
* 1 to denote the two classes. The method also supports categorical features inputs where the
270-
* number of categories can specified using the categoricalFeaturesInfo option.
278+
* Method to train a decision tree model.
279+
* The method supports binary and multiclass classification and regression.
280+
*
281+
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
282+
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
283+
* is recommended to clearly separate classification and regression.
271284
*
272285
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
273286
* For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -279,11 +292,9 @@ object DecisionTree extends Serializable with Logging {
279292
* @param numClassesForClassification number of classes for classification. Default value of 2.
280293
* @param maxBins maximum number of bins used for splitting features
281294
* @param quantileCalculationStrategy algorithm for calculating quantiles
282-
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
283-
* the number of discrete values they take. For example,
284-
* an entry (n -> k) implies the feature n is categorical with k
285-
* categories 0, 1, 2, ... , k-1. It's important to note that
286-
* features are zero-indexed.
295+
* @param categoricalFeaturesInfo Map storing arity of categorical features.
296+
* E.g., an entry (n -> k) indicates that feature n is categorical
297+
* with k categories indexed from 0: {0, 1, ..., k-1}.
287298
* @return DecisionTreeModel that can be used for prediction
288299
*/
289300
def train(
@@ -300,6 +311,93 @@ object DecisionTree extends Serializable with Logging {
300311
new DecisionTree(strategy).train(input)
301312
}
302313

314+
/**
315+
* Method to train a decision tree model for binary or multiclass classification.
316+
*
317+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
318+
* Labels should take values {0, 1, ..., numClasses-1}.
319+
* @param numClassesForClassification number of classes for classification.
320+
* @param categoricalFeaturesInfo Map storing arity of categorical features.
321+
* E.g., an entry (n -> k) indicates that feature n is categorical
322+
* with k categories indexed from 0: {0, 1, ..., k-1}.
323+
* @param impurity Criterion used for information gain calculation.
324+
* Supported values: "gini" (recommended) or "entropy".
325+
* @param maxDepth Maximum depth of the tree.
326+
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
327+
* (suggested value: 4)
328+
* @param maxBins maximum number of bins used for splitting features
329+
* (suggested value: 100)
330+
* @return DecisionTreeModel that can be used for prediction
331+
*/
332+
def trainClassifier(
333+
input: RDD[LabeledPoint],
334+
numClassesForClassification: Int,
335+
categoricalFeaturesInfo: Map[Int, Int],
336+
impurity: String,
337+
maxDepth: Int,
338+
maxBins: Int): DecisionTreeModel = {
339+
val impurityType = Impurities.fromString(impurity)
340+
train(input, Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort,
341+
categoricalFeaturesInfo)
342+
}
343+
344+
/**
345+
* Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
346+
*/
347+
def trainClassifier(
348+
input: JavaRDD[LabeledPoint],
349+
numClassesForClassification: Int,
350+
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
351+
impurity: String,
352+
maxDepth: Int,
353+
maxBins: Int): DecisionTreeModel = {
354+
trainClassifier(input.rdd, numClassesForClassification,
355+
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
356+
impurity, maxDepth, maxBins)
357+
}
358+
359+
/**
360+
* Method to train a decision tree model for regression.
361+
*
362+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
363+
* Labels are real numbers.
364+
* @param categoricalFeaturesInfo Map storing arity of categorical features.
365+
* E.g., an entry (n -> k) indicates that feature n is categorical
366+
* with k categories indexed from 0: {0, 1, ..., k-1}.
367+
* @param impurity Criterion used for information gain calculation.
368+
* Supported values: "variance".
369+
* @param maxDepth Maximum depth of the tree.
370+
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
371+
* (suggested value: 4)
372+
* @param maxBins maximum number of bins used for splitting features
373+
* (suggested value: 100)
374+
* @return DecisionTreeModel that can be used for prediction
375+
*/
376+
def trainRegressor(
377+
input: RDD[LabeledPoint],
378+
categoricalFeaturesInfo: Map[Int, Int],
379+
impurity: String,
380+
maxDepth: Int,
381+
maxBins: Int): DecisionTreeModel = {
382+
val impurityType = Impurities.fromString(impurity)
383+
train(input, Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo)
384+
}
385+
386+
/**
387+
* Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
388+
*/
389+
def trainRegressor(
390+
input: JavaRDD[LabeledPoint],
391+
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
392+
impurity: String,
393+
maxDepth: Int,
394+
maxBins: Int): DecisionTreeModel = {
395+
trainRegressor(input.rdd,
396+
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
397+
impurity, maxDepth, maxBins)
398+
}
399+
400+
303401
private val InvalidBinIndex = -1
304402

305403
/**
@@ -1331,16 +1429,15 @@ object DecisionTree extends Serializable with Logging {
13311429
* Categorical features:
13321430
* For each feature, there is 1 bin per split.
13331431
* Splits and bins are handled in 2 ways:
1334-
* (a) For multiclass classification with a low-arity feature
1432+
* (a) "unordered features"
1433+
* For multiclass classification with a low-arity feature
13351434
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
13361435
* the feature is split based on subsets of categories.
1337-
* There are 2^(maxFeatureValue - 1) - 1 splits.
1338-
* (b) For regression and binary classification,
1436+
* There are math.pow(2, maxFeatureValue - 1) - 1 splits.
1437+
* (b) "ordered features"
1438+
* For regression and binary classification,
13391439
* and for multiclass classification with a high-arity feature,
1340-
* there is one split per category.
1341-
1342-
* Categorical case (a) features are called unordered features.
1343-
* Other cases are called ordered features.
1440+
* there is one bin per category.
13441441
*
13451442
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
13461443
* @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,10 @@ import org.apache.spark.annotation.Experimental
2727
object Algo extends Enumeration {
2828
type Algo = Value
2929
val Classification, Regression = Value
30+
31+
private[mllib] def fromString(name: String): Algo = name match {
32+
case "classification" => Classification
33+
case "regression" => Regression
34+
case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name")
35+
}
3036
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.impurity
19+
20+
/**
21+
* Factory for Impurity instances.
22+
*/
23+
private[mllib] object Impurities {
24+
25+
def fromString(name: String): Impurity = name match {
26+
case "gini" => Gini
27+
case "entropy" => Entropy
28+
case "variance" => Variance
29+
case _ => throw new IllegalArgumentException(s"Did not recognize Impurity name: $name")
30+
}
31+
32+
}

python/pyspark/mllib/tree.py

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class DecisionTree(object):
131131
"""
132132

133133
@staticmethod
134-
def trainClassifier(data, numClasses, categoricalFeaturesInfo={},
134+
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
135135
impurity="gini", maxDepth=4, maxBins=100):
136136
"""
137137
Train a DecisionTreeModel for classification.
@@ -150,12 +150,20 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo={},
150150
:param maxBins: Number of bins used for finding splits at each node.
151151
:return: DecisionTreeModel
152152
"""
153-
return DecisionTree.train(data, "classification", numClasses,
154-
categoricalFeaturesInfo,
155-
impurity, maxDepth, maxBins)
153+
sc = data.context
154+
dataBytes = _get_unmangled_labeled_point_rdd(data)
155+
categoricalFeaturesInfoJMap = \
156+
MapConverter().convert(categoricalFeaturesInfo,
157+
sc._gateway._gateway_client)
158+
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
159+
dataBytes._jrdd, "classification",
160+
numClasses, categoricalFeaturesInfoJMap,
161+
impurity, maxDepth, maxBins)
162+
dataBytes.unpersist()
163+
return DecisionTreeModel(sc, model)
156164

157165
@staticmethod
158-
def trainRegressor(data, categoricalFeaturesInfo={},
166+
def trainRegressor(data, categoricalFeaturesInfo,
159167
impurity="variance", maxDepth=4, maxBins=100):
160168
"""
161169
Train a DecisionTreeModel for regression.
@@ -173,42 +181,14 @@ def trainRegressor(data, categoricalFeaturesInfo={},
173181
:param maxBins: Number of bins used for finding splits at each node.
174182
:return: DecisionTreeModel
175183
"""
176-
return DecisionTree.train(data, "regression", 0,
177-
categoricalFeaturesInfo,
178-
impurity, maxDepth, maxBins)
179-
180-
@staticmethod
181-
def train(data, algo, numClasses, categoricalFeaturesInfo,
182-
impurity, maxDepth, maxBins=100):
183-
"""
184-
Train a DecisionTreeModel for classification or regression.
185-
186-
:param data: Training data: RDD of LabeledPoint.
187-
For classification, labels are integers
188-
{0,1,...,numClasses}.
189-
For regression, labels are real numbers.
190-
:param algo: "classification" or "regression"
191-
:param numClasses: Number of classes for classification.
192-
:param categoricalFeaturesInfo: Map from categorical feature index
193-
to number of categories.
194-
Any feature not in this map
195-
is treated as continuous.
196-
:param impurity: For classification: "entropy" or "gini".
197-
For regression: "variance".
198-
:param maxDepth: Max depth of tree.
199-
E.g., depth 0 means 1 leaf node.
200-
Depth 1 means 1 internal node + 2 leaf nodes.
201-
:param maxBins: Number of bins used for finding splits at each node.
202-
:return: DecisionTreeModel
203-
"""
204184
sc = data.context
205185
dataBytes = _get_unmangled_labeled_point_rdd(data)
206186
categoricalFeaturesInfoJMap = \
207187
MapConverter().convert(categoricalFeaturesInfo,
208188
sc._gateway._gateway_client)
209189
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
210-
dataBytes._jrdd, algo,
211-
numClasses, categoricalFeaturesInfoJMap,
190+
dataBytes._jrdd, "regression",
191+
0, categoricalFeaturesInfoJMap,
212192
impurity, maxDepth, maxBins)
213193
dataBytes.unpersist()
214194
return DecisionTreeModel(sc, model)

0 commit comments

Comments
 (0)