diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 98663a8d807b..5ac3eb3daf7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3994,16 +3994,28 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsUp { - case operator if operator.resolved => operator.transformExpressionsUp { + case operator => operator.transformExpressionsUp { + case e if !e.childrenResolved => e + // String literal is treated as char type when it's compared to a char type column. // We should pad the shorter one to the longer length. case b @ BinaryComparison(attr: Attribute, lit) if lit.foldable => - padAttrLitCmp(attr, lit).map { newChildren => + padAttrLitCmp(attr, attr.metadata, lit).map { newChildren => b.withNewChildren(newChildren) }.getOrElse(b) case b @ BinaryComparison(lit, attr: Attribute) if lit.foldable => - padAttrLitCmp(attr, lit).map { newChildren => + padAttrLitCmp(attr, attr.metadata, lit).map { newChildren => + b.withNewChildren(newChildren.reverse) + }.getOrElse(b) + + case b @ BinaryComparison(or @ OuterReference(attr: Attribute), lit) if lit.foldable => + padAttrLitCmp(or, attr.metadata, lit).map { newChildren => + b.withNewChildren(newChildren) + }.getOrElse(b) + + case b @ BinaryComparison(lit, or @ OuterReference(attr: Attribute)) if lit.foldable => + padAttrLitCmp(or, attr.metadata, lit).map { newChildren => b.withNewChildren(newChildren.reverse) }.getOrElse(b) @@ -4027,6 +4039,12 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { case b @ BinaryComparison(left: Attribute, right: Attribute) => b.withNewChildren(CharVarcharUtils.addPaddingInStringComparison(Seq(left, right))) + case b @ BinaryComparison(OuterReference(left: Attribute), right: Attribute) => + b.withNewChildren(padOuterRefAttrCmp(left, right)) + + case b @ BinaryComparison(left: Attribute, OuterReference(right: Attribute)) => + b.withNewChildren(padOuterRefAttrCmp(right, left).reverse) + case i @ In(attr: Attribute, list) if list.forall(_.isInstanceOf[Attribute]) => val newChildren = CharVarcharUtils.addPaddingInStringComparison( attr +: list.map(_.asInstanceOf[Attribute])) @@ -4035,9 +4053,12 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } - private def padAttrLitCmp(attr: Attribute, lit: Expression): Option[Seq[Expression]] = { - if (attr.dataType == StringType) { - CharVarcharUtils.getRawType(attr.metadata).flatMap { + private def padAttrLitCmp( + expr: Expression, + metadata: Metadata, + lit: Expression): Option[Seq[Expression]] = { + if (expr.dataType == StringType) { + CharVarcharUtils.getRawType(metadata).flatMap { case CharType(length) => val str = lit.eval().asInstanceOf[UTF8String] if (str == null) { @@ -4045,9 +4066,9 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } else { val stringLitLen = str.numChars() if (length < stringLitLen) { - Some(Seq(StringRPad(attr, Literal(stringLitLen)), lit)) + Some(Seq(StringRPad(expr, Literal(stringLitLen)), lit)) } else if (length > stringLitLen) { - Some(Seq(attr, StringRPad(lit, Literal(length)))) + Some(Seq(expr, StringRPad(lit, Literal(length)))) } else { None } @@ -4059,6 +4080,14 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } + private def padOuterRefAttrCmp(outerAttr: Attribute, attr: Attribute): Seq[Expression] = { + val Seq(r, newAttr) = CharVarcharUtils.addPaddingInStringComparison(Seq(outerAttr, attr)) + val newOuterRef = r.transform { + case ar: Attribute if ar.semanticEquals(outerAttr) => OuterReference(ar) + } + Seq(newOuterRef, newAttr) + } + private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = { if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 1e561747b615..7d1e4ff04050 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -582,21 +582,6 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { } } - test("SPARK-33992: char/varchar resolution in correlated sub query") { - withTable("t1", "t2") { - sql(s"CREATE TABLE t1(v VARCHAR(3), c CHAR(5)) USING $format") - sql(s"CREATE TABLE t2(v VARCHAR(3), c CHAR(5)) USING $format") - sql("INSERT INTO t1 VALUES ('c', 'b')") - sql("INSERT INTO t2 VALUES ('a', 'b')") - - checkAnswer(sql( - """ - |SELECT v FROM t1 - |WHERE 'a' IN (SELECT v FROM t2 WHERE t1.c = t2.c )""".stripMargin), - Row("c")) - } - } - test("SPARK-34003: fix char/varchar fails w/ both group by and order by ") { withTable("t") { sql(s"CREATE TABLE t(v VARCHAR(3), i INT) USING $format") @@ -631,6 +616,48 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { checkAnswer(spark.table("t"), Row("c ")) } } + + test("SPARK-34833: right-padding applied correctly for correlated subqueries - join keys") { + withTable("t1", "t2") { + sql(s"CREATE TABLE t1(v VARCHAR(3), c CHAR(5)) USING $format") + sql(s"CREATE TABLE t2(v VARCHAR(5), c CHAR(8)) USING $format") + sql("INSERT INTO t1 VALUES ('c', 'b')") + sql("INSERT INTO t2 VALUES ('a', 'b')") + Seq("t1.c = t2.c", "t2.c = t1.c", + "t1.c = 'b'", "'b' = t1.c", "t1.c = 'b '", "'b ' = t1.c", + "t1.c = 'b '", "'b ' = t1.c").foreach { predicate => + checkAnswer(sql( + s""" + |SELECT v FROM t1 + |WHERE 'a' IN (SELECT v FROM t2 WHERE $predicate) + """.stripMargin), + Row("c")) + } + } + } + + test("SPARK-34833: right-padding applied correctly for correlated subqueries - other preds") { + withTable("t") { + sql(s"CREATE TABLE t(c0 INT, c1 CHAR(5), c2 CHAR(7)) USING $format") + sql("INSERT INTO t VALUES (1, 'abc', 'abc')") + Seq("c1 = 'abc'", "'abc' = c1", "c1 = 'abc '", "'abc ' = c1", + "c1 = 'abc '", "'abc ' = c1", "c1 = c2", "c2 = c1", + "c1 IN ('xxx', 'abc', 'xxxxx')", "c1 IN ('xxx', 'abc ', 'xxxxx')", + "c1 IN ('xxx', 'abc ', 'xxxxx')", + "c1 IN (c2)", "c2 IN (c1)").foreach { predicate => + checkAnswer(sql( + s""" + |SELECT c0 FROM t t1 + |WHERE ( + | SELECT count(*) AS c + | FROM t + | WHERE c0 = t1.c0 AND $predicate + |) > 0 + """.stripMargin), + Row(1)) + } + } + } } // Some basic char/varchar tests which doesn't rely on table implementation.