From 1789a74c2f270136ed5ec66713d9244299873874 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 2 May 2016 15:38:31 -0700 Subject: [PATCH 1/5] [SPARK-15076][SQL] Improve ConstantFolding optimizer by using integral associative property. --- .../sql/catalyst/optimizer/Optimizer.scala | 36 +++++++++++++++++++ .../optimizer/ConstantFoldingSuite.scala | 28 ++++++++++++--- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 93762ad1b91c2..7e975f3f8e702 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -742,6 +742,23 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe * equivalent [[Literal]] values. */ object ConstantFolding extends Rule[LogicalPlan] { + private def isAssociativelyFoldable(e: Expression): Boolean = + e.isInstanceOf[BinaryArithmetic] && + e.dataType.isInstanceOf[IntegralType] && + isSingleOperatorExpr(e) + + private def isSingleOperatorExpr(e: Expression): Boolean = e.find { + case a: Add if a.getClass == e.getClass => false + case m: Multiply if m.getClass == e.getClass => false + case _: BinaryArithmetic => true + case _ => false + }.isEmpty + + private def getOperandList(e: Expression): Seq[Expression] = e match { + case BinaryArithmetic(a, b) => getOperandList(a) ++ getOperandList(b) + case other => other :: Nil + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { // Skip redundant folding of literals. This rule is technically not necessary. Placing this @@ -751,6 +768,25 @@ object ConstantFolding extends Rule[LogicalPlan] { // Fold expressions that are foldable. case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) + + // Use associative property for integral type + case e if isAssociativelyFoldable(e) => + val (foldables, others) = getOperandList(e).partition(_.foldable) + if (foldables.size > 1) { + e match { + case a: Add => + val foldableExpr = foldables.reduce((x, y) => Add(x, y)) + val c = Literal.create(foldableExpr.eval(EmptyRow), e.dataType) + Add(others.reduce((x, y) => Add(x, y)), c) + case m: Multiply => + val foldableExpr = foldables.reduce((x, y) => Multiply(x, y)) + val c = Literal.create(foldableExpr.eval(EmptyRow), e.dataType) + Multiply(others.reduce((x, y) => Multiply(x, y)), c) + case _ => e + } + } else { + e + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index d9655bbcc2ce1..70c2b5c17caab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -33,8 +33,8 @@ class ConstantFoldingSuite extends PlanTest { val batches = Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: - Batch("ConstantFolding", Once, - OptimizeIn(SimpleCatalystConf(true)), + Batch("ConstantFolding", FixedPoint(10), + OptimizeIn(SimpleCatalystConf(caseSensitiveAnalysis = true)), ConstantFolding, BooleanSimplification) :: Nil } @@ -104,8 +104,8 @@ class ConstantFoldingSuite extends PlanTest { val correctAnswer = testRelation .select( - Literal(5) + 'a as Symbol("c1"), - 'a + Literal(2) + Literal(3) as Symbol("c2"), + 'a + Literal(5) as Symbol("c1"), + 'a + Literal(5) as Symbol("c2"), Literal(2) * 'a + Literal(4) as Symbol("c3"), 'a * Literal(7) as Symbol("c4")) .analyze @@ -149,7 +149,7 @@ class ConstantFoldingSuite extends PlanTest { val correctAnswer = testRelation .select( - Literal(5) + 'a as Symbol("c1"), + 'a + 5 as Symbol("c1"), Literal(3) as Symbol("c2")) .analyze @@ -264,4 +264,22 @@ class ConstantFoldingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("Constant folding test: associative property") { + val originalQuery = + testRelation + .select((Literal(3) + ((Literal(1) + 'a) + 2)) + 4, 'b * 2 * 3 * 4, 'a + 1 + 'b + 2) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"), + ('b * 24).as("(((b * 2) * 3) * 4)"), + ('a + 'b + 3).as("(((a + 1) + b) + 2)")) + .analyze + + comparePlans(optimized, correctAnswer) + } } From 526ee9424efd48b85b3bfa276cdec0b6a4703637 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 31 May 2016 15:33:47 -0700 Subject: [PATCH 2/5] Make new optimizer `ReorderAssociativeOperator`. --- .../sql/catalyst/optimizer/Optimizer.scala | 37 +++++++----- .../optimizer/ConstantFoldingSuite.scala | 28 ++------- .../ReorderAssociativeOperatorSuite.scala | 57 +++++++++++++++++++ 3 files changed, 85 insertions(+), 37 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7e975f3f8e702..828a3ada87a62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -94,6 +94,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) FoldablePropagation, OptimizeIn(conf), ConstantFolding, + ReorderAssociativeOperator, LikeSimplification, BooleanSimplification, SimplifyConditionals, @@ -738,10 +739,9 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe } /** - * Replaces [[Expression Expressions]] that can be statically evaluated with - * equivalent [[Literal]] values. + * Reorder associative integral-type operators and fold all constants into one. */ -object ConstantFolding extends Rule[LogicalPlan] { +object ReorderAssociativeOperator extends Rule[LogicalPlan] { private def isAssociativelyFoldable(e: Expression): Boolean = e.isInstanceOf[BinaryArithmetic] && e.dataType.isInstanceOf[IntegralType] && @@ -761,15 +761,6 @@ object ConstantFolding extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - // Skip redundant folding of literals. This rule is technically not necessary. Placing this - // here avoids running the next rule for Literal values, which would create a new Literal - // object and running eval unnecessarily. - case l: Literal => l - - // Fold expressions that are foldable. - case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) - - // Use associative property for integral type case e if isAssociativelyFoldable(e) => val (foldables, others) = getOperandList(e).partition(_.foldable) if (foldables.size > 1) { @@ -777,11 +768,11 @@ object ConstantFolding extends Rule[LogicalPlan] { case a: Add => val foldableExpr = foldables.reduce((x, y) => Add(x, y)) val c = Literal.create(foldableExpr.eval(EmptyRow), e.dataType) - Add(others.reduce((x, y) => Add(x, y)), c) + if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y)), c) case m: Multiply => val foldableExpr = foldables.reduce((x, y) => Multiply(x, y)) val c = Literal.create(foldableExpr.eval(EmptyRow), e.dataType) - Multiply(others.reduce((x, y) => Multiply(x, y)), c) + if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y)), c) case _ => e } } else { @@ -791,6 +782,24 @@ object ConstantFolding extends Rule[LogicalPlan] { } } +/** + * Replaces [[Expression Expressions]] that can be statically evaluated with + * equivalent [[Literal]] values. + */ +object ConstantFolding extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + // Skip redundant folding of literals. This rule is technically not necessary. Placing this + // here avoids running the next rule for Literal values, which would create a new Literal + // object and running eval unnecessarily. + case l: Literal => l + + // Fold expressions that are foldable. + case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) + } + } +} + /** * Replaces [[In (value, seq[Literal])]] with optimized version[[InSet (value, HashSet[Literal])]] * which is much faster diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 70c2b5c17caab..d9655bbcc2ce1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -33,8 +33,8 @@ class ConstantFoldingSuite extends PlanTest { val batches = Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: - Batch("ConstantFolding", FixedPoint(10), - OptimizeIn(SimpleCatalystConf(caseSensitiveAnalysis = true)), + Batch("ConstantFolding", Once, + OptimizeIn(SimpleCatalystConf(true)), ConstantFolding, BooleanSimplification) :: Nil } @@ -104,8 +104,8 @@ class ConstantFoldingSuite extends PlanTest { val correctAnswer = testRelation .select( - 'a + Literal(5) as Symbol("c1"), - 'a + Literal(5) as Symbol("c2"), + Literal(5) + 'a as Symbol("c1"), + 'a + Literal(2) + Literal(3) as Symbol("c2"), Literal(2) * 'a + Literal(4) as Symbol("c3"), 'a * Literal(7) as Symbol("c4")) .analyze @@ -149,7 +149,7 @@ class ConstantFoldingSuite extends PlanTest { val correctAnswer = testRelation .select( - 'a + 5 as Symbol("c1"), + Literal(5) + 'a as Symbol("c1"), Literal(3) as Symbol("c2")) .analyze @@ -264,22 +264,4 @@ class ConstantFoldingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - - test("Constant folding test: associative property") { - val originalQuery = - testRelation - .select((Literal(3) + ((Literal(1) + 'a) + 2)) + 4, 'b * 2 * 3 * 4, 'a + 1 + 'b + 2) - - val optimized = Optimize.execute(originalQuery.analyze) - - val correctAnswer = - testRelation - .select( - ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"), - ('b * 24).as("(((b * 2) * 3) * 4)"), - ('a + 'b + 3).as("(((a + 1) + b) + 2)")) - .analyze - - comparePlans(optimized, correctAnswer) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala new file mode 100644 index 0000000000000..9554f915e67b4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class ReorderAssociativeOperatorSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("ReorderAssociativeOperator", Once, + ReorderAssociativeOperator) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("Reorder associative operators") { + val originalQuery = + testRelation + .select( + (Literal(3) + ((Literal(1) + 'a) + 2)) + 4, + 'b * 1 * 2 * 3 * 4, + 'a + 1 + 'b + 2 + 'c + 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"), + ('b * 24).as("((((b * 1) * 2) * 3) * 4)"), + ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)")) + .analyze + + comparePlans(optimized, correctAnswer) + } +} From 37bfa88943ca22a3ee8f48f22630935b0ef8d5f6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 1 Jun 2016 16:13:19 -0700 Subject: [PATCH 3/5] Add deterministic check and testcase. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 5 ++--- .../optimizer/ReorderAssociativeOperatorSuite.scala | 6 ++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 828a3ada87a62..4d1fa78052407 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -743,9 +743,8 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe */ object ReorderAssociativeOperator extends Rule[LogicalPlan] { private def isAssociativelyFoldable(e: Expression): Boolean = - e.isInstanceOf[BinaryArithmetic] && - e.dataType.isInstanceOf[IntegralType] && - isSingleOperatorExpr(e) + e.deterministic && e.isInstanceOf[BinaryArithmetic] && e.dataType.isInstanceOf[IntegralType] && + isSingleOperatorExpr(e) private def isSingleOperatorExpr(e: Expression): Boolean = e.find { case a: Add if a.getClass == e.getClass => false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala index 9554f915e67b4..3af1b997f6791 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala @@ -40,7 +40,8 @@ class ReorderAssociativeOperatorSuite extends PlanTest { .select( (Literal(3) + ((Literal(1) + 'a) + 2)) + 4, 'b * 1 * 2 * 3 * 4, - 'a + 1 + 'b + 2 + 'c + 3) + 'a + 1 + 'b + 2 + 'c + 3, + Rand(0) * 1 * 2 * 3 * 4) val optimized = Optimize.execute(originalQuery.analyze) @@ -49,7 +50,8 @@ class ReorderAssociativeOperatorSuite extends PlanTest { .select( ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"), ('b * 24).as("((((b * 1) * 2) * 3) * 4)"), - ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)")) + ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"), + Rand(0) * 1 * 2 * 3 * 4) .analyze comparePlans(optimized, correctAnswer) From 0acb157111c5557ecdc8ce2189c7f0005431316a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 1 Jun 2016 17:33:54 -0700 Subject: [PATCH 4/5] Improve the code according to the comments. --- .../sql/catalyst/optimizer/Optimizer.scala | 57 ++++++++----------- .../ReorderAssociativeOperatorSuite.scala | 4 ++ 2 files changed, 29 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4d1fa78052407..c4fabbb65f3c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -742,42 +742,35 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe * Reorder associative integral-type operators and fold all constants into one. */ object ReorderAssociativeOperator extends Rule[LogicalPlan] { - private def isAssociativelyFoldable(e: Expression): Boolean = - e.deterministic && e.isInstanceOf[BinaryArithmetic] && e.dataType.isInstanceOf[IntegralType] && - isSingleOperatorExpr(e) - - private def isSingleOperatorExpr(e: Expression): Boolean = e.find { - case a: Add if a.getClass == e.getClass => false - case m: Multiply if m.getClass == e.getClass => false - case _: BinaryArithmetic => true - case _ => false - }.isEmpty + private def flattenAdd(e: Expression): Seq[Expression] = e match { + case Add(l, r) => flattenAdd(l) ++ flattenAdd(r) + case other => other :: Nil + } - private def getOperandList(e: Expression): Seq[Expression] = e match { - case BinaryArithmetic(a, b) => getOperandList(a) ++ getOperandList(b) + private def flattenMultiply(e: Expression): Seq[Expression] = e match { + case Multiply(l, r) => flattenMultiply(l) ++ flattenMultiply(r) case other => other :: Nil } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressionsDown { - case e if isAssociativelyFoldable(e) => - val (foldables, others) = getOperandList(e).partition(_.foldable) - if (foldables.size > 1) { - e match { - case a: Add => - val foldableExpr = foldables.reduce((x, y) => Add(x, y)) - val c = Literal.create(foldableExpr.eval(EmptyRow), e.dataType) - if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y)), c) - case m: Multiply => - val foldableExpr = foldables.reduce((x, y) => Multiply(x, y)) - val c = Literal.create(foldableExpr.eval(EmptyRow), e.dataType) - if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y)), c) - case _ => e - } - } else { - e - } - } + def apply(plan: LogicalPlan): LogicalPlan = plan transformExpressionsDown { + case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] => + val (foldables, others) = flattenAdd(a).partition(_.foldable) + if (foldables.size > 1) { + val foldableExpr = foldables.reduce((x, y) => Add(x, y)) + val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType) + if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y)), c) + } else { + a + } + case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] => + val (foldables, others) = flattenMultiply(m).partition(_.foldable) + if (foldables.size > 1) { + val foldableExpr = foldables.reduce((x, y) => Multiply(x, y)) + val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType) + if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y)), c) + } else { + m + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala index 3af1b997f6791..05e15e9ec4728 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala @@ -40,7 +40,9 @@ class ReorderAssociativeOperatorSuite extends PlanTest { .select( (Literal(3) + ((Literal(1) + 'a) + 2)) + 4, 'b * 1 * 2 * 3 * 4, + ('b + 1) * 2 * 3 * 4, 'a + 1 + 'b + 2 + 'c + 3, + 'a + 1 + 'b * 2 + 'c + 3, Rand(0) * 1 * 2 * 3 * 4) val optimized = Optimize.execute(originalQuery.analyze) @@ -50,7 +52,9 @@ class ReorderAssociativeOperatorSuite extends PlanTest { .select( ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"), ('b * 24).as("((((b * 1) * 2) * 3) * 4)"), + (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"), ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"), + ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"), Rand(0) * 1 * 2 * 3 * 4) .analyze From 3959d57258ed03bf6e9b845569d420c40b70e5c4 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 1 Jun 2016 22:23:30 -0700 Subject: [PATCH 5/5] Fix apply. --- .../sql/catalyst/optimizer/Optimizer.scala | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c4fabbb65f3c5..11cd84b396ff4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -752,25 +752,27 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { case other => other :: Nil } - def apply(plan: LogicalPlan): LogicalPlan = plan transformExpressionsDown { - case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] => - val (foldables, others) = flattenAdd(a).partition(_.foldable) - if (foldables.size > 1) { - val foldableExpr = foldables.reduce((x, y) => Add(x, y)) - val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType) - if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y)), c) - } else { - a - } - case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] => - val (foldables, others) = flattenMultiply(m).partition(_.foldable) - if (foldables.size > 1) { - val foldableExpr = foldables.reduce((x, y) => Multiply(x, y)) - val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType) - if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y)), c) - } else { - m - } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] => + val (foldables, others) = flattenAdd(a).partition(_.foldable) + if (foldables.size > 1) { + val foldableExpr = foldables.reduce((x, y) => Add(x, y)) + val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType) + if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y)), c) + } else { + a + } + case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] => + val (foldables, others) = flattenMultiply(m).partition(_.foldable) + if (foldables.size > 1) { + val foldableExpr = foldables.reduce((x, y) => Multiply(x, y)) + val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType) + if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y)), c) + } else { + m + } + } } }