Skip to content

Commit 417d13f

Browse files
committed
fix error caused by treating nparray as list
1 parent a296a86 commit 417d13f

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

python/pyspark/ml/classification.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,18 +1054,12 @@ def _transform(self, dataset):
10541054
initUDF = udf(lambda _: [], ArrayType(DoubleType()))
10551055
newDataset = dataset.withColumn(accColName, initUDF(dataset[origCols[0]]))
10561056

1057-
newDataset.show()
1058-
10591057
# persist if underlying dataset is not persistent.
10601058
handlePersistence =\
10611059
dataset.rdd.getStorageLevel() == StorageLevel(False, False, False, False)
10621060
if handlePersistence:
10631061
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
10641062

1065-
def updateDict(predictions, i, prediction):
1066-
predictions[i] = prediction[1]
1067-
return predictions
1068-
10691063
# update the accumulator column with the result of prediction of models
10701064
aggregatedDataset = newDataset
10711065
for index, model in enumerate(self.models):
@@ -1075,7 +1069,7 @@ def updateDict(predictions, i, prediction):
10751069
# add temporary column to store intermediate scores and update
10761070
tmpColName = "mbc$tmp" + str(uuid.uuid4())
10771071
updateUDF = udf(
1078-
lambda predictions, prediction: predictions + [prediction[1]],
1072+
lambda predictions, prediction: predictions + [prediction.tolist()[1]],
10791073
ArrayType(DoubleType()))
10801074
transformedDataset = model.transform(aggregatedDataset).select(*columns)
10811075
updatedDataset = transformedDataset.withColumn(
@@ -1084,19 +1078,19 @@ def updateDict(predictions, i, prediction):
10841078
newColumns = origCols + [tmpColName]
10851079

10861080
# switch out the intermediate column with the accumulator column
1087-
updatedDataset.select(*newColumns).withColumnRenamed(tmpColName, accColName)
1088-
aggregatedDataset = updatedDataset
1081+
aggregatedDataset = updatedDataset\
1082+
.select(*newColumns).withColumnRenamed(tmpColName, accColName)
10891083

10901084
if handlePersistence:
10911085
newDataset.unpersist()
10921086

1093-
return aggregatedDataset
10941087
# output the index of the classifier with highest confidence as prediction
1095-
# labelUDF = udf(lambda predictions: float(max(predictions, key=predictions.get)))
1088+
labelUDF = udf(
1089+
lambda predictions: float(max(enumerate(predictions), key=operator.itemgetter(1))[0]))
10961090

10971091
# output label and label metadata as prediction
1098-
# return aggregatedDataset.withColumn(
1099-
# self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])).drop(accColName)
1092+
return aggregatedDataset.withColumn(
1093+
self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])).drop(accColName)
11001094

11011095
# @since("1.4.0")
11021096
# def copy(self, extra=None):

0 commit comments

Comments
 (0)