From 7e1a8dbbdd8a4646a2b2c87bab97f535f7e51714 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Sat, 25 Apr 2020 13:11:53 +0300 Subject: [PATCH 1/2] Add a test --- .../org/apache/spark/sql/ColumnExpressionSuite.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index a9ee25b10dc0..b72d92b9e2a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -26,12 +26,13 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ -import org.apache.spark.sql.catalyst.expressions.{In, InSet, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{In, InSet, Literal, NamedExpression} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class ColumnExpressionSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -869,4 +870,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { df.select(typedLit(("a", 2, 1.0))), Row(Row("a", 2, 1.0)) :: Nil) } + + test("SPARK-31563: sql of InSet for UTF8String collection") { + val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString)) + assert(inSet.sql === "('a' IN ('a', 'b'))") + } } From d46c177b4cfe191ed73ec319886084f4b482b784 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Sat, 25 Apr 2020 13:12:02 +0300 Subject: [PATCH 2/2] Bug fix --- .../apache/spark/sql/catalyst/expressions/predicates.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index bd190c3e5abc..ac492cf22730 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.immutable.TreeSet +import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference @@ -519,7 +520,9 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with override def sql: String = { val valueSQL = child.sql - val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ") + val listSQL = hset.toSeq + .map(elem => Literal(convertToScala(elem, child.dataType)).sql) + .mkString(", ") s"($valueSQL IN ($listSQL))" } }