From a4aea51c9f867aa81251a1d2eb9851c7076514a4 Mon Sep 17 00:00:00 2001 From: Basin Date: Fri, 16 Jan 2015 09:34:40 +0800 Subject: [PATCH 1/7] boostingstrategy.defaultParam string algo to enumeration. --- .../tree/configuration/BoostingStrategy.scala | 21 +++++++++++++++++++ .../mllib/tree/configuration/Strategy.scala | 13 ++++++++++++ 2 files changed, 34 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index cf51d041c65a9..49ceb83a54154 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -88,4 +88,25 @@ object BoostingStrategy { throw new IllegalArgumentException(s"$algo is not supported by boosting.") } } + + /** + * Returns default configuration for the boosting algorithm + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @return Configuration for boosting algorithm + */ + def defaultParams(algo: Algo): BoostingStrategy = { + val treeStragtegy = Strategy.defaultStategy(algo) + treeStragtegy.maxDepth = 3 + algo match { + case Algo.Classification => + treeStragtegy.numClasses = 2 + new BoostingStrategy(treeStragtegy, LogLoss) + case Algo.Regression => + new BoostingStrategy(treeStragtegy, SquaredError) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by boosting.") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d5cd89ab94e81..bef066783082c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -181,4 +181,17 @@ object Strategy { new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, numClasses = 0) } + + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo Algo.Classification or Algo.Regression + */ + def defaultStategy(algo: Algo): Strategy = algo match { + case Algo.Classification => + new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, + numClasses = 2) + case Algo.Regression => + new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, + numClasses = 0) + } } From 68cf5444a38d5575d4973e2570474599cb834abf Mon Sep 17 00:00:00 2001 From: Basin Date: Fri, 16 Jan 2015 13:39:33 +0800 Subject: [PATCH 2/7] mllib-ensembles doc modified. --- docs/mllib-ensembles.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 23ede04b62d5b..9fd49d5c9bc9a 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -456,7 +456,7 @@ val (trainingData, testData) = (splits(0), splits(1)) // Train a GradientBoostedTrees model. // The defaultParams for Classification use LogLoss by default. -val boostingStrategy = BoostingStrategy.defaultParams("Classification") +val boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification) boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. boostingStrategy.treeStrategy.numClassesForClassification = 2 boostingStrategy.treeStrategy.maxDepth = 5 @@ -506,7 +506,7 @@ JavaRDD testData = splits[1]; // Train a GradientBoostedTrees model. // The defaultParams for Classification use LogLoss by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification"); +BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification); boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. boostingStrategy.getTreeStrategy().setNumClassesForClassification(2); boostingStrategy.getTreeStrategy().setMaxDepth(5); @@ -564,7 +564,7 @@ val (trainingData, testData) = (splits(0), splits(1)) // Train a GradientBoostedTrees model. // The defaultParams for Regression use SquaredError by default. -val boostingStrategy = BoostingStrategy.defaultParams("Regression") +val boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression) boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. boostingStrategy.treeStrategy.maxDepth = 5 // Empty categoricalFeaturesInfo indicates all features are continuous. @@ -614,7 +614,7 @@ JavaRDD testData = splits[1]; // Train a GradientBoostedTrees model. // The defaultParams for Regression use SquaredError by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression"); +BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression); boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. boostingStrategy.getTreeStrategy().setMaxDepth(5); // Empty categoricalFeaturesInfo indicates all features are continuous. From e04a5aa7e86daab66de86742e55e033bbafc3a2a Mon Sep 17 00:00:00 2001 From: Basin Date: Fri, 16 Jan 2015 09:34:40 +0800 Subject: [PATCH 3/7] boostingstrategy.defaultParam string algo to enumeration. --- .../tree/configuration/BoostingStrategy.scala | 21 +++++++++++++++++++ .../mllib/tree/configuration/Strategy.scala | 13 ++++++++++++ 2 files changed, 34 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index cf51d041c65a9..49ceb83a54154 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -88,4 +88,25 @@ object BoostingStrategy { throw new IllegalArgumentException(s"$algo is not supported by boosting.") } } + + /** + * Returns default configuration for the boosting algorithm + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @return Configuration for boosting algorithm + */ + def defaultParams(algo: Algo): BoostingStrategy = { + val treeStragtegy = Strategy.defaultStategy(algo) + treeStragtegy.maxDepth = 3 + algo match { + case Algo.Classification => + treeStragtegy.numClasses = 2 + new BoostingStrategy(treeStragtegy, LogLoss) + case Algo.Regression => + new BoostingStrategy(treeStragtegy, SquaredError) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by boosting.") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d5cd89ab94e81..bef066783082c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -181,4 +181,17 @@ object Strategy { new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, numClasses = 0) } + + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo Algo.Classification or Algo.Regression + */ + def defaultStategy(algo: Algo): Strategy = algo match { + case Algo.Classification => + new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, + numClasses = 2) + case Algo.Regression => + new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, + numClasses = 0) + } } From 65f96ce6b8e3a555cec28127ba5316e17b7a3a8f Mon Sep 17 00:00:00 2001 From: Basin Date: Fri, 16 Jan 2015 13:39:33 +0800 Subject: [PATCH 4/7] mllib-ensembles doc modified. --- docs/mllib-ensembles.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 23ede04b62d5b..9fd49d5c9bc9a 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -456,7 +456,7 @@ val (trainingData, testData) = (splits(0), splits(1)) // Train a GradientBoostedTrees model. // The defaultParams for Classification use LogLoss by default. -val boostingStrategy = BoostingStrategy.defaultParams("Classification") +val boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification) boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. boostingStrategy.treeStrategy.numClassesForClassification = 2 boostingStrategy.treeStrategy.maxDepth = 5 @@ -506,7 +506,7 @@ JavaRDD testData = splits[1]; // Train a GradientBoostedTrees model. // The defaultParams for Classification use LogLoss by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification"); +BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification); boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. boostingStrategy.getTreeStrategy().setNumClassesForClassification(2); boostingStrategy.getTreeStrategy().setMaxDepth(5); @@ -564,7 +564,7 @@ val (trainingData, testData) = (splits(0), splits(1)) // Train a GradientBoostedTrees model. // The defaultParams for Regression use SquaredError by default. -val boostingStrategy = BoostingStrategy.defaultParams("Regression") +val boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression) boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. boostingStrategy.treeStrategy.maxDepth = 5 // Empty categoricalFeaturesInfo indicates all features are continuous. @@ -614,7 +614,7 @@ JavaRDD testData = splits[1]; // Train a GradientBoostedTrees model. // The defaultParams for Regression use SquaredError by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression"); +BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression); boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. boostingStrategy.getTreeStrategy().setMaxDepth(5); // Empty categoricalFeaturesInfo indicates all features are continuous. From 7c1e6eeb1689ba11daed4fc5be885a9054dd722e Mon Sep 17 00:00:00 2001 From: Basin Date: Wed, 21 Jan 2015 10:57:16 +0800 Subject: [PATCH 5/7] Doc of Java updated. algo -> algoStr instead. --- docs/mllib-ensembles.md | 4 ++-- .../mllib/tree/configuration/BoostingStrategy.scala | 10 +++++----- .../spark/mllib/tree/configuration/Strategy.scala | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 9fd49d5c9bc9a..bc6fac30cb75d 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -506,7 +506,7 @@ JavaRDD testData = splits[1]; // Train a GradientBoostedTrees model. // The defaultParams for Classification use LogLoss by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification); +BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification()); boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. boostingStrategy.getTreeStrategy().setNumClassesForClassification(2); boostingStrategy.getTreeStrategy().setMaxDepth(5); @@ -614,7 +614,7 @@ JavaRDD testData = splits[1]; // Train a GradientBoostedTrees model. // The defaultParams for Regression use SquaredError by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression); +BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression()); boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. boostingStrategy.getTreeStrategy().setMaxDepth(5); // Empty categoricalFeaturesInfo indicates all features are continuous. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 49ceb83a54154..0203aff2b03ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -70,22 +70,22 @@ object BoostingStrategy { /** * Returns default configuration for the boosting algorithm - * @param algo Learning goal. Supported: + * @param algoStr Learning goal. Supported: * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy(algo) + def defaultParams(algoStr: String): BoostingStrategy = { + val treeStrategy = Strategy.defaultStrategy(algoStr) treeStrategy.maxDepth = 3 - algo match { + algoStr match { case "Classification" => treeStrategy.numClasses = 2 new BoostingStrategy(treeStrategy, LogLoss) case "Regression" => new BoostingStrategy(treeStrategy, SquaredError) case _ => - throw new IllegalArgumentException(s"$algo is not supported by boosting.") + throw new IllegalArgumentException(s"$algoStr is not supported by boosting.") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index bef066783082c..6345c041e7aed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -171,9 +171,9 @@ object Strategy { /** * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] - * @param algo "Classification" or "Regression" + * @param algoStr "Classification" or "Regression" */ - def defaultStrategy(algo: String): Strategy = algo match { + def defaultStrategy(algoStr: String): Strategy = algoStr match { case "Classification" => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, numClasses = 2) From 3b728750151f40d4f7ea8168b2888b1dec1d10b9 Mon Sep 17 00:00:00 2001 From: Basin Date: Thu, 22 Jan 2015 08:09:51 +0800 Subject: [PATCH 6/7] defaultParams(algoStr: String) call defaultParams(algo: Algo). --- .../mllib/tree/configuration/BoostingStrategy.scala | 12 +----------- .../spark/mllib/tree/configuration/Strategy.scala | 9 +-------- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 0203aff2b03ec..c0f464669b9c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -76,17 +76,7 @@ object BoostingStrategy { * @return Configuration for boosting algorithm */ def defaultParams(algoStr: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy(algoStr) - treeStrategy.maxDepth = 3 - algoStr match { - case "Classification" => - treeStrategy.numClasses = 2 - new BoostingStrategy(treeStrategy, LogLoss) - case "Regression" => - new BoostingStrategy(treeStrategy, SquaredError) - case _ => - throw new IllegalArgumentException(s"$algoStr is not supported by boosting.") - } + defaultParams(Algo.fromString(algoStr)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 6345c041e7aed..3b79719f20034 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -173,14 +173,7 @@ object Strategy { * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algoStr "Classification" or "Regression" */ - def defaultStrategy(algoStr: String): Strategy = algoStr match { - case "Classification" => - new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, - numClasses = 2) - case "Regression" => - new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, - numClasses = 0) - } + def defaultStrategy(algoStr: String): Strategy = defaultStategy(Algo.fromString(algoStr)) /** * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] From 87bab1c427f917a757a30af56288f010d16a4107 Mon Sep 17 00:00:00 2001 From: Basin Date: Thu, 22 Jan 2015 13:02:12 +0800 Subject: [PATCH 7/7] Docs and Code documentations updated. --- docs/mllib-ensembles.md | 8 ++++---- .../spark/mllib/tree/configuration/BoostingStrategy.scala | 8 +++----- .../apache/spark/mllib/tree/configuration/Strategy.scala | 6 ++++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index bc6fac30cb75d..23ede04b62d5b 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -456,7 +456,7 @@ val (trainingData, testData) = (splits(0), splits(1)) // Train a GradientBoostedTrees model. // The defaultParams for Classification use LogLoss by default. -val boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification) +val boostingStrategy = BoostingStrategy.defaultParams("Classification") boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. boostingStrategy.treeStrategy.numClassesForClassification = 2 boostingStrategy.treeStrategy.maxDepth = 5 @@ -506,7 +506,7 @@ JavaRDD testData = splits[1]; // Train a GradientBoostedTrees model. // The defaultParams for Classification use LogLoss by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Classification()); +BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification"); boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. boostingStrategy.getTreeStrategy().setNumClassesForClassification(2); boostingStrategy.getTreeStrategy().setMaxDepth(5); @@ -564,7 +564,7 @@ val (trainingData, testData) = (splits(0), splits(1)) // Train a GradientBoostedTrees model. // The defaultParams for Regression use SquaredError by default. -val boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression) +val boostingStrategy = BoostingStrategy.defaultParams("Regression") boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. boostingStrategy.treeStrategy.maxDepth = 5 // Empty categoricalFeaturesInfo indicates all features are continuous. @@ -614,7 +614,7 @@ JavaRDD testData = splits[1]; // Train a GradientBoostedTrees model. // The defaultParams for Regression use SquaredError by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(Algo.Regression()); +BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression"); boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. boostingStrategy.getTreeStrategy().setMaxDepth(5); // Empty categoricalFeaturesInfo indicates all features are continuous. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index c0f464669b9c1..ed8e6a796f8c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -70,13 +70,11 @@ object BoostingStrategy { /** * Returns default configuration for the boosting algorithm - * @param algoStr Learning goal. Supported: - * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], - * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @param algo Learning goal. Supported: "Classification" or "Regression" * @return Configuration for boosting algorithm */ - def defaultParams(algoStr: String): BoostingStrategy = { - defaultParams(Algo.fromString(algoStr)) + def defaultParams(algo: String): BoostingStrategy = { + defaultParams(Algo.fromString(algo)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 3b79719f20034..972959885f396 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -171,9 +171,11 @@ object Strategy { /** * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] - * @param algoStr "Classification" or "Regression" + * @param algo "Classification" or "Regression" */ - def defaultStrategy(algoStr: String): Strategy = defaultStategy(Algo.fromString(algoStr)) + def defaultStrategy(algo: String): Strategy = { + defaultStategy(Algo.fromString(algo)) + } /** * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]