Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b566818
[SPARK-25951][SQL] Remove Alias when canonicalize
mgaido91 Nov 6, 2018
2b00f35
introduce sameResult
mgaido91 Nov 7, 2018
0491249
address comment: add comments
mgaido91 Nov 28, 2018
3831be0
address comment
mgaido91 Nov 28, 2018
6c93e70
improve comments
mgaido91 Nov 28, 2018
a306465
fix recursive aliases
mgaido91 Nov 29, 2018
13aef71
fix trimAliases
mgaido91 Nov 29, 2018
5c6b9fc
address comments
mgaido91 Nov 30, 2018
0aaedd8
fix import
mgaido91 Nov 30, 2018
6eee1e4
Merge branch 'master' of github.com:apache/spark into SPARK-25951
mgaido91 Nov 30, 2018
5bca5e3
fix
mgaido91 Nov 30, 2018
1f797df
add tests
mgaido91 Dec 3, 2018
bf1d04a
switch approach: update outputPartitioning
mgaido91 Dec 4, 2018
e4f617f
fix ut error
mgaido91 Dec 5, 2018
bca3e87
Add test suite dedicated to partitioning
mgaido91 Dec 7, 2018
0f68a41
fix UT failures
mgaido91 Dec 7, 2018
952a2c2
Merge branch 'master' into SPARK-25951
mgaido91 Jan 30, 2019
9af290e
fix merge
mgaido91 Jan 30, 2019
778ede3
Merge branch 'master' of github.com:apache/spark into SPARK-25951
mgaido91 Jan 31, 2019
5904cf9
address comments
mgaido91 Jan 31, 2019
205f1b7
cleanup
mgaido91 Jan 31, 2019
f47d5df
address comment
mgaido91 Feb 1, 2019
09b9981
use maropu's approach
mgaido91 Feb 10, 2019
69f9d5e
fix
mgaido91 Feb 10, 2019
75ef545
fix ut failures
mgaido91 Feb 10, 2019
df3394c
adress comments
mgaido91 Feb 11, 2019
78d92bc
fix rangepartitioning
mgaido91 Feb 11, 2019
47a8f71
Merge branch 'master' into SPARK-25951
mgaido91 Apr 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2616,8 +2616,8 @@ object EliminateUnions extends Rule[LogicalPlan] {
* rule can't work for those parameters.
*/
object CleanupAliases extends Rule[LogicalPlan] {
private def trimAliases(e: Expression): Expression = {
e.transformDown {
private[spark] def trimAliases(e: Expression): Expression = {
e.transformUp {
case Alias(child, _) => child
case MultiAlias(child, _) => child
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.plans.physical

import org.apache.spark.sql.catalyst.analysis.CleanupAliases
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{DataType, IntegerType}

Expand Down Expand Up @@ -184,6 +185,11 @@ trait Partitioning {
case AllTuples => numPartitions == 1
case _ => false
}

/**
* Returns a version of this [[Partitioning]] amended by the invalid [[Attribute]].
*/
private[spark] def pruneInvalidAttribute(invalidAttr: Attribute): Partitioning = this
}

case class UnknownPartitioning(numPartitions: Int) extends Partitioning
Expand Down Expand Up @@ -235,6 +241,21 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
* than numPartitions) based on hashing expressions.
*/
def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions))

/**
* If the HashPartitioning contains an attribute which is not present in the output expressions,
* the returned partitioning in `UnknownPartitioning` instead of the `HashPartitioning` of the
* remaining attributes which is wrong.
* Eg. `HashPartitioning('a, 'b)` with output expressions `'a as 'a1`, should produce
* `UnknownPartitioning` instead of `HashPartitioning('a1)`
*/
override private[spark] def pruneInvalidAttribute(invalidAttr: Attribute): Partitioning = {
if (this.references.contains(invalidAttr)) {
UnknownPartitioning(numPartitions)
Copy link
Contributor

@cloud-fan cloud-fan Feb 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add comments to explain it.

HashPartitioning('a, 'b) with output expressions 'a as 'a1, should produce UnknownPartitioning instead of HashPartitioning('a1), which is wrong.

} else {
this
}
}
}

/**
Expand Down Expand Up @@ -284,6 +305,19 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
}
}
}

/**
* Returns `UnknownPartitioning` if the first ordering expressions is not valid anymore,
* otherwise it performs no modification because pruning the invalid expressions may cause
* errors when comparing with `ClusteredDistribution`s.
*/
override private[spark] def pruneInvalidAttribute(invalidAttr: Attribute): Partitioning = {
if (ordering.headOption.forall(_.references.contains(invalidAttr))) {
UnknownPartitioning(numPartitions)
} else {
this
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.analysis.CleanupAliases
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, UnknownPartitioning}

/**
* Trait for plans which can produce an output partitioned by aliased attributes of their child.
* It rewrites the partitioning attributes of the child with the corresponding new ones which are
* exposed in the output of this plan. It can avoid the presence of redundant shuffles in queries
* caused by the rename of an attribute among the partitioning ones, eg.
*
* spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df1")
* spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df2")
* sql("set spark.sql.autoBroadcastJoinThreshold=-1")
* sql("""
* SELECT * FROM
* (SELECT key AS k from df1) t1
* INNER JOIN
* (SELECT key AS k from df2) t2
* ON t1.k = t2.k
* """).explain
*
* == Physical Plan ==
* *SortMergeJoin [k#21L], [k#22L], Inner
* :- *Sort [k#21L ASC NULLS FIRST], false, 0
* : +- Exchange hashpartitioning(k#21L, 200) // <--- Unnecessary shuffle operation
* : +- *Project [key#2L AS k#21L]
* : +- Exchange hashpartitioning(key#2L, 200)
* : +- *Project [id#0L AS key#2L]
* : +- *Range (0, 10, step=1, splits=Some(4))
* +- *(4) Sort [k#22L ASC NULLS FIRST], false, 0
* +- *(4) Project [key#8L AS k#22L]
* +- ReusedExchange [key#8L], Exchange hashpartitioning(key#2L, 200)
*/
trait AliasAwareOutputPartitioning extends UnaryExecNode {

/**
* `Seq` of `Expression`s which define the ouput of the node.
*/
protected def outputExpressions: Seq[NamedExpression]

/**
* Returns the valid `Partitioning`s for the node w.r.t its output and its expressions.
*/
final override def outputPartitioning: Partitioning = {
child.outputPartitioning match {
case partitioning: Expression =>
// Creates a sequence of tuples where the first element is an `Attribute` referenced in the
// partitioning expression of the child and the second is a sequence of all its aliased
// occurrences in the node output. If there is no occurrence of an attribute in the output,
// the second element of the tuple for it will be an empty `Seq`. If the attribute,
// instead, is only present as is in the output, there will be no entry for it.
// Eg. if the partitioning is RangePartitioning('a) and the node output is "a, 'a as a1,
// a' as a2", then exprToEquiv will contain the tuple ('a, Seq('a, 'a as a1, 'a as a2)).
val exprToEquiv = partitioning.references.map { attr =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain what's going on here? The code is a little hard to follow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, let me add some comments. Thanks.

attr -> outputExpressions.filter(e =>
CleanupAliases.trimAliases(e).semanticEquals(attr))
}.filterNot { case (attr, exprs) =>
exprs.size == 1 && exprs.forall(_ == attr)
}
val initValue = partitioning match {
case PartitioningCollection(partitionings) => partitionings
case other => Seq(other)
}
// Replace all the aliased expressions detected earlier with all their corresponding
// occurrences. This may produce many valid partitioning expressions from a single one.
// Eg. in the example above, this would produce a `Seq` of 3 `RangePartitioning`, namely:
// `RangePartitioning('a)`, `RangePartitioning('a1)`, `RangePartitioning('a2)`.
val validPartitionings = exprToEquiv.foldLeft(initValue) {
case (partitionings, (toReplace, equivalents)) =>
if (equivalents.isEmpty) {
// Remove from the partitioning expression the attribute which is not present in the
// node output
partitionings.map(_.pruneInvalidAttribute(toReplace))
} else {
partitionings.flatMap {
case p: Expression if p.references.contains(toReplace) =>
equivalents.map { equiv =>
p.transformDown {
case e if e == toReplace => equiv.toAttribute
}.asInstanceOf[Partitioning]
}
case other => Seq(other)
}
}
}.distinct
if (validPartitionings.size == 1) {
validPartitionings.head
} else {
validPartitionings.filterNot(_.isInstanceOf[UnknownPartitioning]) match {
case Seq() => PartitioningCollection(validPartitionings)
case Seq(knownPartitioning) => knownPartitioning
case knownPartitionings => PartitioningCollection(knownPartitionings)
}

}
case other => other
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ case class HashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode with BlockingOperatorWithCodegen {
extends UnaryExecNode with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand All @@ -72,7 +72,7 @@ case class HashAggregateExec(

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputExpressions: Seq[NamedExpression] = resultExpressions

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
Expand All @@ -90,11 +90,15 @@ case class HashAggregateExec(
// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
// map and/or the sort-based aggregation once it has processed a given number of input rows.
private val testFallbackStartsAt: Option[(Int, Int)] = {
sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match {
case null | "" => None
case fallbackStartsAt =>
val splits = fallbackStartsAt.split(",").map(_.trim)
Some((splits.head.toInt, splits.last.toInt))
if (Utils.isTesting && sqlContext == null) {
None
} else {
sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match {
case null | "" => None
case fallbackStartsAt =>
val splits = fallbackStartsAt.split(",").map(_.trim)
Some((splits.head.toInt, splits.last.toInt))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ case class ObjectHashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode {
extends UnaryExecNode with AliasAwareOutputPartitioning {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand Down Expand Up @@ -97,7 +97,7 @@ case class ObjectHashAggregateExec(
}
}

override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputExpressions: Seq[NamedExpression] = resultExpressions

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numOutputRows = longMetric("numOutputRows")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.SQLMetrics

/**
Expand All @@ -38,7 +38,7 @@ case class SortAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode {
extends UnaryExecNode with AliasAwareOutputPartitioning {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand Down Expand Up @@ -66,7 +66,7 @@ case class SortAggregateExec(
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}

override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputExpressions: Seq[NamedExpression] = resultExpressions

override def outputOrdering: Seq[SortOrder] = {
groupingExpressions.map(SortOrder(_, Ascending))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}

/** Physical plan for Project. */
case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryExecNode with CodegenSupport {
extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning {

override def output: Seq[Attribute] = projectList.map(_.toAttribute)

Expand Down Expand Up @@ -79,7 +79,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputExpressions: Seq[NamedExpression] = projectList
}


Expand Down
34 changes: 34 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ import org.scalatest.exceptions.TestFailedException
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.sql.catalyst.ScroogeLikeExample
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
Expand Down Expand Up @@ -1595,6 +1597,38 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
Seq(Row("Amsterdam")))
}

test("SPARK-25951: avoid redundant shuffle on rename") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried about introducing such a big change for such a corner case.

Looking at this test case, it seems updating the outputPartitioning according to the project list is good enough?

cc @maryannxue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this test case, it seems updating the outputPartitioning according to the project list is good enough?

Basically, that is what is done. The diff is quite big, that's true, but mostly are tests. Is there something specific you are worried about?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried about Expression.sameResult. Is it still useful after updating the outputPartitioning according to the project list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is still useful because it may happen a used writes something like:

select ... group by a as b, ...

(please notice that despite it makes no sense written in SQL, maybe with other APIs is easier to happen).

And especially it is useful when updating the outputPartitioning in order to check whether an expression produces the same result of another.

withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val N = 10
val t1 = spark.range(N).selectExpr("floor(id/4) as k1")
val t2 = spark.range(N).selectExpr("floor(id/4) as k2")

val agg1 = t1.groupBy("k1").agg(count(lit("1")).as("cnt1"))
val agg2 = t2.groupBy("k2").agg(count(lit("1")).as("cnt2")).withColumnRenamed("k2", "k3")
val finalPlan = agg1.join(agg2, $"k1" === $"k3")
val exchanges = finalPlan.queryExecution.executedPlan.collect {
case se: ShuffleExchangeExec => se
}
assert(exchanges.size == 2)
assert(!exchanges.exists(_.newPartitioning match {
case HashPartitioning(Seq(a: AttributeReference), _) => a.name == "k3"
case _ => false
}))

// In this case the requirement is not satisfied
val agg3 = t2.groupBy("k2").agg(count(lit("1")).as("cnt2")).withColumn("k3", $"k2" + 1)
val finalPlan2 = agg1.join(agg3, $"k1" === $"k3")
val exchanges2 = finalPlan2.queryExecution.executedPlan.collect {
case se: ShuffleExchangeExec => se
}
assert(exchanges2.size == 3)
assert(exchanges2.exists(_.newPartitioning match {
case HashPartitioning(Seq(a: AttributeReference), _) => a.name == "k3"
case _ => false
}))
}
}

test("SPARK-24762: Enable top-level Option of Product encoders") {
val data = Seq(Some((1, "a")), Some((2, "b")), None)
val ds = data.toDS()
Expand Down
Loading