diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index fcfe83ceb863a..66cdfd91cd831 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -22,9 +22,15 @@ package org.apache.spark.sql.catalyst.expressions * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ -class InterpretedProjection(expressions: Seq[Expression]) extends Projection { - def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = - this(expressions.map(BindReferences.bindReference(_, inputSchema))) +class InterpretedProjection(expressions: Seq[Expression], mutableRow: Boolean = false) + extends Projection { + + def this( + expressions: Seq[Expression], + inputSchema: Seq[Attribute], + mutableRow: Boolean = false) = { + this(expressions.map(BindReferences.bindReference(_, inputSchema)), mutableRow) + } // null check is required for when Kryo invokes the no-arg constructor. protected val exprArray = if (expressions != null) expressions.toArray else null @@ -36,7 +42,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { outputArray(i) = exprArray(i).eval(input) i += 1 } - new GenericInternalRow(outputArray) + if (mutableRow) new GenericMutableRow(outputArray) else new GenericInternalRow(outputArray) } override def toString: String = s"Row => [${exprArray.mkString(",")}]" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 44930f82b53a0..a6d66fff30290 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -241,7 +241,7 @@ case class GeneratedAggregate( child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) - val newAggregationBuffer = newProjection(initialValues, child.output) + val newAggregationBuffer = newProjection(initialValues, child.output, mutableRow = true) log.info(s"Initial values: ${initialValues.mkString(",")}") // A projection that computes the group given an input tuple. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 7739a9f949c77..a885a4df3c07c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -153,13 +153,15 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } protected def newProjection( - expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { + expressions: Seq[Expression], + inputSchema: Seq[Attribute], + mutableRow: Boolean = false): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if (codegenEnabled) { GenerateProjection.generate(expressions, inputSchema) } else { - new InterpretedProjection(expressions, inputSchema) + new InterpretedProjection(expressions, inputSchema, mutableRow) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala new file mode 100644 index 0000000000000..ffee6ff469c4a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types.DataTypes._ + +class AggregateSuite extends SparkPlanTest { + + test("SPARK-8826 Fix ClassCastException in GeneratedAggregate") { + + // when codegen = false, CCE is thrown if group-by expression is empty or unsafe is disabled + val input = Seq(("Hello", 4, 2.0)) + + val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED) + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, false) + try { + val df = input.toDF("a", "b", "c") + val colB = df.col("b").expr + val colC = df.col("c").expr + val aggrExpr = Alias(Count(Cast(colC, LongType)), "Count")() + + for (groupExpr <- Seq(Seq.empty, Seq(colB))) { + val aggregate = GeneratedAggregate(true, groupExpr, Seq(aggrExpr), false, _: SparkPlan) + // ok if it's not throws exception + checkAnswer(df, aggregate, (_, _) => None) + } + } finally { + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 13f3be8ca28d6..301f757ee6411 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -54,10 +54,7 @@ class SparkPlanTest extends SparkFunSuite { input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedAnswer: Seq[Row]): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } + checkAnswer(input, planFunction, compareCheck(expectedAnswer)) } /** @@ -71,12 +68,57 @@ class SparkPlanTest extends SparkFunSuite { input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedAnswer: Seq[A]): Unit = { - val expectedRows = expectedAnswer.map(Row.fromTuple) - SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match { + checkAnswer(input, planFunction, expectedAnswer.map(Row.fromTuple)) + } + + protected def checkAnswer[A <: Product : TypeTag]( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + f: (SparkPlan, Seq[Row]) => Option[String]): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, f) match { case Some(errorMessage) => fail(errorMessage) case None => } } + + private def compareCheck(expectedAnswer: Seq[Row]): (SparkPlan, Seq[Row]) => Option[String] = { + (outputPlan: SparkPlan, sparkAnswer: Seq[Row]) => { + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | Results do not match for Spark plan: + | $outputPlan + | == Results == + | ${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n") + } + """.stripMargin + Some(errorMessage) + } else { + None + } + } + } + + protected def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + converted.sortBy(_.toString()) + } } /** @@ -89,12 +131,12 @@ object SparkPlanTest { * @param input the input data to be used. * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param checker check result if it's valid. */ def checkAnswer( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedAnswer: Seq[Row]): Option[String] = { + checker: (SparkPlan, Seq[Row]) => Option[String]): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) @@ -114,23 +156,6 @@ object SparkPlanTest { } } - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - // This function is copied from Catalyst's QueryTest - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } - converted.sortBy(_.toString()) - } - val sparkAnswer: Seq[Row] = try { resolvedPlan.executeCollect().toSeq } catch { @@ -146,22 +171,7 @@ object SparkPlanTest { return Some(errorMessage) } - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = - s""" - | Results do not match for Spark plan: - | $outputPlan - | == Results == - | ${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString()), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} - """.stripMargin - return Some(errorMessage) - } - - None + checker.apply(outputPlan, sparkAnswer) } }