diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 55adc06320a58..892d80fecdc37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -71,6 +71,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { ColumnPruning, // Operator combine CollapseRepartition, + CollapseSorts, CollapseProject, CombineFilters, CombineLimits, @@ -482,6 +483,25 @@ object CollapseRepartition extends Rule[LogicalPlan] { } } +/** + * Collapse two adjacent [[Sort]] operators into one if possible. Keep the last sort + * This rule applies to the scenario where the global is same for the Sort nodes and then + * either a) The sorts are adjacent or b) In between two Sort nodes, there is a Filter or + * a Project or a Limit + */ +object CollapseSorts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case ss @ Sort( _, globalOrder, ns @ Sort ( _, g, grandChild)) + if globalOrder == g => ss.copy( child = grandChild) + case ss @ Sort( _, globalOrder, p @ Project( _, c @ Sort( _, g, ggchild))) + if globalOrder == g => ss.copy( child = p.copy( child = ggchild)) + case ss @ Sort( _, globalOrder, f @ Filter( _, c @ Sort( _, g, ggchild))) + if globalOrder == g => ss.copy( child = f.copy( child = ggchild)) + case ss @ Sort( _, globalOrder, l @ Limit( _, c @ Sort( _, g, ggchild))) + if globalOrder == g => ss.copy( child = l.copy( child = ggchild)) + } +} + /** * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition. * For example, when the expression is just checking to see if a string starts with a given diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseSortsSuite.scala new file mode 100644 index 0000000000000..2bbc256a8ee09 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseSortsSuite.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +/** + * Test class to test CollapseSorts rule + * For adjacent sorts, collapse the sort if possible + */ +class CollapseSortsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Collapse Sorts", FixedPoint(10), + CollapseSorts, + CollapseProject, + CombineLimits) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("collapsesorts: select has all columns used in sort") { + val originalQuery = + testRelation + .select('a, 'b) + .orderBy('b.asc) + .orderBy('a.asc) + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .select('a, 'b) + .orderBy('a.asc).analyze + comparePlans(optimized, correctAnswer) + } + + + test("collapsesorts: combines two sorts project subset") { + val originalQuery = + testRelation + .select('a, 'b, 'c) + .orderBy('b.asc) + .orderBy('a.asc) + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .select('a, 'b, 'c) + .orderBy('a.asc).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapsesorts: select has all columns used in sort, desc") { + val originalQuery = + testRelation + .select('a, 'b) + .orderBy('b.desc) + .orderBy('a.asc) + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .select('a, 'b) + .orderBy('a.asc).analyze + comparePlans(optimized, correctAnswer) + } + + test("collapsesorts: multiple sorts") { + val originalQuery = + testRelation + .select('a, 'b) + .orderBy('a.asc) + .orderBy('b.desc, 'a.asc) + .orderBy('a.asc) + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .select('a, 'b) + .orderBy('a.asc).analyze + comparePlans(optimized, correctAnswer) + } + + // Project will be introduced as part of Analyzer ResolveSortReferences. Test to ensure + // that sorts are collapsed. + test("collapsesorts: sorts will be collapsed even with project introduced in between") { + val originalQuery = + testRelation + .select('a) + .orderBy('b.desc, 'a.asc) + .orderBy('a.asc) + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .select('a) + .orderBy('a.asc).analyze + comparePlans(optimized, correctAnswer) + } + + test("collapsesorts: test collapsesorts in sort <- project <- sort scenario") { + val originalQuery = + testRelation + .orderBy('b.desc, 'a.asc) + .select('a, 'c) + .orderBy('c.asc) + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .select('a, 'c) + .orderBy('c.asc).analyze + comparePlans(optimized, correctAnswer) + } + + test("collapsesorts: test collapsesorts in sort <- limit <- sort scenario") { + val originalQuery = + testRelation + .orderBy('b.desc, 'a.asc) + .limit(2) + .orderBy('c.asc) + .select('a) + val optimized = Optimize.execute(originalQuery.analyze) + // Check there is only one Sort + assert(optimized.toString.split("Sort").length == 2) + } + + test("collapsesorts: test collapsesorts in sort <- filter <- sort scenario") { + val originalQuery = + testRelation + .orderBy('b.desc, 'a.asc) + .where('c > 1) + .orderBy('c.asc) + .select('a, 'b, 'c, 'd) + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .where('c > 1) + .orderBy('c.asc) + .select('a, 'b, 'c, 'd).analyze + comparePlans(optimized, correctAnswer) + } + + test("collapsesorts: collapsesorts will not be exercised, global in sortBy is false") { + val originalQuery = + testRelation + .sortBy('b.desc, 'a.asc) + .where('c > 1) + .orderBy('c.asc) + .select('a, 'b, 'c, 'd) + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .sortBy('b.desc, 'a.asc) + .where('c > 1) + .orderBy('c.asc) + .select('a, 'b, 'c, 'd).analyze + comparePlans(optimized, correctAnswer) + } + +}