Skip to content

Commit 1305492

Browse files
committed
uniformized udf calls in Classifier
1 parent a672228 commit 1305492

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,20 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
102102
var outputData = dataset
103103
var numColsOutput = 0
104104
if (getRawPredictionCol != "") {
105-
outputData = outputData.withColumn(getRawPredictionCol,
106-
callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol)))
105+
val predictRawUDF = udf { (features: Any) =>
106+
predictRaw(features.asInstanceOf[FeaturesType])
107+
}
108+
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
107109
numColsOutput += 1
108110
}
109111
if (getPredictionCol != "") {
110112
val predUDF = if (getRawPredictionCol != "") {
111-
udf[Double, Vector](raw2prediction).apply(col(getRawPredictionCol))
113+
udf(raw2prediction _).apply(col(getRawPredictionCol))
112114
} else {
113-
callUDF(predict _, DoubleType, col(getFeaturesCol))
115+
val predictUDF = udf { (features: Any) =>
116+
predict(features.asInstanceOf[FeaturesType])
117+
}
118+
predictUDF(col(getFeaturesCol))
114119
}
115120
outputData = outputData.withColumn(getPredictionCol, predUDF)
116121
numColsOutput += 1

0 commit comments

Comments
 (0)