diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index 2c30a1d9aa94..ad49d2595adc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -231,8 +231,12 @@ class PowerIterationClustering private[clustering] ( dataset.schema($(idCol)).dataType match { case _: LongType => uncastPredictions + case _: IntegerType => + uncastPredictions.withColumn($(idCol), col($(idCol)).cast(LongType)) case otherType => - uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) + throw new IllegalArgumentException(s"PowerIterationClustering had an unexpected error: " + + s"ID col was found to be of type ${otherType.simpleString}, despite initial schema " + + s"checks. Please report this bug.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index 65328df17baf..db0f470dca85 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -103,24 +103,21 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setK(2) .setMaxIter(1) - def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = { + def runTest(idType: DataType, similarityType: DataType): Unit = { val typedData = data.select( col("id").cast(idType).alias("id"), - col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"), + col("neighbors").cast(ArrayType(idType, containsNull = false)).alias("neighbors"), col("similarities").cast(ArrayType(similarityType, containsNull = false)) .alias("similarities") ) - model.transform(typedData).collect() + model.transform(typedData).select("id", "prediction").collect() } for (idType <- Seq(IntegerType, LongType)) { - runTest(idType, LongType, DoubleType) - } - for (neighborType <- Seq(IntegerType, LongType)) { - runTest(LongType, neighborType, DoubleType) + runTest(idType, DoubleType) } for (similarityType <- Seq(FloatType, DoubleType)) { - runTest(LongType, LongType, similarityType) + runTest(LongType, similarityType) } }