From da849aaec90d21649e6c9de860213300a684a408 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 23 Mar 2021 13:29:29 +0900 Subject: [PATCH 1/5] Fix --- .../sql/catalyst/analysis/Analyzer.scala | 18 ++++++++- .../spark/sql/CharVarcharTestSuite.scala | 39 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 98663a8d807b..97635b098145 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3994,7 +3994,9 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsUp { - case operator if operator.resolved => operator.transformExpressionsUp { + case operator => operator.transformExpressionsUp { + case e if !e.childrenResolved => e + // String literal is treated as char type when it's compared to a char type column. // We should pad the shorter one to the longer length. case b @ BinaryComparison(attr: Attribute, lit) if lit.foldable => @@ -4027,6 +4029,12 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { case b @ BinaryComparison(left: Attribute, right: Attribute) => b.withNewChildren(CharVarcharUtils.addPaddingInStringComparison(Seq(left, right))) + case b @ BinaryComparison(OuterReference(left: Attribute), right: Attribute) => + b.withNewChildren(padOuterRefAttrCmp(left, right)) + + case b @ BinaryComparison(left: Attribute, OuterReference(right: Attribute)) => + b.withNewChildren(padOuterRefAttrCmp(right, left).reverse) + case i @ In(attr: Attribute, list) if list.forall(_.isInstanceOf[Attribute]) => val newChildren = CharVarcharUtils.addPaddingInStringComparison( attr +: list.map(_.asInstanceOf[Attribute])) @@ -4035,6 +4043,14 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } + private def padOuterRefAttrCmp(outerAttr: Attribute, attr: Attribute): Seq[Expression] = { + val Seq(r, newAttr) = CharVarcharUtils.addPaddingInStringComparison(Seq(outerAttr, attr)) + val newOuterRef = r.transform { + case ar: Attribute if ar.semanticEquals(outerAttr) => OuterReference(ar) + } + Seq(newOuterRef, newAttr) + } + private def padAttrLitCmp(attr: Attribute, lit: Expression): Option[Seq[Expression]] = { if (attr.dataType == StringType) { CharVarcharUtils.getRawType(attr.metadata).flatMap { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 1e561747b615..f0f08fa02b3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -631,6 +631,45 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { checkAnswer(spark.table("t"), Row("c ")) } } + + test("SPARK-34833: right-padding applied correctly for correlated subqueries - join keys") { + withTable("t1", "t2") { + sql(s"CREATE TABLE t1(v VARCHAR(3), c CHAR(5)) USING $format") + sql(s"CREATE TABLE t2(v VARCHAR(5), c CHAR(8)) USING $format") + sql("INSERT INTO t1 VALUES ('c', 'b')") + sql("INSERT INTO t2 VALUES ('a', 'b')") + Seq("t1.c = t2.c", "t2.c = t1.c").foreach { predicate => + checkAnswer(sql( + s""" + |SELECT v FROM t1 + |WHERE 'a' IN (SELECT v FROM t2 WHERE $predicate) + """.stripMargin), + Row("c")) + } + } + } + + test("SPARK-34833: right-padding applied correctly for correlated subqueries - other preds") { + withTable("t") { + sql("CREATE TABLE t(c0 INT, c1 CHAR(5), c2 CHAR(7)) USING parquet") + sql("INSERT INTO t VALUES (1, 'abc', 'abc')") + Seq("c1 = 'abc'", "'abc' = c1", "c1 = c2", "c1 IN ('abc', 'defghijk')", "c1 IN (c2)") + .foreach { predicate => + + checkAnswer(sql( + s""" + |SELECT c0 FROM t t1 + |WHERE ( + | SELECT count(*) AS c + | FROM t + | WHERE c0 = t1.c0 AND $predicate + |) > 0 + |LIMIT 3 + """.stripMargin), + Row(1)) + } + } + } } // Some basic char/varchar tests which doesn't rely on table implementation. From d02fbec9fe5e07b9c0fa9a3792704b263cd25778 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 23 Mar 2021 16:28:51 +0900 Subject: [PATCH 2/5] review --- .../test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index f0f08fa02b3d..df96e0269019 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -651,7 +651,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { test("SPARK-34833: right-padding applied correctly for correlated subqueries - other preds") { withTable("t") { - sql("CREATE TABLE t(c0 INT, c1 CHAR(5), c2 CHAR(7)) USING parquet") + sql(s"CREATE TABLE t(c0 INT, c1 CHAR(5), c2 CHAR(7)) USING $format") sql("INSERT INTO t VALUES (1, 'abc', 'abc')") Seq("c1 = 'abc'", "'abc' = c1", "c1 = c2", "c1 IN ('abc', 'defghijk')", "c1 IN (c2)") .foreach { predicate => From f869736c9d82344fd718f29d3a82a8c4973324d1 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 23 Mar 2021 22:21:27 +0900 Subject: [PATCH 3/5] Review --- .../sql/catalyst/analysis/Analyzer.scala | 37 +++++++++++++++---- .../spark/sql/CharVarcharTestSuite.scala | 17 +-------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 97635b098145..48ecefcc686d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -4009,6 +4009,16 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { b.withNewChildren(newChildren.reverse) }.getOrElse(b) + case b @ BinaryComparison(OuterReference(attr: Attribute), lit) if lit.foldable => + padOuterAttrLitComp(attr, lit).map { newChildren => + b.withNewChildren(newChildren) + }.getOrElse(b) + + case b @ BinaryComparison(lit, OuterReference(attr: Attribute)) if lit.foldable => + padOuterAttrLitComp(attr, lit).map { newChildren => + b.withNewChildren(newChildren.reverse) + }.getOrElse(b) + case i @ In(attr: Attribute, list) if attr.dataType == StringType && list.forall(_.foldable) => CharVarcharUtils.getRawType(attr.metadata).flatMap { @@ -4043,14 +4053,6 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } - private def padOuterRefAttrCmp(outerAttr: Attribute, attr: Attribute): Seq[Expression] = { - val Seq(r, newAttr) = CharVarcharUtils.addPaddingInStringComparison(Seq(outerAttr, attr)) - val newOuterRef = r.transform { - case ar: Attribute if ar.semanticEquals(outerAttr) => OuterReference(ar) - } - Seq(newOuterRef, newAttr) - } - private def padAttrLitCmp(attr: Attribute, lit: Expression): Option[Seq[Expression]] = { if (attr.dataType == StringType) { CharVarcharUtils.getRawType(attr.metadata).flatMap { @@ -4075,6 +4077,25 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } + private def padOuterAttrLitComp( + outerAttr: Attribute, + lit: Expression): Option[Seq[Expression]] = { + padAttrLitCmp(outerAttr, lit).map { case Seq(newAttr, newLit) => + val newOuterRef = newAttr.transform { + case ar: Attribute if ar.semanticEquals(outerAttr) => OuterReference(ar) + } + Seq(newOuterRef, newLit) + } + } + + private def padOuterRefAttrCmp(outerAttr: Attribute, attr: Attribute): Seq[Expression] = { + val Seq(r, newAttr) = CharVarcharUtils.addPaddingInStringComparison(Seq(outerAttr, attr)) + val newOuterRef = r.transform { + case ar: Attribute if ar.semanticEquals(outerAttr) => OuterReference(ar) + } + Seq(newOuterRef, newAttr) + } + private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = { if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index df96e0269019..03ccc8ea74fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -582,21 +582,6 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { } } - test("SPARK-33992: char/varchar resolution in correlated sub query") { - withTable("t1", "t2") { - sql(s"CREATE TABLE t1(v VARCHAR(3), c CHAR(5)) USING $format") - sql(s"CREATE TABLE t2(v VARCHAR(3), c CHAR(5)) USING $format") - sql("INSERT INTO t1 VALUES ('c', 'b')") - sql("INSERT INTO t2 VALUES ('a', 'b')") - - checkAnswer(sql( - """ - |SELECT v FROM t1 - |WHERE 'a' IN (SELECT v FROM t2 WHERE t1.c = t2.c )""".stripMargin), - Row("c")) - } - } - test("SPARK-34003: fix char/varchar fails w/ both group by and order by ") { withTable("t") { sql(s"CREATE TABLE t(v VARCHAR(3), i INT) USING $format") @@ -638,7 +623,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE t2(v VARCHAR(5), c CHAR(8)) USING $format") sql("INSERT INTO t1 VALUES ('c', 'b')") sql("INSERT INTO t2 VALUES ('a', 'b')") - Seq("t1.c = t2.c", "t2.c = t1.c").foreach { predicate => + Seq("t1.c = t2.c", "t2.c = t1.c", "t1.c = 'b'", "'b' = t1.c").foreach { predicate => checkAnswer(sql( s""" |SELECT v FROM t1 From acdff360f0df82806ee1ce2bc24c98b6b5f60d5d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 24 Mar 2021 08:17:56 +0900 Subject: [PATCH 4/5] Add more tests --- .../spark/sql/CharVarcharTestSuite.scala | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 03ccc8ea74fa..9786d95bde0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -623,7 +623,9 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE t2(v VARCHAR(5), c CHAR(8)) USING $format") sql("INSERT INTO t1 VALUES ('c', 'b')") sql("INSERT INTO t2 VALUES ('a', 'b')") - Seq("t1.c = t2.c", "t2.c = t1.c", "t1.c = 'b'", "'b' = t1.c").foreach { predicate => + Seq("t1.c = t2.c", "t2.c = t1.c", + "t1.c = 'b'", "'b' = t1.c", "t1.c = 'b '", "'b ' = t1.c", + "t1.c = 'b '", "'b ' = t1.c").foreach { predicate => checkAnswer(sql( s""" |SELECT v FROM t1 @@ -638,21 +640,23 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { withTable("t") { sql(s"CREATE TABLE t(c0 INT, c1 CHAR(5), c2 CHAR(7)) USING $format") sql("INSERT INTO t VALUES (1, 'abc', 'abc')") - Seq("c1 = 'abc'", "'abc' = c1", "c1 = c2", "c1 IN ('abc', 'defghijk')", "c1 IN (c2)") - .foreach { predicate => - - checkAnswer(sql( - s""" - |SELECT c0 FROM t t1 - |WHERE ( - | SELECT count(*) AS c - | FROM t - | WHERE c0 = t1.c0 AND $predicate - |) > 0 - |LIMIT 3 - """.stripMargin), - Row(1)) - } + Seq("c1 = 'abc'", "'abc' = c1", "c1 = 'abc '", "'abc ' = c1", + "c1 = 'abc '", "'abc ' = c1", "c1 = c2", "c2 = c1", + "c1 IN ('xxx', 'abc', 'xxxxx')", "c1 IN ('xxx', 'abc ', 'xxxxx')", + "c1 IN ('xxx', 'abc ', 'xxxxx')", + "c1 IN (c2)", "c2 IN (c1)").foreach { predicate => + checkAnswer(sql( + s""" + |SELECT c0 FROM t t1 + |WHERE ( + | SELECT count(*) AS c + | FROM t + | WHERE c0 = t1.c0 AND $predicate + |) > 0 + |LIMIT 3 + """.stripMargin), + Row(1)) + } } } } From 9142bfc7ea082133bd60f7dc1090c64f7835199d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 24 Mar 2021 22:10:53 +0900 Subject: [PATCH 5/5] review --- .../sql/catalyst/analysis/Analyzer.scala | 36 ++++++++----------- .../spark/sql/CharVarcharTestSuite.scala | 1 - 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 48ecefcc686d..5ac3eb3daf7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -4000,22 +4000,22 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { // String literal is treated as char type when it's compared to a char type column. // We should pad the shorter one to the longer length. case b @ BinaryComparison(attr: Attribute, lit) if lit.foldable => - padAttrLitCmp(attr, lit).map { newChildren => + padAttrLitCmp(attr, attr.metadata, lit).map { newChildren => b.withNewChildren(newChildren) }.getOrElse(b) case b @ BinaryComparison(lit, attr: Attribute) if lit.foldable => - padAttrLitCmp(attr, lit).map { newChildren => + padAttrLitCmp(attr, attr.metadata, lit).map { newChildren => b.withNewChildren(newChildren.reverse) }.getOrElse(b) - case b @ BinaryComparison(OuterReference(attr: Attribute), lit) if lit.foldable => - padOuterAttrLitComp(attr, lit).map { newChildren => + case b @ BinaryComparison(or @ OuterReference(attr: Attribute), lit) if lit.foldable => + padAttrLitCmp(or, attr.metadata, lit).map { newChildren => b.withNewChildren(newChildren) }.getOrElse(b) - case b @ BinaryComparison(lit, OuterReference(attr: Attribute)) if lit.foldable => - padOuterAttrLitComp(attr, lit).map { newChildren => + case b @ BinaryComparison(lit, or @ OuterReference(attr: Attribute)) if lit.foldable => + padAttrLitCmp(or, attr.metadata, lit).map { newChildren => b.withNewChildren(newChildren.reverse) }.getOrElse(b) @@ -4053,9 +4053,12 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } - private def padAttrLitCmp(attr: Attribute, lit: Expression): Option[Seq[Expression]] = { - if (attr.dataType == StringType) { - CharVarcharUtils.getRawType(attr.metadata).flatMap { + private def padAttrLitCmp( + expr: Expression, + metadata: Metadata, + lit: Expression): Option[Seq[Expression]] = { + if (expr.dataType == StringType) { + CharVarcharUtils.getRawType(metadata).flatMap { case CharType(length) => val str = lit.eval().asInstanceOf[UTF8String] if (str == null) { @@ -4063,9 +4066,9 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } else { val stringLitLen = str.numChars() if (length < stringLitLen) { - Some(Seq(StringRPad(attr, Literal(stringLitLen)), lit)) + Some(Seq(StringRPad(expr, Literal(stringLitLen)), lit)) } else if (length > stringLitLen) { - Some(Seq(attr, StringRPad(lit, Literal(length)))) + Some(Seq(expr, StringRPad(lit, Literal(length)))) } else { None } @@ -4077,17 +4080,6 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } - private def padOuterAttrLitComp( - outerAttr: Attribute, - lit: Expression): Option[Seq[Expression]] = { - padAttrLitCmp(outerAttr, lit).map { case Seq(newAttr, newLit) => - val newOuterRef = newAttr.transform { - case ar: Attribute if ar.semanticEquals(outerAttr) => OuterReference(ar) - } - Seq(newOuterRef, newLit) - } - } - private def padOuterRefAttrCmp(outerAttr: Attribute, attr: Attribute): Seq[Expression] = { val Seq(r, newAttr) = CharVarcharUtils.addPaddingInStringComparison(Seq(outerAttr, attr)) val newOuterRef = r.transform { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 9786d95bde0a..7d1e4ff04050 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -653,7 +653,6 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { | FROM t | WHERE c0 = t1.c0 AND $predicate |) > 0 - |LIMIT 3 """.stripMargin), Row(1)) }