Skip to content

Commit b0888d1

Browse files
maropucloud-fan
authored andcommitted
[SPARK-20730][SQL] Add an optimizer rule to combine nested Concat
## What changes were proposed in this pull request? This pr added a new Optimizer rule to combine nested Concat. The master supports a pipeline operator '||' to concatenate strings in apache#17711 (This pr is follow-up). Since the parser currently generates nested Concat expressions, the optimizer needs to combine the nested expressions. ## How was this patch tested? Added tests in `CombineConcatSuite` and `SQLQueryTestSuite`. Author: Takeshi Yamamuro <[email protected]> Closes apache#17970 from maropu/SPARK-20730.
1 parent 8da6e8b commit b0888d1

File tree

5 files changed

+134
-2
lines changed

5 files changed

+134
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
111111
RemoveRedundantProject,
112112
SimplifyCreateStructOps,
113113
SimplifyCreateArrayOps,
114-
SimplifyCreateMapOps) ++
114+
SimplifyCreateMapOps,
115+
CombineConcats) ++
115116
extendedOperatorOptimizationRules: _*) ::
116117
Batch("Check Cartesian Products", Once,
117118
CheckCartesianProducts(conf)) ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import scala.collection.immutable.HashSet
21+
import scala.collection.mutable.{ArrayBuffer, Stack}
2122

2223
import org.apache.spark.sql.catalyst.analysis._
2324
import org.apache.spark.sql.catalyst.expressions._
@@ -543,3 +544,28 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
543544
}
544545
}
545546
}
547+
548+
/**
549+
* Combine nested [[Concat]] expressions.
550+
*/
551+
object CombineConcats extends Rule[LogicalPlan] {
552+
553+
private def flattenConcats(concat: Concat): Concat = {
554+
val stack = Stack[Expression](concat)
555+
val flattened = ArrayBuffer.empty[Expression]
556+
while (stack.nonEmpty) {
557+
stack.pop() match {
558+
case Concat(children) =>
559+
stack.pushAll(children.reverse)
560+
case child =>
561+
flattened += child
562+
}
563+
}
564+
Concat(flattened)
565+
}
566+
567+
def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown {
568+
case concat: Concat if concat.children.exists(_.isInstanceOf[Concat]) =>
569+
flattenConcats(concat)
570+
}
571+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.dsl.plans._
21+
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.plans.PlanTest
23+
import org.apache.spark.sql.catalyst.plans.logical._
24+
import org.apache.spark.sql.catalyst.rules._
25+
import org.apache.spark.sql.types.StringType
26+
27+
28+
class CombineConcatsSuite extends PlanTest {
29+
30+
object Optimize extends RuleExecutor[LogicalPlan] {
31+
val batches = Batch("CombineConcatsSuite", FixedPoint(50), CombineConcats) :: Nil
32+
}
33+
34+
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
35+
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
36+
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
37+
comparePlans(actual, correctAnswer)
38+
}
39+
40+
test("combine nested Concat exprs") {
41+
def str(s: String): Literal = Literal(s, StringType)
42+
assertEquivalent(
43+
Concat(
44+
Concat(str("a") :: str("b") :: Nil) ::
45+
str("c") ::
46+
str("d") ::
47+
Nil),
48+
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
49+
assertEquivalent(
50+
Concat(
51+
str("a") ::
52+
Concat(str("b") :: str("c") :: Nil) ::
53+
str("d") ::
54+
Nil),
55+
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
56+
assertEquivalent(
57+
Concat(
58+
str("a") ::
59+
str("b") ::
60+
Concat(str("c") :: str("d") :: Nil) ::
61+
Nil),
62+
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
63+
assertEquivalent(
64+
Concat(
65+
Concat(
66+
str("a") ::
67+
Concat(
68+
str("b") ::
69+
Concat(str("c") :: str("d") :: Nil) ::
70+
Nil) ::
71+
Nil) ::
72+
Nil),
73+
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
74+
}
75+
}

sql/core/src/test/resources/sql-tests/inputs/string-functions.sql

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@ select format_string();
44

55
-- A pipe operator for string concatenation
66
select 'a' || 'b' || 'c';
7+
8+
-- Check if catalyst combine nested `Concat`s
9+
EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col
10+
FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10));

sql/core/src/test/resources/sql-tests/results/string-functions.sql.out

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 3
2+
-- Number of queries: 4
33

44

55
-- !query 0
@@ -26,3 +26,29 @@ select 'a' || 'b' || 'c'
2626
struct<concat(concat(a, b), c):string>
2727
-- !query 2 output
2828
abc
29+
30+
31+
-- !query 3
32+
EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col
33+
FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10))
34+
-- !query 3 schema
35+
struct<plan:string>
36+
-- !query 3 output
37+
== Parsed Logical Plan ==
38+
'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x]
39+
+- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x]
40+
+- 'UnresolvedTableValuedFunction range, [10]
41+
42+
== Analyzed Logical Plan ==
43+
col: string
44+
Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x]
45+
+- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL]
46+
+- Range (0, 10, step=1, splits=None)
47+
48+
== Optimized Logical Plan ==
49+
Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x]
50+
+- Range (0, 10, step=1, splits=None)
51+
52+
== Physical Plan ==
53+
*Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x]
54+
+- *Range (0, 10, step=1, splits=2)

0 commit comments

Comments
 (0)