Skip to content

Commit a672228

Browse files
committed
uniformized udf calls in OneVsRest
1 parent 49e4904 commit a672228

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,12 @@ final class OneVsRestModel private[ml] (
106106

107107
// add temporary column to store intermediate scores and update
108108
val tmpColName = "mbc$tmp" + UUID.randomUUID().toString
109-
val update: (Map[Int, Double], Vector) => Map[Int, Double] =
110-
(predictions: Map[Int, Double], prediction: Vector) => {
111-
predictions + ((index, prediction(1)))
112-
}
113-
val updateUDF = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
109+
val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
110+
predictions + ((index, prediction(1)))
111+
}
114112
val transformedDataset = model.transform(df).select(columns : _*)
115-
val updatedDataset = transformedDataset.withColumn(tmpColName, updateUDF)
113+
val updatedDataset = transformedDataset
114+
.withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol)))
116115
val newColumns = origCols ++ List(col(tmpColName))
117116

118117
// switch out the intermediate column with the accumulator column
@@ -124,13 +123,13 @@ final class OneVsRestModel private[ml] (
124123
}
125124

126125
// output the index of the classifier with highest confidence as prediction
127-
val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => {
126+
val labelUDF = udf { (predictions: Map[Int, Double]) =>
128127
predictions.maxBy(_._2)._1.toDouble
129128
}
130129

131130
// output label and label metadata as prediction
132-
val labelUDF = udf(label).apply(col(accColName))
133-
aggregatedDataset.withColumn($(predictionCol), labelUDF.as($(predictionCol), labelMetadata))
131+
aggregatedDataset
132+
.withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata))
134133
.drop(accColName)
135134
}
136135

@@ -175,34 +174,32 @@ final class OneVsRest(override val uid: String)
175174
}
176175
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
177176

178-
val multiClassLabeled = dataset.select($(labelCol), $(featuresCol))
177+
val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
179178

180179
// persist if underlying dataset is not persistent.
181180
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
182181
if (handlePersistence) {
183-
multiClassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
182+
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
184183
}
185184

186185
// create k columns, one for each binary classifier.
187186
val models = Range(0, numClasses).par.map { index =>
188-
189-
val label: Double => Double = (label: Double) => {
187+
val labelUDF = udf { (label: Double) =>
190188
if (label.toInt == index) 1.0 else 0.0
191189
}
192190

193191
// generate new label metadata for the binary problem.
194192
// TODO: use when ... otherwise after SPARK-7321 is merged
195-
val labelUDF = udf(label).apply(col($(labelCol)))
196193
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
197194
val labelColName = "mc2b$" + index
198-
val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
199-
val trainingDataset = multiClassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
195+
val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
196+
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
200197
val classifier = getClassifier
201198
classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
202199
}.toArray[ClassificationModel[_, _]]
203200

204201
if (handlePersistence) {
205-
multiClassLabeled.unpersist()
202+
multiclassLabeled.unpersist()
206203
}
207204

208205
// extract label metadata from label column if present, or create a nominal attribute

0 commit comments

Comments
 (0)