@@ -1054,18 +1054,12 @@ def _transform(self, dataset):
10541054 initUDF = udf (lambda _ : [], ArrayType (DoubleType ()))
10551055 newDataset = dataset .withColumn (accColName , initUDF (dataset [origCols [0 ]]))
10561056
1057- newDataset .show ()
1058-
10591057 # persist if underlying dataset is not persistent.
10601058 handlePersistence = \
10611059 dataset .rdd .getStorageLevel () == StorageLevel (False , False , False , False )
10621060 if handlePersistence :
10631061 newDataset .persist (StorageLevel .MEMORY_AND_DISK )
10641062
1065- def updateDict (predictions , i , prediction ):
1066- predictions [i ] = prediction [1 ]
1067- return predictions
1068-
10691063 # update the accumulator column with the result of prediction of models
10701064 aggregatedDataset = newDataset
10711065 for index , model in enumerate (self .models ):
@@ -1075,7 +1069,7 @@ def updateDict(predictions, i, prediction):
10751069 # add temporary column to store intermediate scores and update
10761070 tmpColName = "mbc$tmp" + str (uuid .uuid4 ())
10771071 updateUDF = udf (
1078- lambda predictions , prediction : predictions + [prediction [1 ]],
1072+ lambda predictions , prediction : predictions + [prediction . tolist () [1 ]],
10791073 ArrayType (DoubleType ()))
10801074 transformedDataset = model .transform (aggregatedDataset ).select (* columns )
10811075 updatedDataset = transformedDataset .withColumn (
@@ -1084,19 +1078,19 @@ def updateDict(predictions, i, prediction):
10841078 newColumns = origCols + [tmpColName ]
10851079
10861080 # switch out the intermediate column with the accumulator column
1087- updatedDataset . select ( * newColumns ). withColumnRenamed ( tmpColName , accColName )
1088- aggregatedDataset = updatedDataset
1081+ aggregatedDataset = updatedDataset \
1082+ . select ( * newColumns ). withColumnRenamed ( tmpColName , accColName )
10891083
10901084 if handlePersistence :
10911085 newDataset .unpersist ()
10921086
1093- return aggregatedDataset
10941087 # output the index of the classifier with highest confidence as prediction
1095- # labelUDF = udf(lambda predictions: float(max(predictions, key=predictions.get)))
1088+ labelUDF = udf (
1089+ lambda predictions : float (max (enumerate (predictions ), key = operator .itemgetter (1 ))[0 ]))
10961090
10971091 # output label and label metadata as prediction
1098- # return aggregatedDataset.withColumn(
1099- # self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])).drop(accColName)
1092+ return aggregatedDataset .withColumn (
1093+ self .getPredictionCol (), labelUDF (aggregatedDataset [accColName ])).drop (accColName )
11001094
11011095 # @since("1.4.0")
11021096 # def copy(self, extra=None):
0 commit comments