@@ -106,13 +106,12 @@ final class OneVsRestModel private[ml] (
106106
107107 // add temporary column to store intermediate scores and update
108108 val tmpColName = " mbc$tmp" + UUID .randomUUID().toString
109- val update : (Map [Int , Double ], Vector ) => Map [Int , Double ] =
110- (predictions : Map [Int , Double ], prediction : Vector ) => {
111- predictions + ((index, prediction(1 )))
112- }
113- val updateUDF = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
109+ val updateUDF = udf { (predictions : Map [Int , Double ], prediction : Vector ) =>
110+ predictions + ((index, prediction(1 )))
111+ }
114112 val transformedDataset = model.transform(df).select(columns : _* )
115- val updatedDataset = transformedDataset.withColumn(tmpColName, updateUDF)
113+ val updatedDataset = transformedDataset
114+ .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol)))
116115 val newColumns = origCols ++ List (col(tmpColName))
117116
118117 // switch out the intermediate column with the accumulator column
@@ -124,13 +123,13 @@ final class OneVsRestModel private[ml] (
124123 }
125124
126125 // output the index of the classifier with highest confidence as prediction
127- val label : Map [ Int , Double ] => Double = (predictions : Map [Int , Double ]) => {
126+ val labelUDF = udf { (predictions : Map [Int , Double ]) =>
128127 predictions.maxBy(_._2)._1.toDouble
129128 }
130129
131130 // output label and label metadata as prediction
132- val labelUDF = udf(label).apply(col(accColName))
133- aggregatedDataset .withColumn($(predictionCol), labelUDF.as($(predictionCol), labelMetadata))
131+ aggregatedDataset
132+ .withColumn($(predictionCol), labelUDF(col(accColName)) .as($(predictionCol), labelMetadata))
134133 .drop(accColName)
135134 }
136135
@@ -175,34 +174,32 @@ final class OneVsRest(override val uid: String)
175174 }
176175 val numClasses = MetadataUtils .getNumClasses(labelSchema).fold(computeNumClasses())(identity)
177176
178- val multiClassLabeled = dataset.select($(labelCol), $(featuresCol))
177+ val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
179178
180179 // persist if underlying dataset is not persistent.
181180 val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel .NONE
182181 if (handlePersistence) {
183- multiClassLabeled .persist(StorageLevel .MEMORY_AND_DISK )
182+ multiclassLabeled .persist(StorageLevel .MEMORY_AND_DISK )
184183 }
185184
186185 // create k columns, one for each binary classifier.
187186 val models = Range (0 , numClasses).par.map { index =>
188-
189- val label : Double => Double = (label : Double ) => {
187+ val labelUDF = udf { (label : Double ) =>
190188 if (label.toInt == index) 1.0 else 0.0
191189 }
192190
193191 // generate new label metadata for the binary problem.
194192 // TODO: use when ... otherwise after SPARK-7321 is merged
195- val labelUDF = udf(label).apply(col($(labelCol)))
196193 val newLabelMeta = BinaryAttribute .defaultAttr.withName(" label" ).toMetadata()
197194 val labelColName = " mc2b$" + index
198- val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
199- val trainingDataset = multiClassLabeled .withColumn(labelColName, labelUDFWithNewMeta)
195+ val labelUDFWithNewMeta = labelUDF(col($(labelCol))) .as(labelColName, newLabelMeta)
196+ val trainingDataset = multiclassLabeled .withColumn(labelColName, labelUDFWithNewMeta)
200197 val classifier = getClassifier
201198 classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
202199 }.toArray[ClassificationModel [_, _]]
203200
204201 if (handlePersistence) {
205- multiClassLabeled .unpersist()
202+ multiclassLabeled .unpersist()
206203 }
207204
208205 // extract label metadata from label column if present, or create a nominal attribute
0 commit comments