@@ -90,7 +90,7 @@ final class OneVsRestModel private[ml] (
9090 val accColName = " mbc$acc" + UUID .randomUUID().toString
9191 val init : () => Map [Int , Double ] = () => {Map ()}
9292 val mapType = MapType (IntegerType , DoubleType , valueContainsNull = false )
93- val newDataset = dataset.withColumn(accColName, callUDF (init, mapType ))
93+ val newDataset = dataset.withColumn(accColName, udf (init).apply( ))
9494
9595 // persist if underlying dataset is not persistent.
9696 val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel .NONE
@@ -110,9 +110,9 @@ final class OneVsRestModel private[ml] (
110110 (predictions : Map [Int , Double ], prediction : Vector ) => {
111111 predictions + ((index, prediction(1 )))
112112 }
113- val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
113+ val updateUDF = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
114114 val transformedDataset = model.transform(df).select(columns : _* )
115- val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf )
115+ val updatedDataset = transformedDataset.withColumn(tmpColName, updateUDF )
116116 val newColumns = origCols ++ List (col(tmpColName))
117117
118118 // switch out the intermediate column with the accumulator column
@@ -129,8 +129,8 @@ final class OneVsRestModel private[ml] (
129129 }
130130
131131 // output label and label metadata as prediction
132- val labelUdf = callUDF (label, DoubleType , col(accColName))
133- aggregatedDataset.withColumn($(predictionCol), labelUdf .as($(predictionCol), labelMetadata))
132+ val labelUDF = udf (label).apply( col(accColName))
133+ aggregatedDataset.withColumn($(predictionCol), labelUDF .as($(predictionCol), labelMetadata))
134134 .drop(accColName)
135135 }
136136
@@ -175,12 +175,12 @@ final class OneVsRest(override val uid: String)
175175 }
176176 val numClasses = MetadataUtils .getNumClasses(labelSchema).fold(computeNumClasses())(identity)
177177
178- val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
178+ val multiClassLabeled = dataset.select($(labelCol), $(featuresCol))
179179
180180 // persist if underlying dataset is not persistent.
181181 val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel .NONE
182182 if (handlePersistence) {
183- multiclassLabeled .persist(StorageLevel .MEMORY_AND_DISK )
183+ multiClassLabeled .persist(StorageLevel .MEMORY_AND_DISK )
184184 }
185185
186186 // create k columns, one for each binary classifier.
@@ -192,17 +192,17 @@ final class OneVsRest(override val uid: String)
192192
193193 // generate new label metadata for the binary problem.
194194 // TODO: use when ... otherwise after SPARK-7321 is merged
195- val labelUDF = callUDF (label, DoubleType , col($(labelCol)))
195+ val labelUDF = udf (label).apply( col($(labelCol)))
196196 val newLabelMeta = BinaryAttribute .defaultAttr.withName(" label" ).toMetadata()
197197 val labelColName = " mc2b$" + index
198198 val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
199- val trainingDataset = multiclassLabeled .withColumn(labelColName, labelUDFWithNewMeta)
199+ val trainingDataset = multiClassLabeled .withColumn(labelColName, labelUDFWithNewMeta)
200200 val classifier = getClassifier
201201 classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
202202 }.toArray[ClassificationModel [_, _]]
203203
204204 if (handlePersistence) {
205- multiclassLabeled .unpersist()
205+ multiClassLabeled .unpersist()
206206 }
207207
208208 // extract label metadata from label column if present, or create a nominal attribute
0 commit comments