Skip to content

Commit fbe266c

Browse files
committed
Fix optimizer issues
1 parent 766e0e6 commit fbe266c

File tree

3 files changed

+65
-2
lines changed

3 files changed

+65
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.collection.immutable.HashSet
2121
import scala.collection.mutable.{ArrayBuffer, Stack}
2222

2323
import org.apache.spark.sql.catalyst.analysis._
24+
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts
2425
import org.apache.spark.sql.catalyst.expressions._
2526
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
2627
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -645,11 +646,16 @@ object CombineConcats extends Rule[LogicalPlan] {
645646
stack.pop() match {
646647
case Concat(children) =>
647648
stack.pushAll(children.reverse)
649+
case Cast(Concat(children), StringType, _) =>
650+
stack.pushAll(children.reverse)
648651
case child =>
649652
flattened += child
650653
}
651654
}
652-
Concat(flattened)
655+
val newChildren = flattened.map { e =>
656+
ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
657+
}
658+
Concat(newChildren)
653659
}
654660

655661
def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown {

sql/core/src/test/resources/sql-tests/inputs/string-functions.sql

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,17 @@ select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null);
2424
select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a');
2525
select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null);
2626
select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a');
27+
28+
-- turn on concatBinaryAsString
29+
set spark.sql.function.concatBinaryAsString=false;
30+
31+
-- Check if catalyst combine nested `Concat`s if concatBinaryAsString=false
32+
EXPLAIN EXTENDED SELECT ((col1 || col2) || (col3 || col4)) col
33+
FROM (
34+
SELECT
35+
string(id) col1,
36+
string(id + 1) col2,
37+
encode(string(id + 2), 'utf-8') col3,
38+
encode(string(id + 3), 'utf-8') col4
39+
FROM range(10)
40+
);

sql/core/src/test/resources/sql-tests/results/string-functions.sql.out

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 12
2+
-- Number of queries: 14
33

44

55
-- !query 0
@@ -118,3 +118,46 @@ select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a')
118118
struct<right(NULL, -2):string,right('abcd', -2):string,right('abcd', 0):string,right('abcd', 'a'):string>
119119
-- !query 11 output
120120
NULL NULL
121+
122+
123+
-- !query 12
124+
set spark.sql.function.concatBinaryAsString=false
125+
-- !query 12 schema
126+
struct<key:string,value:string>
127+
-- !query 12 output
128+
spark.sql.function.concatBinaryAsString false
129+
130+
131+
-- !query 13
132+
EXPLAIN EXTENDED SELECT ((col1 || col2) || (col3 || col4)) col
133+
FROM (
134+
SELECT
135+
string(id) col1,
136+
string(id + 1) col2,
137+
encode(string(id + 2), 'utf-8') col3,
138+
encode(string(id + 3), 'utf-8') col4
139+
FROM range(10)
140+
)
141+
-- !query 13 schema
142+
struct<plan:string>
143+
-- !query 13 output
144+
== Parsed Logical Plan ==
145+
'Project [concat(concat('col1, 'col2), concat('col3, 'col4)) AS col#x]
146+
+- 'SubqueryAlias __auto_generated_subquery_name
147+
+- 'Project ['string('id) AS col1#x, 'string(('id + 1)) AS col2#x, 'encode('string(('id + 2)), utf-8) AS col3#x, 'encode('string(('id + 3)), utf-8) AS col4#x]
148+
+- 'UnresolvedTableValuedFunction range, [10]
149+
150+
== Analyzed Logical Plan ==
151+
col: string
152+
Project [concat(concat(col1#x, col2#x), cast(concat(col3#x, col4#x) as string)) AS col#x]
153+
+- SubqueryAlias __auto_generated_subquery_name
154+
+- Project [cast(id#xL as string) AS col1#x, cast((id#xL + cast(1 as bigint)) as string) AS col2#x, encode(cast((id#xL + cast(2 as bigint)) as string), utf-8) AS col3#x, encode(cast((id#xL + cast(3 as bigint)) as string), utf-8) AS col4#x]
155+
+- Range (0, 10, step=1, splits=None)
156+
157+
== Optimized Logical Plan ==
158+
Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]
159+
+- Range (0, 10, step=1, splits=None)
160+
161+
== Physical Plan ==
162+
*Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]
163+
+- *Range (0, 10, step=1, splits=2)

0 commit comments

Comments
 (0)