-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-18291][SparkR][ML] SparkR glm predict should output original label when family = binomial. #15788
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-18291][SparkR][ML] SparkR glm predict should output original label when family = binomial. #15788
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,11 +23,16 @@ import org.json4s.JsonDSL._ | |
| import org.json4s.jackson.JsonMethods._ | ||
|
|
||
| import org.apache.spark.ml.{Pipeline, PipelineModel} | ||
| import org.apache.spark.ml.attribute.AttributeGroup | ||
| import org.apache.spark.ml.feature.RFormula | ||
| import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} | ||
| import org.apache.spark.ml.feature.{IndexToString, RFormula} | ||
| import org.apache.spark.ml.regression._ | ||
| import org.apache.spark.ml.Transformer | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.ml.param.shared._ | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql._ | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| private[r] class GeneralizedLinearRegressionWrapper private ( | ||
| val pipeline: PipelineModel, | ||
|
|
@@ -42,6 +47,8 @@ private[r] class GeneralizedLinearRegressionWrapper private ( | |
| val rNumIterations: Int, | ||
| val isLoaded: Boolean = false) extends MLWritable { | ||
|
|
||
| import GeneralizedLinearRegressionWrapper._ | ||
|
|
||
| private val glm: GeneralizedLinearRegressionModel = | ||
| pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] | ||
|
|
||
|
|
@@ -52,7 +59,15 @@ private[r] class GeneralizedLinearRegressionWrapper private ( | |
| def residuals(residualsType: String): DataFrame = glm.summary.residuals(residualsType) | ||
|
|
||
| def transform(dataset: Dataset[_]): DataFrame = { | ||
| pipeline.transform(dataset).drop(glm.getFeaturesCol) | ||
| if (rFamily == "binomial") { | ||
| pipeline.transform(dataset) | ||
| .drop(PREDICTED_LABEL_PROB_COL) | ||
| .drop(PREDICTED_LABEL_INDEX_COL) | ||
| .drop(glm.getFeaturesCol) | ||
| } else { | ||
| pipeline.transform(dataset) | ||
| .drop(glm.getFeaturesCol) | ||
| } | ||
| } | ||
|
|
||
| override def write: MLWriter = | ||
|
|
@@ -62,6 +77,10 @@ private[r] class GeneralizedLinearRegressionWrapper private ( | |
| private[r] object GeneralizedLinearRegressionWrapper | ||
| extends MLReadable[GeneralizedLinearRegressionWrapper] { | ||
|
|
||
| val PREDICTED_LABEL_PROB_COL = "pred_label_prob" | ||
| val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" | ||
| val PREDICTED_LABEL_COL = "prediction" | ||
|
|
||
| def fit( | ||
| formula: String, | ||
| data: DataFrame, | ||
|
|
@@ -71,8 +90,8 @@ private[r] object GeneralizedLinearRegressionWrapper | |
| maxIter: Int, | ||
| weightCol: String, | ||
| regParam: Double): GeneralizedLinearRegressionWrapper = { | ||
| val rFormula = new RFormula() | ||
| .setFormula(formula) | ||
| val rFormula = new RFormula().setFormula(formula) | ||
| if (family == "binomial") rFormula.setForceIndexLabel(true) | ||
| RWrapperUtils.checkDataColumns(rFormula, data) | ||
| val rFormulaModel = rFormula.fit(data) | ||
| // get labels and feature names from output schema | ||
|
|
@@ -90,9 +109,27 @@ private[r] object GeneralizedLinearRegressionWrapper | |
| .setWeightCol(weightCol) | ||
| .setRegParam(regParam) | ||
| .setFeaturesCol(rFormula.getFeaturesCol) | ||
| val pipeline = new Pipeline() | ||
| .setStages(Array(rFormulaModel, glr)) | ||
| .fit(data) | ||
| val pipeline = if (family == "binomial") { | ||
| // Convert prediction from probability to label index. | ||
| val probToPred = new ProbabilityToPrediction() | ||
| .setInputCol(PREDICTED_LABEL_PROB_COL) | ||
| .setOutputCol(PREDICTED_LABEL_INDEX_COL) | ||
| // Convert prediction from label index to original label. | ||
| val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) | ||
| .asInstanceOf[NominalAttribute] | ||
| val labels = labelAttr.values.get | ||
| val idxToStr = new IndexToString() | ||
| .setInputCol(PREDICTED_LABEL_INDEX_COL) | ||
| .setOutputCol(PREDICTED_LABEL_COL) | ||
| .setLabels(labels) | ||
|
|
||
| new Pipeline() | ||
| .setStages(Array(rFormulaModel, glr.setPredictionCol(PREDICTED_LABEL_PROB_COL), | ||
| probToPred, idxToStr)) | ||
| .fit(data) | ||
| } else { | ||
| new Pipeline().setStages(Array(rFormulaModel, glr)).fit(data) | ||
| } | ||
|
|
||
| val glm: GeneralizedLinearRegressionModel = | ||
| pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] | ||
|
|
@@ -200,3 +237,27 @@ private[r] object GeneralizedLinearRegressionWrapper | |
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * This utility transformer converts the predicted value of GeneralizedLinearRegressionModel | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps this could be reusable and should go to RWrapperUtils.scala?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an inherent feature of other classification algorithms who extends |
||
| * with "binomial" family from probability to prediction according to threshold 0.5. | ||
| */ | ||
| private[r] class ProbabilityToPrediction private[r] (override val uid: String) | ||
| extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { | ||
|
|
||
| def this() = this(Identifiable.randomUID("probToPred")) | ||
|
|
||
| def setInputCol(value: String): this.type = set(inputCol, value) | ||
|
|
||
| def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
|
||
| override def transformSchema(schema: StructType): StructType = { | ||
| StructType(schema.fields :+ StructField($(outputCol), DoubleType)) | ||
| } | ||
|
|
||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| dataset.withColumn($(outputCol), round(col($(inputCol)))) | ||
| } | ||
|
|
||
| override def copy(extra: ParamMap): ProbabilityToPrediction = defaultCopy(extra) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to make R models incompatible with other languages when you persist them. (They already are because of having hidden Pipelines, but that is fixable.) This, however, will encode special behavior which is only respected when the model is loaded from R, not from other languages.
One option is to encode this in a SQLTransformer.
I'm also worried that these hard-coded columns names will lead to future bug reports about conflicting input column names.
It looks like this same issue appears in other PRs for R, such as [https://issues.apache.org/jira/browse/SPARK-18401]. Given the pervasiveness and that we're in QA right now, I'd recommend we not worry about it for 2.1 and delay fixing it until 2.2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I totally agree the hard-coded column names issues should be fixed, and already have some ideas in my mind to improve SparkR ML wrappers(which include this). This can be placed in the plan of next release version and I will write simple design documents for reviewing.
For the
ProbabilityToPredictionissue, the idea ofSQLTransformersounds good and I will try to fix it in follow-up PR. Thanks.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds great--thanks!