Skip to content

Commit 197ec82

Browse files
committed
callUDF => udf in OneVsRest
1 parent 84d6780 commit 197ec82

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ final class OneVsRestModel private[ml] (
9090
val accColName = "mbc$acc" + UUID.randomUUID().toString
9191
val init: () => Map[Int, Double] = () => {Map()}
9292
val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false)
93-
val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
93+
val newDataset = dataset.withColumn(accColName, udf(init).apply())
9494

9595
// persist if underlying dataset is not persistent.
9696
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -110,9 +110,9 @@ final class OneVsRestModel private[ml] (
110110
(predictions: Map[Int, Double], prediction: Vector) => {
111111
predictions + ((index, prediction(1)))
112112
}
113-
val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
113+
val updateUDF = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
114114
val transformedDataset = model.transform(df).select(columns : _*)
115-
val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
115+
val updatedDataset = transformedDataset.withColumn(tmpColName, updateUDF)
116116
val newColumns = origCols ++ List(col(tmpColName))
117117

118118
// switch out the intermediate column with the accumulator column
@@ -129,8 +129,8 @@ final class OneVsRestModel private[ml] (
129129
}
130130

131131
// output label and label metadata as prediction
132-
val labelUdf = callUDF(label, DoubleType, col(accColName))
133-
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
132+
val labelUDF = udf(label).apply(col(accColName))
133+
aggregatedDataset.withColumn($(predictionCol), labelUDF.as($(predictionCol), labelMetadata))
134134
.drop(accColName)
135135
}
136136

@@ -175,12 +175,12 @@ final class OneVsRest(override val uid: String)
175175
}
176176
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
177177

178-
val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
178+
val multiClassLabeled = dataset.select($(labelCol), $(featuresCol))
179179

180180
// persist if underlying dataset is not persistent.
181181
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
182182
if (handlePersistence) {
183-
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
183+
multiClassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
184184
}
185185

186186
// create k columns, one for each binary classifier.
@@ -192,17 +192,17 @@ final class OneVsRest(override val uid: String)
192192

193193
// generate new label metadata for the binary problem.
194194
// TODO: use when ... otherwise after SPARK-7321 is merged
195-
val labelUDF = callUDF(label, DoubleType, col($(labelCol)))
195+
val labelUDF = udf(label).apply(col($(labelCol)))
196196
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
197197
val labelColName = "mc2b$" + index
198198
val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
199-
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
199+
val trainingDataset = multiClassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
200200
val classifier = getClassifier
201201
classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
202202
}.toArray[ClassificationModel[_, _]]
203203

204204
if (handlePersistence) {
205-
multiclassLabeled.unpersist()
205+
multiClassLabeled.unpersist()
206206
}
207207

208208
// extract label metadata from label column if present, or create a nominal attribute

0 commit comments

Comments
 (0)