-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-15767][R][ML] Decision Tree Regression wrapper in SparkR #13690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test build #60600 has finished for PR 13690 at commit
|
7ea9544 to
378607f
Compare
|
Test build #61050 has finished for PR 13690 at commit
|
|
Hi @vectorijk would you be interested in continuing this work? |
|
Yes, sure. But I'm in a vacation this week. I will keep working on this and On Thu, Aug 11, 2016, 19:46 Felix Cheung [email protected] wrote:
|
|
Great! based on earlier discussions we might want to call this |
|
ping @vectorijk Have you started working on the random forest wrapper. If not and you feel busy, I can also work on that :) |
|
Also, if you need any help with this PR, just let me know and we may work together to make it. |
|
@junyangq I have started working on random forest wrapper. I will open PR as soon as possible. Also, I'll update this PR very soon. Thanks. |
|
Sounds great. Thank you @vectorijk |
378607f to
f8b3484
Compare
|
Test build #64777 has finished for PR 13690 at commit
|
|
Test build #64776 has finished for PR 13690 at commit
|
|
@vectorijk Is this ready for another round of review ? |
|
@vectorijk hi - would you have time to update this? |
|
hi @vectorijk - would you have time to update this? If not, I will try to follow up basing on your work. |
|
@felixcheung I'll update the changes in this two days. |
R/pkg/R/mllib.R
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this has been updated to use an internal function - could you check?
R/pkg/R/mllib.R
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@note since 2.1.0 like others?
R/pkg/R/mllib.R
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add #' doc block
|
Thanks! Aside from having to rebase, there are some left over of "spark.rpart", a few some changes and also would be great to add tests for this. |
2835a7a to
b18b718
Compare
|
Test build #66438 has finished for PR 13690 at commit
|
|
Test build #66442 has finished for PR 13690 at commit
|
|
Test build #66448 has finished for PR 13690 at commit
|
|
@felixcheung @shivaram @junyangq It's ready for the review. |
|
could you fix the test failure? |
R/pkg/R/mllib.R
Outdated
| #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, | ||
| #' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} | ||
| #' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}, | ||
| #' @seealso \link{spark.decisionTree}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's keep this list sorted?
R/pkg/R/mllib.R
Outdated
| #' @seealso \link{spark.glm}, \link{glm}, | ||
| #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, | ||
| #' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} | ||
| #' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.decisionTree} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
| setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"), | ||
| function(data, formula, type = c("regression", "classification"), | ||
| maxDepth = 5, maxBins = 32 ) { | ||
| formula <- paste(deparse(formula), collapse = "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use match.arg to check type?
https://stat.ethz.ch/R-manual/R-devel/library/base/html/match.arg.html
|
|
||
| test_that("spark.decisionTree Regression", { | ||
| data <- suppressWarnings(createDataFrame(longley)) | ||
| model <- spark.decisionTree(data, Employed~., "regression", maxDepth = 5, maxBins = 16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could be more readable as Employed ~ . (with spaces)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed comments above
|
Test build #66567 has finished for PR 13690 at commit
|
| }) | ||
|
|
||
| test_that("spark.decisionTree Regression", { | ||
| data <- suppressWarnings(createDataFrame(longley)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a test for print (see spark.glm)
| #' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree | ||
| #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to | ||
| #' save/load fitted models. | ||
| #' For more details, see \href{https://en.wikipedia.org/wiki/Decision_tree_learning}{Decision Tree} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you point this url to the Spark programming guide, like http://spark.apache.org/docs/latest/ml-classification-regression.html
| #' @param data a SparkDataFrame for training. | ||
| #' @param formula a symbolic description of the model to be fitted. Currently only a few formula | ||
| #' operators are supported, including '~', ':', '+', and '-'. | ||
| #' @param type type of model to fit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add the types supported, eg. one of "regression" or "classification" as the type of model
| #' | ||
| #' # fit a Decision Tree Regression Model | ||
| #' model <- spark.decisionTree(data, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) | ||
| #' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add an example for "classification" too?
| #' @note spark.decisionTree since 2.1.0 | ||
| setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"), | ||
| function(data, formula, type = c("regression", "classification"), | ||
| maxDepth = 5, maxBins = 32 ) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: extra space after 32 )
|
|
||
| #' Save the Decision Tree Classification model to the input path. | ||
| #' | ||
| #' @param object A fitted Decision tree classification model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you check the output doc by running create-doc.sh - I think this will duplicate the object when the @rdname is changed - in that case, just have one instance of this and say "regression or classification model"
| #' which means throw exception if the output path exists. | ||
| #' | ||
| #' @aliases write.ml,DecisionTreeClassificationModel,character-method | ||
| #' @rdname spark.decisionTreeClassification |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change to @rdname spark.decisionTree
| #' @export | ||
| #' @note summary(DecisionTreeRegressionModel) since 2.1.0 | ||
| setMethod("summary", signature(object = "DecisionTreeRegressionModel"), | ||
| function(object, ...) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not put ... in signature here
|
|
||
| val rFormula = new RFormula() | ||
| .setFormula(formula) | ||
| .setFeaturesCol("features") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you take a look at another model wrapper (like NaiveBayesWrapper) and RWrapperUtils on how to handle DataFrame column name - this shouldn't be hardcoded here?
|
|
||
| val rFormula = new RFormula() | ||
| .setFormula(formula) | ||
| .setFeaturesCol("features") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| #' @export | ||
| #' @note summary(DecisionTreeClassificationModel) since 2.1.0 | ||
| setMethod("summary", signature(object = "DecisionTreeClassificationModel"), | ||
| function(object, ...) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
|
@felixcheung @vectorijk Should we close this PR ? |
|
@shivaram I will update this today. |
|
gentle ping @vectorijk |
What changes were proposed in this pull request?
Implement a wrapper in SparkR to support decision tree regression. R's naive Decision Tree Regression implementation is from package rpart with signature
rpart(formula, dataframe, method="anova"). I propose we could implement API likespark.rpart(dataframe, formula, ...). After having implemented decision tree classification, we could refactor this two into an API more likerpart().How was this patch tested?
Test with unit test in SparkR