Skip to content

Commit 179c6fd

Browse files
committed
Fix
1 parent fbe266c commit 179c6fd

File tree

5 files changed

+23
-30
lines changed

5 files changed

+23
-30
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -680,9 +680,8 @@ object TypeCoercion {
680680
// Skip nodes if unresolved or empty children
681681
case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c
682682

683-
case c @ Concat(children) if !children.map(_.dataType).forall(_ == BinaryType) =>
684-
typeCastToString(c)
685-
case c @ Concat(children) if conf.concatBinaryAsString =>
683+
case c @ Concat(children) if conf.concatBinaryAsString ||
684+
!children.map(_.dataType).forall(_ == BinaryType) =>
686685
typeCastToString(c)
687686
}
688687
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -646,16 +646,17 @@ object CombineConcats extends Rule[LogicalPlan] {
646646
stack.pop() match {
647647
case Concat(children) =>
648648
stack.pushAll(children.reverse)
649-
case Cast(Concat(children), StringType, _) =>
650-
stack.pushAll(children.reverse)
649+
// If `spark.sql.function.concatBinaryAsString` is false, nested `Concat` exprs possibly
650+
// have `Concat`s with binary output. Since `TypeCoercion` casts them into strings,
651+
// we need to handle the case to combine all nested `Concat`s.
652+
case c @ Cast(Concat(children), StringType, _) =>
653+
val newChildren = children.map { e => c.copy(child = e) }
654+
stack.pushAll(newChildren.reverse)
651655
case child =>
652656
flattened += child
653657
}
654658
}
655-
val newChildren = flattened.map { e =>
656-
ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
657-
}
658-
Concat(newChildren)
659+
Concat(flattened)
659660
}
660661

661662
def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.plans.PlanTest
2323
import org.apache.spark.sql.catalyst.plans.logical._
2424
import org.apache.spark.sql.catalyst.rules._
25-
import org.apache.spark.sql.types.StringType
2625

2726

2827
class CombineConcatsSuite extends PlanTest {
@@ -37,8 +36,10 @@ class CombineConcatsSuite extends PlanTest {
3736
comparePlans(actual, correctAnswer)
3837
}
3938

39+
def str(s: String): Literal = Literal(s)
40+
def binary(s: String): Literal = Literal(s.getBytes)
41+
4042
test("combine nested Concat exprs") {
41-
def str(s: String): Literal = Literal(s, StringType)
4243
assertEquivalent(
4344
Concat(
4445
Concat(str("a") :: str("b") :: Nil) ::
@@ -72,4 +73,13 @@ class CombineConcatsSuite extends PlanTest {
7273
Nil),
7374
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
7475
}
76+
77+
test("combine string and binary exprs") {
78+
assertEquivalent(
79+
Concat(
80+
Concat(str("a") :: str("b") :: Nil) ::
81+
Concat(binary("c") :: binary("d") :: Nil) ::
82+
Nil),
83+
Concat(str("a") :: str("b") :: binary("c") :: binary("d") :: Nil))
84+
}
7585
}

sql/core/src/test/resources/sql-tests/inputs/string-functions.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a');
2929
set spark.sql.function.concatBinaryAsString=false;
3030

3131
-- Check if catalyst combine nested `Concat`s if concatBinaryAsString=false
32-
EXPLAIN EXTENDED SELECT ((col1 || col2) || (col3 || col4)) col
32+
EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col
3333
FROM (
3434
SELECT
3535
string(id) col1,

sql/core/src/test/resources/sql-tests/results/string-functions.sql.out

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ spark.sql.function.concatBinaryAsString false
129129

130130

131131
-- !query 13
132-
EXPLAIN EXTENDED SELECT ((col1 || col2) || (col3 || col4)) col
132+
EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col
133133
FROM (
134134
SELECT
135135
string(id) col1,
@@ -141,23 +141,6 @@ FROM (
141141
-- !query 13 schema
142142
struct<plan:string>
143143
-- !query 13 output
144-
== Parsed Logical Plan ==
145-
'Project [concat(concat('col1, 'col2), concat('col3, 'col4)) AS col#x]
146-
+- 'SubqueryAlias __auto_generated_subquery_name
147-
+- 'Project ['string('id) AS col1#x, 'string(('id + 1)) AS col2#x, 'encode('string(('id + 2)), utf-8) AS col3#x, 'encode('string(('id + 3)), utf-8) AS col4#x]
148-
+- 'UnresolvedTableValuedFunction range, [10]
149-
150-
== Analyzed Logical Plan ==
151-
col: string
152-
Project [concat(concat(col1#x, col2#x), cast(concat(col3#x, col4#x) as string)) AS col#x]
153-
+- SubqueryAlias __auto_generated_subquery_name
154-
+- Project [cast(id#xL as string) AS col1#x, cast((id#xL + cast(1 as bigint)) as string) AS col2#x, encode(cast((id#xL + cast(2 as bigint)) as string), utf-8) AS col3#x, encode(cast((id#xL + cast(3 as bigint)) as string), utf-8) AS col4#x]
155-
+- Range (0, 10, step=1, splits=None)
156-
157-
== Optimized Logical Plan ==
158-
Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]
159-
+- Range (0, 10, step=1, splits=None)
160-
161144
== Physical Plan ==
162145
*Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]
163146
+- *Range (0, 10, step=1, splits=2)

0 commit comments

Comments
 (0)