@@ -24,16 +24,11 @@ import org.json4s.jackson.JsonMethods._
2424
2525import org .apache .spark .ml .{Pipeline , PipelineModel }
2626import org .apache .spark .ml .attribute .{Attribute , AttributeGroup , NominalAttribute }
27- import org .apache .spark .ml .feature .{IndexToString , RFormula }
28- import org .apache .spark .ml .regression ._
29- import org .apache .spark .ml .Transformer
30- import org .apache .spark .ml .param .ParamMap
31- import org .apache .spark .ml .param .shared ._
27+ import org .apache .spark .ml .feature .{IndexToString , RFormula , SQLTransformer }
3228import org .apache .spark .ml .r .RWrapperUtils ._
29+ import org .apache .spark .ml .regression ._
3330import org .apache .spark .ml .util ._
3431import org .apache .spark .sql ._
35- import org .apache .spark .sql .functions ._
36- import org .apache .spark .sql .types ._
3732
3833private [r] class GeneralizedLinearRegressionWrapper private (
3934 val pipeline : PipelineModel ,
@@ -114,9 +109,9 @@ private[r] object GeneralizedLinearRegressionWrapper
114109 .setLabelCol(rFormula.getLabelCol)
115110 val pipeline = if (family == " binomial" ) {
116111 // Convert prediction from probability to label index.
117- val probToPred = new ProbabilityToPrediction ()
118- .setInputCol( PREDICTED_LABEL_PROB_COL )
119- .setOutputCol( PREDICTED_LABEL_INDEX_COL )
112+ val statement =
113+ s " SELECT *, ROUND( $ PREDICTED_LABEL_PROB_COL) AS $PREDICTED_LABEL_INDEX_COL FROM __THIS__ "
114+ val probToPred = new SQLTransformer ().setStatement(statement )
120115 // Convert prediction from label index to original label.
121116 val labelAttr = Attribute .fromStructField(schema(rFormulaModel.getLabelCol))
122117 .asInstanceOf [NominalAttribute ]
@@ -248,27 +243,3 @@ private[r] object GeneralizedLinearRegressionWrapper
248243 }
249244 }
250245}
251-
252- /**
253- * This utility transformer converts the predicted value of GeneralizedLinearRegressionModel
254- * with "binomial" family from probability to prediction according to threshold 0.5.
255- */
256- private [r] class ProbabilityToPrediction private [r] (override val uid : String )
257- extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
258-
259- def this () = this (Identifiable .randomUID(" probToPred" ))
260-
261- def setInputCol (value : String ): this .type = set(inputCol, value)
262-
263- def setOutputCol (value : String ): this .type = set(outputCol, value)
264-
265- override def transformSchema (schema : StructType ): StructType = {
266- StructType (schema.fields :+ StructField ($(outputCol), DoubleType ))
267- }
268-
269- override def transform (dataset : Dataset [_]): DataFrame = {
270- dataset.withColumn($(outputCol), round(col($(inputCol))))
271- }
272-
273- override def copy (extra : ParamMap ): ProbabilityToPrediction = defaultCopy(extra)
274- }
0 commit comments