Skip to content

Commit 0ebd0da

Browse files
committed
replace the now deprecated callUDF by udf in VectorIndexer
1 parent 8013409 commit 0ebd0da

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.ml.param.shared._
3030
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
3131
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
3232
import org.apache.spark.sql.{DataFrame, Row}
33-
import org.apache.spark.sql.functions.callUDF
33+
import org.apache.spark.sql.functions.udf
3434
import org.apache.spark.sql.types.{StructField, StructType}
3535
import org.apache.spark.util.collection.OpenHashSet
3636

@@ -339,7 +339,10 @@ class VectorIndexerModel private[ml] (
339339
override def transform(dataset: DataFrame): DataFrame = {
340340
transformSchema(dataset.schema, logging = true)
341341
val newField = prepOutputField(dataset.schema)
342-
val newCol = callUDF(transformFunc, new VectorUDT, dataset($(inputCol)))
342+
val transformUDF = udf { (vector: Any) =>
343+
transformFunc(vector.asInstanceOf[Vector])
344+
}
345+
val newCol = transformUDF(dataset($(inputCol)))
343346
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
344347
}
345348

0 commit comments

Comments
 (0)