From 2fb7a1cd42664c281bfc64bf584b8f762f828b4d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 24 Nov 2015 21:41:15 -0800 Subject: [PATCH 1/9] push filter through aggregation with alias and literals --- .../sql/catalyst/optimizer/Optimizer.scala | 12 +++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 21 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f4dba67f13b5..ad87bce073a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -649,7 +649,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe // Substitute any attributes that are produced by the child projection, so that we safely // eliminate it. - private def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = { + private[sql] def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = { condition.transform { case a: Attribute => sourceAliases.getOrElse(a, a) } @@ -690,7 +690,15 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel def apply(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) => - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { + + // Create a map of Alias for grouping keys or literals + val aliasMap = AttributeMap(aggregateExpressions.collect { + case a: Alias if groupingExpressions.contains(a.child) || a.child.foldable => + (a.toAttribute, a.child) + }) + val newCond = PushPredicateThroughProject.replaceAlias(condition, aliasMap) + + val (pushDown, stayUp) = splitConjunctivePredicates(newCond).partition { conjunct => conjunct.references subsetOf AttributeSet(groupingExpressions) } if (pushDown.nonEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bb82b562aaaa..580c5a9868e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2028,4 +2028,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(false) :: Row(true) :: Nil) } + test("push filter through aggregation with alias and literals") { + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + val q1 = sql( + """ + | SELECT k, v from ( + | SELECT key k, sum(value) v, 3 c FROM src GROUP BY key + | ) t WHERE k = 1 and v > 0 and c = 3 + """.stripMargin) + val q2 = sql( + """ + | SELECT k, v from ( + | SELECT key k, sum(value) v, 3 c FROM src WHERE key = 1 GROUP BY key + | ) t WHERE v > 0 + """.stripMargin) + comparePlans(q1.queryExecution.optimizedPlan, q2.queryExecution.optimizedPlan) + checkAnswer(q1, q2) + } + + } + } From 162268c2410d19caf7032e17ccccac1aecc1237c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 25 Nov 2015 01:00:03 -0800 Subject: [PATCH 2/9] improve performance of cartesian product --- .../unsafe/sort/UnsafeExternalSorter.java | 57 +++++++++++++++ .../unsafe/sort/UnsafeInMemorySorter.java | 4 ++ .../execution/joins/CartesianProduct.scala | 70 ++++++++++++++++--- 3 files changed, 123 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 9a7b2ad06cab..cde24a9046da 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -21,6 +21,7 @@ import java.io.File; import java.io.IOException; import java.util.LinkedList; +import java.util.Queue; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; @@ -519,4 +520,60 @@ public long getKeyPrefix() { return upstream.getKeyPrefix(); } } + + /** + * Returns a iterator. It is the caller's responsibility to call `cleanupResources()` + * after consuming this iterator. + */ + public UnsafeSorterIterator getIterator() throws IOException { + if (spillWriters.isEmpty()) { + assert(inMemSorter != null); + return inMemSorter.getIterator(); + } else { + Queue queue = new LinkedList<>(); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + queue.add(spillWriter.getReader(blockManager)); + } + if (inMemSorter != null) { + queue.add(inMemSorter.getIterator()); + } + return new ChainedIterator(queue); + } + } + + class ChainedIterator extends UnsafeSorterIterator { + private final Queue iterators; + private UnsafeSorterIterator current = null; + public ChainedIterator(Queue iters) { + this.iterators = iters; + this.current = iters.remove(); + } + + @Override + public boolean hasNext() { + if (!current.hasNext()) { + if (!iterators.isEmpty()) { + current = iterators.remove(); + } + } + return current.hasNext(); + } + + @Override + public void loadNext() throws IOException { + current.loadNext(); + } + + @Override + public Object getBaseObject() { return current.getBaseObject(); } + + @Override + public long getBaseOffset() { return current.getBaseOffset(); } + + @Override + public int getRecordLength() { return current.getRecordLength(); } + + @Override + public long getKeyPrefix() { return current.getKeyPrefix(); } + } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index a218ad4623f4..2ea6fc1de8fa 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -231,4 +231,8 @@ public SortedIterator getSortedIterator() { sorter.sort(array, 0, pos / 2, sortComparator); return new SortedIterator(memoryManager, pos, array); } + + public SortedIterator getIterator() { + return new SortedIterator(memoryManager, pos, array); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index f467519b802a..9ccebad5ba99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -17,16 +17,69 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.rdd.RDD +import org.apache.spark._ +import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + + +private[spark] +class UnsafeCartesianRDD(rdd1 : RDD[UnsafeRow], rdd2 : RDD[UnsafeRow]) + extends CartesianRDD[UnsafeRow, UnsafeRow](rdd1.sparkContext, rdd1, rdd2) { + + override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { + val sorter = UnsafeExternalSorter.create( + context.taskMemoryManager(), + SparkEnv.get.blockManager, + context, + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes) + + val currSplit = split.asInstanceOf[CartesianPartition] + var numFields = 0 + for (y <- rdd2.iterator(currSplit.s2, context)) { + numFields = y.numFields() + sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0) + } + + def createIter(): Iterator[UnsafeRow] = { + val iter = sorter.getIterator + val unsafeRow = new UnsafeRow + new Iterator[UnsafeRow] { + override def hasNext: Boolean = { + iter.hasNext + } + override def next(): UnsafeRow = { + iter.loadNext() + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFields, iter.getRecordLength) + unsafeRow + } + } + } + + val resultIter = + for (x <- rdd1.iterator(currSplit.s1, context); + y <- createIter()) yield (x, y) + CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( + resultIter, sorter.cleanupResources) + } +} case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override def canProcessSafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true + override private[sql] lazy val metrics = Map( "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), @@ -39,18 +92,19 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod val leftResults = left.execute().map { row => numLeftRows += 1 - row.copy() + row.asInstanceOf[UnsafeRow] } val rightResults = right.execute().map { row => numRightRows += 1 - row.copy() + row.asInstanceOf[UnsafeRow] } - leftResults.cartesian(rightResults).mapPartitionsInternal { iter => - val joinedRow = new JoinedRow + val pair = new UnsafeCartesianRDD(leftResults, rightResults) + pair.mapPartitionsInternal { iter => + val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) iter.map { r => numOutputRows += 1 - joinedRow(r._1, r._2) + joiner.join(r._1, r._2) } } } From 0f5d7ba4942d8d125c201718bab720a354a842aa Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 25 Nov 2015 10:46:14 -0800 Subject: [PATCH 3/9] address comments --- .../sql/catalyst/optimizer/Optimizer.scala | 38 ++++++++----- .../optimizer/FilterPushdownSuite.scala | 53 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 6 +-- 3 files changed, 80 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ad87bce073a1..bfcc05c2fb96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -608,6 +608,19 @@ object SimplifyFilters extends Rule[LogicalPlan] { } } +/** + * Helper functions for Predicate push down. + */ +object PredicateHelper { + + // Substitute any known alias from a map. + def replaceAlias(condition: Expression, aliases: AttributeMap[Expression]): Expression = { + condition.transform { + case a: Attribute => aliases.getOrElse(a, a) + } + } +} + /** * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]] * that were defined in the projection. @@ -633,27 +646,21 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe // If there is no nondeterministic conditions, push down the whole condition. if (nondeterministic.isEmpty) { - project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) + project.copy(child = Filter(PredicateHelper.replaceAlias(condition, aliasMap), grandChild)) } else { // If they are all nondeterministic conditions, leave it un-changed. if (deterministic.isEmpty) { filter } else { // Push down the small conditions without nondeterministic expressions. - val pushedCondition = deterministic.map(replaceAlias(_, aliasMap)).reduce(And) + val pushedCondition = + deterministic.map(PredicateHelper.replaceAlias(_, aliasMap)).reduce(And) Filter(nondeterministic.reduce(And), project.copy(child = Filter(pushedCondition, grandChild))) } } } - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - private[sql] def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = { - condition.transform { - case a: Attribute => sourceAliases.getOrElse(a, a) - } - } } /** @@ -691,15 +698,18 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel case filter @ Filter(condition, aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) => - // Create a map of Alias for grouping keys or literals + def hasAggregate(expression: Expression): Boolean = expression match { + case agg: AggregateExpression => true + case other => expression.children.exists(hasAggregate) + } + // Create a map of Alias for expressions that does not have AggregateExpression val aliasMap = AttributeMap(aggregateExpressions.collect { - case a: Alias if groupingExpressions.contains(a.child) || a.child.foldable => - (a.toAttribute, a.child) + case a: Alias if !hasAggregate(a.child) => (a.toAttribute, a.child) }) - val newCond = PushPredicateThroughProject.replaceAlias(condition, aliasMap) + val newCond = PredicateHelper.replaceAlias(condition, aliasMap) val (pushDown, stayUp) = splitConjunctivePredicates(newCond).partition { - conjunct => conjunct.references subsetOf AttributeSet(groupingExpressions) + conjunct => conjunct.references.subsetOf(grandChild.outputSet) && conjunct.deterministic } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0290fafe879f..f37aee588658 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -697,4 +697,57 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("aggregate: push down filters with alias") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) + .where('c === 2L && 'aa === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .where('a + 1 === 3) + .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) + .where('c === 2L) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("aggregate: push down filters with literal") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a, count('b) as 'c, "s" as 'd) + .where('c === 2L && 'd === "s") + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .where("s" === "s") + .groupBy('a)('a, count('b) as 'c, "s" as 'd) + .where('c === 2L) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("aggregate: don't push down filters which is nondeterministic") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) + .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) + .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 580c5a9868e3..4926dfa7753f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2034,13 +2034,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val q1 = sql( """ | SELECT k, v from ( - | SELECT key k, sum(value) v, 3 c FROM src GROUP BY key - | ) t WHERE k = 1 and v > 0 and c = 3 + | SELECT key + 1 AS k, sum(value) + 2 AS v, 3 c FROM src GROUP BY key + | ) t WHERE k = 0 and v > 0 and c = 3 """.stripMargin) val q2 = sql( """ | SELECT k, v from ( - | SELECT key k, sum(value) v, 3 c FROM src WHERE key = 1 GROUP BY key + | SELECT key + 1 AS k, sum(value) + 2 AS v, 3 c FROM src WHERE key + 1 = 0 GROUP BY key | ) t WHERE v > 0 """.stripMargin) comparePlans(q1.queryExecution.optimizedPlan, q2.queryExecution.optimizedPlan) From 951fe7a781a8887e6c96c409b4e3515dae621e06 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 25 Nov 2015 11:05:18 -0800 Subject: [PATCH 4/9] fix tests --- .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 9 +++++---- .../sql/catalyst/optimizer/FilterPushdownSuite.scala | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bfcc05c2fb96..82508f7f72ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -706,14 +706,15 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel val aliasMap = AttributeMap(aggregateExpressions.collect { case a: Alias if !hasAggregate(a.child) => (a.toAttribute, a.child) }) - val newCond = PredicateHelper.replaceAlias(condition, aliasMap) - val (pushDown, stayUp) = splitConjunctivePredicates(newCond).partition { - conjunct => conjunct.references.subsetOf(grandChild.outputSet) && conjunct.deterministic + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { conjunct => + val replaced = PredicateHelper.replaceAlias(conjunct, aliasMap) + replaced.references.subsetOf(grandChild.outputSet) && replaced.deterministic } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) - val withPushdown = aggregate.copy(child = Filter(pushDownPredicate, grandChild)) + val replaced = PredicateHelper.replaceAlias(pushDownPredicate, aliasMap) + val withPushdown = aggregate.copy(child = Filter(replaced, grandChild)) stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) } else { filter diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index f37aee588658..0128c220baac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -702,15 +702,15 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = testRelation .select('a, 'b) .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) - .where('c === 2L && 'aa === 3) + .where(('c === 2L || 'aa > 4) && 'aa < 3) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a, 'b) - .where('a + 1 === 3) + .where('a + 1 < 3) .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) - .where('c === 2L) + .where('c === 2L || 'aa > 4) .analyze comparePlans(optimized, correctAnswer) From a94204bd5e6af9a8ae8df810f897b5818d45173b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 25 Nov 2015 11:18:32 -0800 Subject: [PATCH 5/9] fix build --- .../spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 97b24799d980..e127539b751b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -228,6 +228,6 @@ public SortedIterator getSortedIterator() { } public SortedIterator getIterator() { - return new SortedIterator(memoryManager, pos, array); + return new SortedIterator(pos / 2); } } From 37b308863d656509add3f6357775148ca9cf5b0f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 25 Nov 2015 22:18:43 -0800 Subject: [PATCH 6/9] address comments --- .../sql/catalyst/expressions/predicates.scala | 9 ++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 21 ++++--------------- .../org/apache/spark/sql/SQLQuerySuite.scala | 21 ------------------- 3 files changed, 13 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 68557479a959..304b438c84ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -65,6 +65,15 @@ trait PredicateHelper { } } + // Substitute any known alias from a map. + protected def replaceAlias( + condition: Expression, + aliases: AttributeMap[Expression]): Expression = { + condition.transform { + case a: Attribute => aliases.getOrElse(a, a) + } + } + /** * Returns true if `expr` can be evaluated using only the output of `plan`. This method * can be used to determine when it is acceptable to move expression evaluation within a query diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 82508f7f72ce..52f609bc158c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -608,19 +608,6 @@ object SimplifyFilters extends Rule[LogicalPlan] { } } -/** - * Helper functions for Predicate push down. - */ -object PredicateHelper { - - // Substitute any known alias from a map. - def replaceAlias(condition: Expression, aliases: AttributeMap[Expression]): Expression = { - condition.transform { - case a: Attribute => aliases.getOrElse(a, a) - } - } -} - /** * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]] * that were defined in the projection. @@ -646,7 +633,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe // If there is no nondeterministic conditions, push down the whole condition. if (nondeterministic.isEmpty) { - project.copy(child = Filter(PredicateHelper.replaceAlias(condition, aliasMap), grandChild)) + project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) } else { // If they are all nondeterministic conditions, leave it un-changed. if (deterministic.isEmpty) { @@ -654,7 +641,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe } else { // Push down the small conditions without nondeterministic expressions. val pushedCondition = - deterministic.map(PredicateHelper.replaceAlias(_, aliasMap)).reduce(And) + deterministic.map(replaceAlias(_, aliasMap)).reduce(And) Filter(nondeterministic.reduce(And), project.copy(child = Filter(pushedCondition, grandChild))) } @@ -708,12 +695,12 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel }) val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { conjunct => - val replaced = PredicateHelper.replaceAlias(conjunct, aliasMap) + val replaced = replaceAlias(conjunct, aliasMap) replaced.references.subsetOf(grandChild.outputSet) && replaced.deterministic } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) - val replaced = PredicateHelper.replaceAlias(pushDownPredicate, aliasMap) + val replaced = replaceAlias(pushDownPredicate, aliasMap) val withPushdown = aggregate.copy(child = Filter(replaced, grandChild)) stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4926dfa7753f..bb82b562aaaa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2028,25 +2028,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(false) :: Row(true) :: Nil) } - test("push filter through aggregation with alias and literals") { - withTempTable("src") { - Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") - val q1 = sql( - """ - | SELECT k, v from ( - | SELECT key + 1 AS k, sum(value) + 2 AS v, 3 c FROM src GROUP BY key - | ) t WHERE k = 0 and v > 0 and c = 3 - """.stripMargin) - val q2 = sql( - """ - | SELECT k, v from ( - | SELECT key + 1 AS k, sum(value) + 2 AS v, 3 c FROM src WHERE key + 1 = 0 GROUP BY key - | ) t WHERE v > 0 - """.stripMargin) - comparePlans(q1.queryExecution.optimizedPlan, q2.queryExecution.optimizedPlan) - checkAnswer(q1, q2) - } - - } - } From 074f2a738d3f095380caca9a2beaef2421810d6e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 25 Nov 2015 23:05:48 -0800 Subject: [PATCH 7/9] add comments --- .../unsafe/sort/UnsafeExternalSorter.java | 26 ++++++++++++------- .../unsafe/sort/UnsafeInMemorySorter.java | 3 +++ .../execution/joins/CartesianProduct.scala | 18 ++++++++----- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index a49afcd8c92f..8e19a5669ff6 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -524,7 +524,9 @@ public long getKeyPrefix() { } /** - * Returns a iterator. It is the caller's responsibility to call `cleanupResources()` + * Returns a iterator, which will return the rows in the order as inserted. + * + * It is the caller's responsibility to call `cleanupResources()` * after consuming this iterator. */ public UnsafeSorterIterator getIterator() throws IOException { @@ -532,7 +534,7 @@ public UnsafeSorterIterator getIterator() throws IOException { assert(inMemSorter != null); return inMemSorter.getIterator(); } else { - Queue queue = new LinkedList<>(); + LinkedList queue = new LinkedList<>(); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { queue.add(spillWriter.getReader(blockManager)); } @@ -543,20 +545,24 @@ public UnsafeSorterIterator getIterator() throws IOException { } } + /** + * Chain multiple UnsafeSorterIterator together as single one. + */ class ChainedIterator extends UnsafeSorterIterator { + private final Queue iterators; - private UnsafeSorterIterator current = null; - public ChainedIterator(Queue iters) { - this.iterators = iters; - this.current = iters.remove(); + private UnsafeSorterIterator current; + + public ChainedIterator(Queue iterators) { + assert iterators.size() > 0; + this.iterators = iterators; + this.current = iterators.remove(); } @Override public boolean hasNext() { - if (!current.hasNext()) { - if (!iterators.isEmpty()) { - current = iterators.remove(); - } + if (!current.hasNext() && !iterators.isEmpty()) { + current = iterators.remove(); } return current.hasNext(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index e127539b751b..c91e88f31bf9 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -227,6 +227,9 @@ public SortedIterator getSortedIterator() { return new SortedIterator(pos / 2); } + /** + * Returns an iterator over record pointers in original order (inserted). + */ public SortedIterator getIterator() { return new SortedIterator(pos / 2); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 9ccebad5ba99..9ddae6b5b02b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -28,11 +28,17 @@ import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter +/** + * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, + * will be much faster than building the right partition for every row in left RDD, it also + * materialize the right RDD (in case of the right RDD is nondeterministic). + */ private[spark] -class UnsafeCartesianRDD(rdd1 : RDD[UnsafeRow], rdd2 : RDD[UnsafeRow]) - extends CartesianRDD[UnsafeRow, UnsafeRow](rdd1.sparkContext, rdd1, rdd2) { +class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) + extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { + // We will not sort the rows, so prefixComparator and recordComparator are null. val sorter = UnsafeExternalSorter.create( context.taskMemoryManager(), SparkEnv.get.blockManager, @@ -43,12 +49,11 @@ class UnsafeCartesianRDD(rdd1 : RDD[UnsafeRow], rdd2 : RDD[UnsafeRow]) SparkEnv.get.memoryManager.pageSizeBytes) val currSplit = split.asInstanceOf[CartesianPartition] - var numFields = 0 for (y <- rdd2.iterator(currSplit.s2, context)) { - numFields = y.numFields() sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0) } + // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] def createIter(): Iterator[UnsafeRow] = { val iter = sorter.getIterator val unsafeRow = new UnsafeRow @@ -58,7 +63,8 @@ class UnsafeCartesianRDD(rdd1 : RDD[UnsafeRow], rdd2 : RDD[UnsafeRow]) } override def next(): UnsafeRow = { iter.loadNext() - unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFields, iter.getRecordLength) + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFieldsOfRight, + iter.getRecordLength) unsafeRow } } @@ -99,7 +105,7 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod row.asInstanceOf[UnsafeRow] } - val pair = new UnsafeCartesianRDD(leftResults, rightResults) + val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) pair.mapPartitionsInternal { iter => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) iter.map { r => From 99bb8ef9f63c5a6fcecfd0e62bc7c131f13d14c7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 25 Nov 2015 23:25:13 -0800 Subject: [PATCH 8/9] fix test --- .../apache/spark/sql/execution/joins/CartesianProduct.scala | 6 +++--- .../apache/spark/sql/execution/metric/SQLMetricsSuite.scala | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 9ddae6b5b02b..fa2bc7672131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -48,8 +48,8 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField 1024, SparkEnv.get.memoryManager.pageSizeBytes) - val currSplit = split.asInstanceOf[CartesianPartition] - for (y <- rdd2.iterator(currSplit.s2, context)) { + val partition = split.asInstanceOf[CartesianPartition] + for (y <- rdd2.iterator(partition.s2, context)) { sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0) } @@ -71,7 +71,7 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField } val resultIter = - for (x <- rdd1.iterator(currSplit.s1, context); + for (x <- rdd1.iterator(partition.s1, context); y <- createIter()) yield (x, y) CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( resultIter, sorter.cleanupResources) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 5e2b4154dd7c..82867ab4967b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -315,7 +315,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { testSparkPlanMetrics(df, 1, Map( 1L -> ("CartesianProduct", Map( "number of left rows" -> 12L, // left needs to be scanned twice - "number of right rows" -> 12L, // right is read 6 times + "number of right rows" -> 4L, // right is read twice "number of output rows" -> 12L))) ) } From fbd7dfdd9d07d778e3aa87477dbdd868f556d755 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 30 Nov 2015 11:07:45 -0800 Subject: [PATCH 9/9] defend empty iterator --- .../spark/util/collection/unsafe/sort/UnsafeExternalSorter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 8e19a5669ff6..5a97f4f11340 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -561,7 +561,7 @@ public ChainedIterator(Queue iterators) { @Override public boolean hasNext() { - if (!current.hasNext() && !iterators.isEmpty()) { + while (!current.hasNext() && !iterators.isEmpty()) { current = iterators.remove(); } return current.hasNext();