diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4daa7ef02bc3..66546f85fcc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -4008,6 +4008,14 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { */ object ApplyCharTypePadding extends Rule[LogicalPlan] { + object AttrOrOuterRef { + def unapply(e: Expression): Option[Attribute] = e match { + case a: Attribute => Some(a) + case OuterReference(a: Attribute) => Some(a) + case _ => None + } + } + override def apply(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsUp { case operator => operator.transformExpressionsUp { @@ -4015,27 +4023,17 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { // String literal is treated as char type when it's compared to a char type column. // We should pad the shorter one to the longer length. - case b @ BinaryComparison(attr: Attribute, lit) if lit.foldable => - padAttrLitCmp(attr, attr.metadata, lit).map { newChildren => - b.withNewChildren(newChildren) - }.getOrElse(b) - - case b @ BinaryComparison(lit, attr: Attribute) if lit.foldable => - padAttrLitCmp(attr, attr.metadata, lit).map { newChildren => - b.withNewChildren(newChildren.reverse) - }.getOrElse(b) - - case b @ BinaryComparison(or @ OuterReference(attr: Attribute), lit) if lit.foldable => - padAttrLitCmp(or, attr.metadata, lit).map { newChildren => + case b @ BinaryComparison(e @ AttrOrOuterRef(attr), lit) if lit.foldable => + padAttrLitCmp(e, attr.metadata, lit).map { newChildren => b.withNewChildren(newChildren) }.getOrElse(b) - case b @ BinaryComparison(lit, or @ OuterReference(attr: Attribute)) if lit.foldable => - padAttrLitCmp(or, attr.metadata, lit).map { newChildren => + case b @ BinaryComparison(lit, e @ AttrOrOuterRef(attr)) if lit.foldable => + padAttrLitCmp(e, attr.metadata, lit).map { newChildren => b.withNewChildren(newChildren.reverse) }.getOrElse(b) - case i @ In(attr: Attribute, list) + case i @ In(e @ AttrOrOuterRef(attr), list) if attr.dataType == StringType && list.forall(_.foldable) => CharVarcharUtils.getRawType(attr.metadata).flatMap { case CharType(length) => @@ -4044,7 +4042,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { val literalCharLengths = literalChars.map(_.numChars()) val targetLen = (length +: literalCharLengths).max Some(i.copy( - value = addPadding(attr, length, targetLen), + value = addPadding(e, length, targetLen), list = list.zip(literalCharLengths).map { case (lit, charLength) => addPadding(lit, charLength, targetLen) } ++ nulls.map(Literal.create(_, StringType)))) @@ -4052,19 +4050,36 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { }.getOrElse(i) // For char type column or inner field comparison, pad the shorter one to the longer length. - case b @ BinaryComparison(left: Attribute, right: Attribute) => - b.withNewChildren(CharVarcharUtils.addPaddingInStringComparison(Seq(left, right))) - - case b @ BinaryComparison(OuterReference(left: Attribute), right: Attribute) => - b.withNewChildren(padOuterRefAttrCmp(left, right)) - - case b @ BinaryComparison(left: Attribute, OuterReference(right: Attribute)) => - b.withNewChildren(padOuterRefAttrCmp(right, left).reverse) + case b @ BinaryComparison(e1 @ AttrOrOuterRef(left), e2 @ AttrOrOuterRef(right)) + // For the same attribute, they must be the same length and no padding is needed. + if !left.semanticEquals(right) => + val outerRefs = (e1, e2) match { + case (_: OuterReference, _: OuterReference) => Seq(left, right) + case (_: OuterReference, _) => Seq(left) + case (_, _: OuterReference) => Seq(right) + case _ => Nil + } + val newChildren = CharVarcharUtils.addPaddingInStringComparison(Seq(left, right)) + if (outerRefs.nonEmpty) { + b.withNewChildren(newChildren.map(_.transform { + case a: Attribute if outerRefs.exists(_.semanticEquals(a)) => OuterReference(a) + })) + } else { + b.withNewChildren(newChildren) + } - case i @ In(attr: Attribute, list) if list.forall(_.isInstanceOf[Attribute]) => + case i @ In(e @ AttrOrOuterRef(attr), list) if list.forall(_.isInstanceOf[Attribute]) => val newChildren = CharVarcharUtils.addPaddingInStringComparison( attr +: list.map(_.asInstanceOf[Attribute])) - i.copy(value = newChildren.head, list = newChildren.tail) + if (e.isInstanceOf[OuterReference]) { + i.copy( + value = newChildren.head.transform { + case a: Attribute if a.semanticEquals(attr) => OuterReference(a) + }, + list = newChildren.tail) + } else { + i.copy(value = newChildren.head, list = newChildren.tail) + } } } }