Skip to content

Commit e7a6211

Browse files
committed
[SPARK-33338][SQL] GROUP BY using literal map should not fail
### What changes were proposed in this pull request? This PR aims to fix `semanticEquals` works correctly on `GetMapValue` expressions having literal maps with `ArrayBasedMapData` and `GenericArrayData`. ### Why are the changes needed? This is a regression from Apache Spark 1.6.x. ```scala scala> sc.version res1: String = 1.6.3 scala> sqlContext.sql("SELECT map('k1', 'v1')[k] FROM t GROUP BY map('k1', 'v1')[k]").show +---+ |_c0| +---+ | v1| +---+ ``` Apache Spark 2.x ~ 3.0.1 raise`RuntimeException` for the following queries. ```sql CREATE TABLE t USING ORC AS SELECT map('k1', 'v1') m, 'k1' k SELECT map('k1', 'v1')[k] FROM t GROUP BY 1 SELECT map('k1', 'v1')[k] FROM t GROUP BY map('k1', 'v1')[k] SELECT map('k1', 'v1')[k] a FROM t GROUP BY a ``` **BEFORE** ```scala Caused by: java.lang.RuntimeException: Couldn't find k#3 in [keys: [k1], values: [v1][k#3]#6] at scala.sys.package$.error(package.scala:27) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:85) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:79) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52) ``` **AFTER** ```sql spark-sql> SELECT map('k1', 'v1')[k] FROM t GROUP BY 1; v1 Time taken: 1.278 seconds, Fetched 1 row(s) spark-sql> SELECT map('k1', 'v1')[k] FROM t GROUP BY map('k1', 'v1')[k]; v1 Time taken: 0.313 seconds, Fetched 1 row(s) spark-sql> SELECT map('k1', 'v1')[k] a FROM t GROUP BY a; v1 Time taken: 0.265 seconds, Fetched 1 row(s) ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs with the newly added test case. Closes apache#30246 from dongjoon-hyun/SPARK-33338. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]> (cherry picked from commit 42c0b17) Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 5dd36f3 commit e7a6211

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,8 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
316316
(value, o.value) match {
317317
case (null, null) => true
318318
case (a: Array[Byte], b: Array[Byte]) => util.Arrays.equals(a, b)
319+
case (a: ArrayBasedMapData, b: ArrayBasedMapData) =>
320+
a.keyArray == b.keyArray && a.valueArray == b.valueArray
319321
case (a, b) => a != null && a.equals(b)
320322
}
321323
case _ => false

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.sql.Row
2222
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue}
2323
import org.apache.spark.sql.catalyst.dsl.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
25+
import org.apache.spark.sql.catalyst.util._
2526
import org.apache.spark.sql.internal.SQLConf
2627
import org.apache.spark.sql.types._
2728
import org.apache.spark.unsafe.types.UTF8String
@@ -466,4 +467,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
466467
CreateNamedStruct(Seq("a", "x", "b", 2.0)).genCode(ctx)
467468
assert(ctx.inlinedMutableStates.isEmpty)
468469
}
470+
471+
test("SPARK-33338: semanticEquals should handle static GetMapValue correctly") {
472+
val keys = new Array[UTF8String](1)
473+
val values = new Array[UTF8String](1)
474+
keys(0) = UTF8String.fromString("key")
475+
values(0) = UTF8String.fromString("value")
476+
477+
val d1 = new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
478+
val d2 = new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
479+
val m1 = GetMapValue(Literal.create(d1, MapType(StringType, StringType)), Literal("a"))
480+
val m2 = GetMapValue(Literal.create(d2, MapType(StringType, StringType)), Literal("a"))
481+
482+
assert(m1.semanticEquals(m2))
483+
}
469484
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3523,6 +3523,18 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
35233523
}
35243524
}
35253525
}
3526+
3527+
test("SPARK-33338: GROUP BY using literal map should not fail") {
3528+
withTempDir { dir =>
3529+
sql(s"CREATE TABLE t USING ORC LOCATION '${dir.toURI}' AS SELECT map('k1', 'v1') m, 'k1' k")
3530+
Seq(
3531+
"SELECT map('k1', 'v1')[k] FROM t GROUP BY 1",
3532+
"SELECT map('k1', 'v1')[k] FROM t GROUP BY map('k1', 'v1')[k]",
3533+
"SELECT map('k1', 'v1')[k] a FROM t GROUP BY a").foreach { statement =>
3534+
checkAnswer(sql(statement), Row("v1"))
3535+
}
3536+
}
3537+
}
35263538
}
35273539

35283540
case class Foo(bar: Option[String])

0 commit comments

Comments
 (0)