From 8ee56bbfaacdd64b1712d72650a39939ca3b13f2 Mon Sep 17 00:00:00 2001 From: yucai Date: Thu, 9 Aug 2018 22:19:43 -0700 Subject: [PATCH 1/3] "distribute by" on multiple columns may lead to codegen issue --- .../scala/org/apache/spark/sql/catalyst/expressions/hash.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index cec00b66f873c..5dfe9866fd604 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -411,7 +411,7 @@ abstract class HashExpression[E] extends Expression { ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, hashResultType -> result), + arguments = Seq(hashResultType -> result), returnType = hashResultType, makeSplitFunction = body => s""" From 931fa28861f15ef1c31a51787f3bd59f2284de89 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 10 Aug 2018 19:35:54 +0800 Subject: [PATCH 2/3] fix UT failure and add new test --- .../spark/sql/catalyst/expressions/hash.scala | 23 ++++++++++++++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 13 +++++++++++ 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 5dfe9866fd604..a754e87a17968 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -404,14 +404,15 @@ abstract class HashExpression[E] extends Expression { input: String, result: String, fields: Array[StructField]): String = { + val tmpInput = ctx.freshName("input") val fieldsHash = fields.zipWithIndex.map { case (field, index) => - nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) + nullSafeElementHash(tmpInput, index.toString, field.nullable, field.dataType, result, ctx) } val hashResultType = CodeGenerator.javaType(dataType) - ctx.splitExpressions( + val code = ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq(hashResultType -> result), + arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result), returnType = hashResultType, makeSplitFunction = body => s""" @@ -419,6 +420,10 @@ abstract class HashExpression[E] extends Expression { |return $result; """.stripMargin, foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + s""" + |final InternalRow $tmpInput = $input; + |$code + """.stripMargin } @tailrec @@ -778,10 +783,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { input: String, result: String, fields: Array[StructField]): String = { + val tmpInput = ctx.freshName("input") val childResult = ctx.freshName("childResult") val fieldsHash = fields.zipWithIndex.map { case (field, index) => val computeFieldHash = nullSafeElementHash( - input, index.toString, field.nullable, field.dataType, childResult, ctx) + tmpInput, index.toString, field.nullable, field.dataType, childResult, ctx) s""" |$childResult = 0; |$computeFieldHash @@ -789,10 +795,10 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + val code = ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result), + arguments = Seq("InternalRow" -> tmpInput, CodeGenerator.JAVA_INT -> result), returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" @@ -801,6 +807,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { |return $result; """.stripMargin, foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + s""" + |final InternalRow $tmpInput = $input; + |${CodeGenerator.JAVA_INT} $childResult = 0; + |$code + """.stripMargin } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3a393d766b0bc..8014571b80437 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2831,4 +2831,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-25084: 'distribute by' on multiple columns may lead to codegen issue") { + withView("spark_25084") { + val count = 1000 + val df = spark.range(count) + val columns = (0 until 400).map{ i => s"id as id$i" } + val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",") + df.selectExpr(columns : _*).createTempView("spark_25084") + assert( + spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count() + === count) + } + } } From 5da6d4d437c6eb2bdf7a64c031b7c9281a5a8b83 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 10 Aug 2018 22:09:28 +0800 Subject: [PATCH 3/3] minor --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8014571b80437..36a999b4a234e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2840,8 +2840,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",") df.selectExpr(columns : _*).createTempView("spark_25084") assert( - spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count() - === count) + spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count === count) } } }