Skip to content

Commit 7644e97

Browse files
committed
Encode probability to prediction by SQLTransformer.
1 parent e2318ed commit 7644e97

File tree

1 file changed

+5
-34
lines changed

1 file changed

+5
-34
lines changed

mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,11 @@ import org.json4s.jackson.JsonMethods._
2424

2525
import org.apache.spark.ml.{Pipeline, PipelineModel}
2626
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
27-
import org.apache.spark.ml.feature.{IndexToString, RFormula}
28-
import org.apache.spark.ml.regression._
29-
import org.apache.spark.ml.Transformer
30-
import org.apache.spark.ml.param.ParamMap
31-
import org.apache.spark.ml.param.shared._
27+
import org.apache.spark.ml.feature.{IndexToString, RFormula, SQLTransformer}
3228
import org.apache.spark.ml.r.RWrapperUtils._
29+
import org.apache.spark.ml.regression._
3330
import org.apache.spark.ml.util._
3431
import org.apache.spark.sql._
35-
import org.apache.spark.sql.functions._
36-
import org.apache.spark.sql.types._
3732

3833
private[r] class GeneralizedLinearRegressionWrapper private (
3934
val pipeline: PipelineModel,
@@ -114,9 +109,9 @@ private[r] object GeneralizedLinearRegressionWrapper
114109
.setLabelCol(rFormula.getLabelCol)
115110
val pipeline = if (family == "binomial") {
116111
// Convert prediction from probability to label index.
117-
val probToPred = new ProbabilityToPrediction()
118-
.setInputCol(PREDICTED_LABEL_PROB_COL)
119-
.setOutputCol(PREDICTED_LABEL_INDEX_COL)
112+
val statement =
113+
s"SELECT *, ROUND($PREDICTED_LABEL_PROB_COL) AS $PREDICTED_LABEL_INDEX_COL FROM __THIS__"
114+
val probToPred = new SQLTransformer().setStatement(statement)
120115
// Convert prediction from label index to original label.
121116
val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
122117
.asInstanceOf[NominalAttribute]
@@ -248,27 +243,3 @@ private[r] object GeneralizedLinearRegressionWrapper
248243
}
249244
}
250245
}
251-
252-
/**
253-
* This utility transformer converts the predicted value of GeneralizedLinearRegressionModel
254-
* with "binomial" family from probability to prediction according to threshold 0.5.
255-
*/
256-
private[r] class ProbabilityToPrediction private[r] (override val uid: String)
257-
extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
258-
259-
def this() = this(Identifiable.randomUID("probToPred"))
260-
261-
def setInputCol(value: String): this.type = set(inputCol, value)
262-
263-
def setOutputCol(value: String): this.type = set(outputCol, value)
264-
265-
override def transformSchema(schema: StructType): StructType = {
266-
StructType(schema.fields :+ StructField($(outputCol), DoubleType))
267-
}
268-
269-
override def transform(dataset: Dataset[_]): DataFrame = {
270-
dataset.withColumn($(outputCol), round(col($(inputCol))))
271-
}
272-
273-
override def copy(extra: ParamMap): ProbabilityToPrediction = defaultCopy(extra)
274-
}

0 commit comments

Comments
 (0)