From 220112906737b3db668513a024423b35a2c2f32a Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 23 Jul 2015 12:21:08 -0700 Subject: [PATCH 01/15] Filter out rows that will not be joined in equal joins early. --- .../catalyst/expressions/nullFunctions.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 47 ++++++------ .../scala/org/apache/spark/sql/SQLConf.scala | 6 ++ .../org/apache/spark/sql/SQLContext.scala | 5 +- .../extendedOperatorOptimizations.scala | 72 +++++++++++++++++++ .../optimizer/AdvancedOptimizationSuite.scala | 32 +++++++++ 6 files changed, 142 insertions(+), 22 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/optimizer/AdvancedOptimizationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 287718fab7f0d..577fd93a44e2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -217,7 +217,7 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" + override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})" private[this] val childrenArray = children.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 813c62009666c..0dfd4ef4a411b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -31,8 +31,14 @@ import org.apache.spark.sql.types._ abstract class Optimizer extends RuleExecutor[LogicalPlan] -object DefaultOptimizer extends Optimizer { - val batches = +class DefaultOptimizer extends Optimizer { + + /** + * Override to provide additional rules for the "Operator Optimizations" batch. + */ + val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + + lazy val batches = // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: @@ -41,26 +47,27 @@ object DefaultOptimizer extends Optimizer { RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down - SetOperationPushDown, - SamplePushDown, - PushPredicateThroughJoin, - PushPredicateThroughProject, - PushPredicateThroughGenerate, - ColumnPruning, + SetOperationPushDown :: + SamplePushDown :: + PushPredicateThroughJoin :: + PushPredicateThroughProject :: + PushPredicateThroughGenerate :: + ColumnPruning :: // Operator combine - ProjectCollapsing, - CombineFilters, - CombineLimits, + ProjectCollapsing :: + CombineFilters :: + CombineLimits :: // Constant folding - NullPropagation, - OptimizeIn, - ConstantFolding, - LikeSimplification, - BooleanSimplification, - RemovePositive, - SimplifyFilters, - SimplifyCasts, - SimplifyCaseConversionExpressions) :: + NullPropagation :: + OptimizeIn :: + ConstantFolding :: + LikeSimplification :: + BooleanSimplification :: + RemovePositive :: + SimplifyFilters :: + SimplifyCasts :: + SimplifyCaseConversionExpressions :: + extendedOperatorOptimizationRules.toList : _*) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 9b2dbd7442f5c..8364d4b17d862 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -401,6 +401,10 @@ private[spark] object SQLConf { "spark.sql.useSerializer2", defaultValue = Some(true), isPublic = false) + val ADVANCED_SQL_OPTIMIZATION = booleanConf( + "spark.sql.advancedOptimization", + defaultValue = Some(true), isPublic = false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -470,6 +474,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) + private[spark] def advancedSqlOptimizations: Boolean = getConf(ADVANCED_SQL_OPTIMIZATION) + private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dbb2a09846548..31e2b508d485e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.optimizer.FilterNullsInJoinKey import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -156,7 +157,9 @@ class SQLContext(@transient val sparkContext: SparkContext) } @transient - protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer + protected[sql] lazy val optimizer: Optimizer = new DefaultOptimizer { + override val extendedOperatorOptimizationRules = FilterNullsInJoinKey(self) :: Nil + } @transient protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala new file mode 100644 index 0000000000000..c2a5579887123 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.optimizer + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.AtLeastNNonNulls +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, LeftSemi} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule + +case class FilterNullsInJoinKey( + sqlContext: SQLContext) + extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!sqlContext.conf.advancedSqlOptimizations) { + plan + } else { + plan transform { + case join: Join => join match { + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => + // If any left key is null, the join condition will not be true. + // So, we can filter those rows out. + val leftCondition = AtLeastNNonNulls(leftKeys.length, leftKeys) + val leftFilter = Filter(leftCondition, left) + val rightCondition = AtLeastNNonNulls(rightKeys.length, rightKeys) + val rightFilter = Filter(rightCondition, right) + + Join(leftFilter, rightFilter, Inner, join.condition) + + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) => + val rightCondition = AtLeastNNonNulls(rightKeys.length, rightKeys) + val rightFilter = Filter(rightCondition, right) + + Join(left, rightFilter, LeftOuter, join.condition) + + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) => + val leftCondition = AtLeastNNonNulls(leftKeys.length, leftKeys) + val leftFilter = Filter(leftCondition, left) + + Join(leftFilter, right, RightOuter, join.condition) + + case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => + val leftCondition = AtLeastNNonNulls(leftKeys.length, leftKeys) + val leftFilter = Filter(leftCondition, left) + val rightCondition = AtLeastNNonNulls(rightKeys.length, rightKeys) + val rightFilter = Filter(rightCondition, right) + + Join(leftFilter, rightFilter, LeftSemi, join.condition) + + case other => other + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/AdvancedOptimizationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/AdvancedOptimizationSuite.scala new file mode 100644 index 0000000000000..9c13989128028 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/AdvancedOptimizationSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.optimizer + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SQLTestUtils +import org.scalatest.BeforeAndAfterAll + +class AdvancedOptimizationSuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { + val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ + import sqlContext.sql + + override def beforeAll(): Unit = { + + } +} From d5b84c399c6966ad509276b3f146948ff06e5ca4 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 23 Jul 2015 18:48:47 -0700 Subject: [PATCH 02/15] Do not add unnessary filters. --- .../sql/catalyst/optimizer/Optimizer.scala | 18 +++- .../plans/logical/basicOperators.scala | 21 +++- .../extendedOperatorOptimizations.scala | 98 ++++++++++++++----- 3 files changed, 112 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0dfd4ef4a411b..49afb53414c05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -66,8 +66,10 @@ class DefaultOptimizer extends Optimizer { RemovePositive :: SimplifyFilters :: SimplifyCasts :: - SimplifyCaseConversionExpressions :: + SimplifyCaseConversionExpressions :: // Nil : _*) :: extendedOperatorOptimizationRules.toList : _*) :: + // Batch("Extended Operator Optimizations", Once, + // extendedOperatorOptimizationRules : _*) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -550,8 +552,20 @@ object SimplifyFilters extends Rule[LogicalPlan] { * This heuristic is valid assuming the expression evaluation cost is minimal. */ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper { + + /** If the condition changes nullability of attributes. */ + private def preserveNullability(condition: Expression): Boolean = condition match { + // The condition is used to change nullability of attributes when + // - expressions of AtLeastNNonNulls are all attributes; and + // - AtLeastNNonNulls is used to make sure that there is no attribute is null. + case AtLeastNNonNulls(n, expressions) => + !(expressions.forall(_.isInstanceOf[Attribute]) && n == expressions.length) + case other => true + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, project @ Project(fields, grandChild)) => + case filter @ Filter(condition, project @ Project(fields, grandChild)) + if preserveNullability(condition) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). val aliasMap = AttributeMap(fields.collect { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index af68358daf5f1..a0debbffece3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -85,7 +85,26 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + private def shouldChangeNullable(atLeastNNonNulls: AtLeastNNonNulls): Boolean = { + val expressions = atLeastNNonNulls.children + val n = atLeastNNonNulls.n + if (expressions.length != n) { + false + } else { + expressions.forall(_.isInstanceOf[Attribute]) + } + } + + override def output: Seq[Attribute] = condition match { + case a: AtLeastNNonNulls if shouldChangeNullable(a) => + val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]]) + child.output.map { + case attr if nonNullableAttributes.contains(attr) => + attr.withNullability(false) + case attr => attr + } + case _ => child.output + } } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala index c2a5579887123..33582465ca7fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala @@ -18,51 +18,105 @@ package org.apache.spark.sql.optimizer import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.AtLeastNNonNulls +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, LeftSemi} -import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Project, Filter, Join, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule case class FilterNullsInJoinKey( sqlContext: SQLContext) extends Rule[LogicalPlan] { + private def needsFilter(keys: Seq[Expression], plan: LogicalPlan): Boolean = { + if (keys.exists(!_.isInstanceOf[Attribute])) { + true + } else { + val keyAttributeSet = AttributeSet(keys.asInstanceOf[Seq[Attribute]]) + // If any key is still nullable, we need to add a Filter. + plan.output.filter(keyAttributeSet.contains).exists(_.nullable) + } + } + + private def addFilter( + keys: Seq[Expression], + child: LogicalPlan): (Seq[Attribute], Filter) = { + val nonAttributes = keys.filterNot { + case attr: Attribute => true + case _ => false + } + + val materializedKeys = nonAttributes.map { expr => + expr -> Alias(expr, "joinKey")() + }.toMap + + val keyAttributes = keys.map { + case attr: Attribute => attr + case expr => materializedKeys(expr).toAttribute + } + + val project = Project(child.output ++ materializedKeys.map(_._2), child) + val filter = Filter(AtLeastNNonNulls(keyAttributes.length, keyAttributes), project) + + (keyAttributes, filter) + } + + private def rewriteJoinCondition( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + otherPredicate: Option[Expression]): Expression = { + val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map { + case (l, r) => EqualTo(l, r) + }.reduce(And) + + val rewrittenJoinCondition = otherPredicate + .map(c => And(rewrittenEqualJoinCondition, c)) + .getOrElse(rewrittenEqualJoinCondition) + + rewrittenJoinCondition + } + def apply(plan: LogicalPlan): LogicalPlan = { if (!sqlContext.conf.advancedSqlOptimizations) { plan } else { plan transform { case join: Join => join match { - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => // If any left key is null, the join condition will not be true. // So, we can filter those rows out. - val leftCondition = AtLeastNNonNulls(leftKeys.length, leftKeys) - val leftFilter = Filter(leftCondition, left) - val rightCondition = AtLeastNNonNulls(rightKeys.length, rightKeys) - val rightFilter = Filter(rightCondition, right) + val (leftKeyAttributes, leftFilter) = addFilter(leftKeys, left) + val (rightKeyAttributes, rightFilter) = addFilter(rightKeys, right) + val rewrittenJoinCondition = + rewriteJoinCondition(leftKeyAttributes, rightKeyAttributes, condition) - Join(leftFilter, rightFilter, Inner, join.condition) + Join(leftFilter, rightFilter, Inner, Some(rewrittenJoinCondition)) - case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) => - val rightCondition = AtLeastNNonNulls(rightKeys.length, rightKeys) - val rightFilter = Filter(rightCondition, right) + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + if needsFilter(rightKeys, right) => + val (rightKeyAttributes, rightFilter) = addFilter(rightKeys, right) + val rewrittenJoinCondition = + rewriteJoinCondition(leftKeys, rightKeyAttributes, condition) - Join(left, rightFilter, LeftOuter, join.condition) + Join(left, rightFilter, LeftOuter, Some(rewrittenJoinCondition)) - case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) => - val leftCondition = AtLeastNNonNulls(leftKeys.length, leftKeys) - val leftFilter = Filter(leftCondition, left) + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) + if needsFilter(leftKeys, left) => + val (leftKeyAttributes, leftFilter) = addFilter(leftKeys, left) + val rewrittenJoinCondition = + rewriteJoinCondition(leftKeyAttributes, rightKeys, condition) - Join(leftFilter, right, RightOuter, join.condition) + Join(leftFilter, right, RightOuter, Some(rewrittenJoinCondition)) - case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - val leftCondition = AtLeastNNonNulls(leftKeys.length, leftKeys) - val leftFilter = Filter(leftCondition, left) - val rightCondition = AtLeastNNonNulls(rightKeys.length, rightKeys) - val rightFilter = Filter(rightCondition, right) + case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) + if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => + val (leftKeyAttributes, leftFilter) = addFilter(leftKeys, left) + val (rightKeyAttributes, rightFilter) = addFilter(rightKeys, right) + val rewrittenJoinCondition = + rewriteJoinCondition(leftKeyAttributes, rightKeyAttributes, condition) - Join(leftFilter, rightFilter, LeftSemi, join.condition) + Join(leftFilter, rightFilter, LeftSemi, Some(rewrittenJoinCondition)) case other => other } From 69bb0724eb1dd92d20afdde4b607d37bc4d5e4ca Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 23 Jul 2015 18:49:34 -0700 Subject: [PATCH 03/15] Introduce NullSafeHashPartitioning and NullUnsafePartitioning. --- .../spark/sql/catalyst/expressions/misc.scala | 20 +++++++ .../plans/physical/partitioning.scala | 58 ++++++++++++++++--- .../sql/catalyst/DistributionSuite.scala | 36 ++++++------ .../spark/sql/execution/Aggregate.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 31 ++++++++-- .../sql/execution/GeneratedAggregate.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../apache/spark/sql/execution/Window.scala | 2 +- .../aggregate/aggregateOperators.scala | 6 +- .../execution/joins/LeftSemiJoinHash.scala | 6 +- .../execution/joins/ShuffledHashJoin.scala | 6 +- .../joins/ShuffledHashOuterJoin.scala | 4 +- .../sql/execution/joins/SortMergeJoin.scala | 2 +- 13 files changed, 130 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 8d8d66ddeb341..44fa2a0bce738 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -160,3 +161,22 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp }) } } + +/** An expression that returns the hashCode of the input row. */ +case object RowHashCode extends LeafExpression { + override def dataType: DataType = IntegerType + + /** hashCode will never be null. */ + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + input.hashCode + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = i.hashCode(); + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2dcfa19fec383..bcd3eb0010bb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -47,9 +47,23 @@ case object AllTuples extends Distribution * Represents data where tuples that share the same values for the `clustering` * [[Expression Expressions]] will be co-located. Based on the context, this * can mean such tuples are either co-located in the same partition or they will be contiguous - * within a single partition. + * within a single partition. For two null values in two rows evaluated by `clustering`, + * we consider these two nulls are equal. */ -case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution { +case class NullSafeClusteredDistribution(clustering: Seq[Expression]) extends Distribution { + require( + clustering != Nil, + "The clustering expressions of a ClusteredDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") +} + +/** + * It is basically the same as [[NullSafeClusteredDistribution]] except that + * for two null values in two rows evaluated by `clustering`, + * we consider these two nulls are not equal. + */ +case class NullUnsafeClusteredDistribution(clustering: Seq[Expression]) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + @@ -60,7 +74,7 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi /** * Represents data where tuples have been ordered according to the `ordering` * [[Expression Expressions]]. This is a strictly stronger guarantee than - * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for + * [[NullSafeClusteredDistribution]] as an ordering will ensure that tuples that share the same value for * the ordering expressions are contiguous and will never be split across partitions. */ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { @@ -89,7 +103,7 @@ sealed trait Partitioning { /** * Returns true iff all distribution guarantees made by this partitioning can also be made * for the `other` specified partitioning. - * For example, two [[HashPartitioning HashPartitioning]]s are + * For example, two [[NullSafeHashPartitioning HashPartitioning]]s are * only compatible if the `numPartitions` of them is the same. */ def compatibleWith(other: Partitioning): Boolean @@ -143,7 +157,34 @@ case object BroadcastPartitioning extends Partitioning { * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be * in the same partition. */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) +case class NullSafeHashPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends Expression with Partitioning with Unevaluable { + + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: DataType = IntegerType + + private[this] lazy val clusteringSet = expressions.toSet + + override def satisfies(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true + case NullSafeClusteredDistribution(requiredClustering) => + clusteringSet.subsetOf(requiredClustering.toSet) + case NullUnsafeClusteredDistribution(requiredClustering) => + clusteringSet.subsetOf(requiredClustering.toSet) + case _ => false + } + + override def compatibleWith(other: Partitioning): Boolean = other match { + case BroadcastPartitioning => true + case h: NullSafeHashPartitioning if h == this => true + case _ => false + } + + override def keyExpressions: Seq[Expression] = expressions +} + +case class NullUnsafeHashPartitioning(expressions: Seq[Expression], numPartitions: Int) extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions @@ -154,14 +195,14 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering) => + case NullUnsafeClusteredDistribution(requiredClustering) => clusteringSet.subsetOf(requiredClustering.toSet) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true - case h: HashPartitioning if h == this => true + case h: NullUnsafeHashPartitioning if h == this => true case _ => false } @@ -194,14 +235,13 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering) => + case NullSafeClusteredDistribution(requiredClustering) => clusteringSet.subsetOf(requiredClustering.toSet) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true - case r: RangePartitioning if r == this => true case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index c046dbf4dc2c9..5aa1ebf589a93 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -45,23 +45,23 @@ class DistributionSuite extends SparkFunSuite { test("HashPartitioning is the output partitioning") { // Cases which do not need an exchange between two data properties. checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), + NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), + NullSafeClusteredDistribution(Seq('a, 'b, 'c)), true) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + NullSafeHashPartitioning(Seq('b, 'c), 10), + NullSafeClusteredDistribution(Seq('a, 'b, 'c)), true) checkSatisfied( SinglePartition, - ClusteredDistribution(Seq('a, 'b, 'c)), + NullSafeClusteredDistribution(Seq('a, 'b, 'c)), true) checkSatisfied( @@ -71,27 +71,27 @@ class DistributionSuite extends SparkFunSuite { // Cases which need an exchange between two data properties. checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('b, 'c)), + NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), + NullSafeClusteredDistribution(Seq('b, 'c)), false) checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('d, 'e)), + NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), + NullSafeClusteredDistribution(Seq('d, 'e)), false) checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), + NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), AllTuples, false) checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), + NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), false) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), + NullSafeHashPartitioning(Seq('b, 'c), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), false) @@ -128,17 +128,17 @@ class DistributionSuite extends SparkFunSuite { checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + NullSafeClusteredDistribution(Seq('a, 'b, 'c)), true) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'b, 'a)), + NullSafeClusteredDistribution(Seq('c, 'b, 'a)), true) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('b, 'c, 'a, 'd)), + NullSafeClusteredDistribution(Seq('b, 'c, 'a, 'd)), true) // Cases which need an exchange between two data properties. @@ -158,12 +158,12 @@ class DistributionSuite extends SparkFunSuite { checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b)), + NullSafeClusteredDistribution(Seq('a, 'b)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'd)), + NullSafeClusteredDistribution(Seq('c, 'd)), false) checkSatisfied( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index e8c6a0f8f801d..3979e396008b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -52,7 +52,7 @@ case class Aggregate( if (groupingExpressions == Nil) { AllTuples :: Nil } else { - ClusteredDistribution(groupingExpressions) :: Nil + NullSafeClusteredDistribution(groupingExpressions) :: Nil } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 41a0c519ba527..b45625651bab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} @@ -140,10 +141,13 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } } + private val advancedSqlOptimizations = child.sqlContext.conf.advancedSqlOptimizations + protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { val rdd = child.execute() val part: Partitioner = newPartitioning match { - case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) + case NullSafeHashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) + case NullUnsafeHashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. @@ -162,7 +166,24 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // TODO: Handle BroadcastPartitioning. } def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match { - case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() + case NullSafeHashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() + case NullUnsafeHashPartitioning(expressions, numPartition) if advancedSqlOptimizations => + // For NullUnsafeHashPartitioning, we do not want to send rows having any expression + // in `expressions` evaluated as null to the same node. + val materalizeExpressions = newMutableProjection(expressions, child.output)() + val partitionExpressionSchema = expressions.map { expr => + Alias(expr, "partitionExpr")().toAttribute + } + val partitionId = + If( + AtLeastNNonNulls(partitionExpressionSchema.length, partitionExpressionSchema), + RowHashCode, + Cast(Multiply(new Rand(numPartition), Literal(numPartition.toDouble)), IntegerType)) + val partitionIdExtractor = + newMutableProjection(partitionId :: Nil, partitionExpressionSchema)() + (row: InternalRow) => partitionIdExtractor(materalizeExpressions(row)) + case NullUnsafeHashPartitioning(expressions, numPartition) => + newMutableProjection(expressions, child.output)() case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } @@ -276,8 +297,10 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ val fixedChildren = requirements.zipped.map { case (AllTuples, rowOrdering, child) => addOperatorsIfNecessary(SinglePartition, rowOrdering, child) - case (ClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + case (NullSafeClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(NullSafeHashPartitioning(clustering, numPartitions), rowOrdering, child) + case (NullUnsafeClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(NullUnsafeHashPartitioning(clustering, numPartitions), rowOrdering, child) case (OrderedDistribution(ordering), rowOrdering, child) => addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 5ad4691a5ca07..6c390fcb0b343 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -61,7 +61,7 @@ case class GeneratedAggregate( if (groupingExpressions == Nil) { AllTuples :: Nil } else { - ClusteredDistribution(groupingExpressions) :: Nil + NullSafeClusteredDistribution(groupingExpressions) :: Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 306bbfec624c0..56489c1436a05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -403,7 +403,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.RepartitionByExpression(expressions, child) => - execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + execution.Exchange(NullSafeHashPartitioning(expressions, numPartitions), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 91c8a02e2b5bc..bdba16037f208 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -92,7 +92,7 @@ case class Window( logWarning("No Partition Defined for Window operation! Moving all data to a single " + "partition, this can cause serious performance degradation.") AllTuples :: Nil - } else ClusteredDistribution(windowSpec.partitionSpec) :: Nil + } else NullSafeClusteredDistribution(windowSpec.partitionSpec) :: Nil } override def requiredChildOrdering: Seq[Seq[SortOrder]] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala index 0c9082897f390..f4f4eb7fafdeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, NullSafeClusteredDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} case class Aggregate2Sort( @@ -49,7 +49,7 @@ case class Aggregate2Sort( override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case Some(exprs) if exprs.length > 0 => NullSafeClusteredDistribution(exprs) :: Nil case None => UnspecifiedDistribution :: Nil } } @@ -144,7 +144,7 @@ case class FinalAndCompleteAggregate2Sort( if (groupingExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(groupingExpressions) :: Nil + NullSafeClusteredDistribution(groupingExpressions) :: Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 874712a4e739f..da9e3886d88ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, NullUnsafeClusteredDistribution, NullSafeClusteredDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -37,8 +37,8 @@ case class LeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { - override def requiredChildDistribution: Seq[ClusteredDistribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + override def requiredChildDistribution: Seq[Distribution] = + NullUnsafeClusteredDistribution(leftKeys) :: NullUnsafeClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 948d0ccebceb0..d972d96a4a7cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, NullUnsafeClusteredDistribution, NullSafeClusteredDistribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -40,8 +40,8 @@ case class ShuffledHashJoin( override def outputPartitioning: Partitioning = left.outputPartitioning - override def requiredChildDistribution: Seq[ClusteredDistribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + override def requiredChildDistribution: Seq[Distribution] = + NullUnsafeClusteredDistribution(leftKeys) :: NullUnsafeClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index f54f1edd38ec8..51ee7edfd7a60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{NullUnsafeClusteredDistribution, Distribution, NullSafeClusteredDistribution} import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -42,7 +42,7 @@ case class ShuffledHashOuterJoin( right: SparkPlan) extends BinaryNode with HashOuterJoin { override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + NullUnsafeClusteredDistribution(leftKeys) :: NullUnsafeClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { val joinedRow = new JoinedRow() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index bb18b5403f8e8..81a04cc6e2d02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -43,7 +43,7 @@ case class SortMergeJoin( override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + NullUnsafeClusteredDistribution(leftKeys) :: NullUnsafeClusteredDistribution(rightKeys) :: Nil // this is to manually construct an ordering that can be used to compare keys from both sides private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) From 7c2d2d87a7182fbc9fc8b35fd75db64e147f0ff7 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 26 Jul 2015 15:51:38 -0700 Subject: [PATCH 04/15] Bug fix and refactoring. --- .../catalyst/expressions/nullFunctions.scala | 48 +++- .../sql/catalyst/optimizer/Optimizer.scala | 28 +-- .../plans/logical/basicOperators.scala | 21 +- .../plans/physical/partitioning.scala | 14 +- .../sql/catalyst/DistributionSuite.scala | 129 +++++++++- .../expressions/ExpressionEvalHelper.scala | 4 +- .../expressions/MathFunctionsSuite.scala | 3 +- .../expressions/NullFunctionsSuite.scala | 49 +++- .../spark/sql/DataFrameNaFunctions.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 11 +- .../extendedOperatorOptimizations.scala | 117 ++++++---- .../optimizer/AdvancedOptimizationSuite.scala | 32 --- .../optimizer/FilterNullsInJoinKeySuite.scala | 221 ++++++++++++++++++ 13 files changed, 566 insertions(+), 113 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/optimizer/AdvancedOptimizationSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 577fd93a44e2e..d58c4756938c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -210,14 +210,58 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } } +/** + * A predicate that is evaluated to be true if there are at least `n` null values. + */ +case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate { + override def nullable: Boolean = false + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})" + + private[this] val childrenArray = children.toArray + + override def eval(input: InternalRow): Boolean = { + var numNulls = 0 + var i = 0 + while (i < childrenArray.length && numNulls < n) { + val evalC = childrenArray(i).eval(input) + if (evalC == null) { + numNulls += 1 + } + i += 1 + } + numNulls >= n + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val numNulls = ctx.freshName("numNulls") + val code = children.map { e => + val eval = e.gen(ctx) + s""" + if ($numNulls < $n) { + ${eval.code} + if (${eval.isNull}) { + $numNulls += 1; + } + } + """ + }.mkString("\n") + s""" + int $numNulls = 0; + $code + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = $numNulls >= $n; + """ + } +} /** * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values. */ -case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { +case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})" + override def toString: String = s"AtLeastNNonNullNans($n, ${children.mkString(",")})" private[this] val childrenArray = children.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 49afb53414c05..41ee3eb141bb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -66,10 +66,8 @@ class DefaultOptimizer extends Optimizer { RemovePositive :: SimplifyFilters :: SimplifyCasts :: - SimplifyCaseConversionExpressions :: // Nil : _*) :: + SimplifyCaseConversionExpressions :: extendedOperatorOptimizationRules.toList : _*) :: - // Batch("Extended Operator Optimizations", Once, - // extendedOperatorOptimizationRules : _*) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -231,12 +229,18 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** Applies a projection only when the child is producing unnecessary attributes */ - private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = + private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = { if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - Project(allReferences.filter(c.outputSet.contains).toSeq, c) + // We need to preserve the nullability of c's output. + // So, we first create a outputMap and if a reference is from the output of + // c, we use that output attribute from c. + val outputMap = AttributeMap(c.output.map(attr => (attr, attr))) + val projectList = allReferences.filter(outputMap.contains).map(outputMap).toSeq + Project(projectList, c) } else { c } + } } /** @@ -552,20 +556,8 @@ object SimplifyFilters extends Rule[LogicalPlan] { * This heuristic is valid assuming the expression evaluation cost is minimal. */ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper { - - /** If the condition changes nullability of attributes. */ - private def preserveNullability(condition: Expression): Boolean = condition match { - // The condition is used to change nullability of attributes when - // - expressions of AtLeastNNonNulls are all attributes; and - // - AtLeastNNonNulls is used to make sure that there is no attribute is null. - case AtLeastNNonNulls(n, expressions) => - !(expressions.forall(_.isInstanceOf[Attribute]) && n == expressions.length) - case other => true - } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, project @ Project(fields, grandChild)) - if preserveNullability(condition) => + case filter @ Filter(condition, project @ Project(fields, grandChild)) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). val aliasMap = AttributeMap(fields.collect { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index a0debbffece3a..77f294bc43a19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -85,18 +85,29 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - private def shouldChangeNullable(atLeastNNonNulls: AtLeastNNonNulls): Boolean = { - val expressions = atLeastNNonNulls.children - val n = atLeastNNonNulls.n - if (expressions.length != n) { + /** + * Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children + * have at least one null value and atLeastNNulls.children are all attributes. + */ + private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): Boolean = { + val expressions = atLeastNNulls.children + val n = atLeastNNulls.n + if (n != 1) { + // AtLeastNNulls is not used to check if atLeastNNulls.children have + // at least one null value. false } else { + // AtLeastNNulls is used to check if atLeastNNulls.children have + // at least one null value. We need to make sure all atLeastNNulls.children + // are attributes. expressions.forall(_.isInstanceOf[Attribute]) } } override def output: Seq[Attribute] = condition match { - case a: AtLeastNNonNulls if shouldChangeNullable(a) => + case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) => + // The condition is used to make sure that there is no null value in + // a.children. val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]]) child.output.map { case attr if nonNullableAttributes.contains(attr) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index bcd3eb0010bb4..b96dd11e398da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -60,8 +60,7 @@ case class NullSafeClusteredDistribution(clustering: Seq[Expression]) extends Di /** * It is basically the same as [[NullSafeClusteredDistribution]] except that - * for two null values in two rows evaluated by `clustering`, - * we consider these two nulls are not equal. + * it does not require that evaluated rows having any null values to be clustered. */ case class NullUnsafeClusteredDistribution(clustering: Seq[Expression]) extends Distribution { require( @@ -184,6 +183,15 @@ case class NullSafeHashPartitioning(expressions: Seq[Expression], numPartitions: override def keyExpressions: Seq[Expression] = expressions } +/** + * Represents a partitioning where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same non-null values are + * guaranteed to be in the same partition. For `expressions`, if a evaluated row + * has any null value, it is not guaranteed to be in the same partition with other + * rows having the same values. + * + * For example, Row(1, null) and Row(1, null) may not be in the same partition. + */ case class NullUnsafeHashPartitioning(expressions: Seq[Expression], numPartitions: Int) extends Expression with Partitioning with Unevaluable { @@ -237,6 +245,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) requiredOrdering.take(minSize) == ordering.take(minSize) case NullSafeClusteredDistribution(requiredClustering) => clusteringSet.subsetOf(requiredClustering.toSet) + case NullUnsafeClusteredDistribution(requiredClustering) => + clusteringSet.subsetOf(requiredClustering.toSet) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 5aa1ebf589a93..07fe8780145d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -42,7 +42,7 @@ class DistributionSuite extends SparkFunSuite { } } - test("HashPartitioning is the output partitioning") { + test("NullSafeHashPartitioning is the output partitioning") { // Cases which do not need an exchange between two data properties. checkSatisfied( NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), @@ -64,6 +64,21 @@ class DistributionSuite extends SparkFunSuite { NullSafeClusteredDistribution(Seq('a, 'b, 'c)), true) + checkSatisfied( + NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), + NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + NullSafeHashPartitioning(Seq('b, 'c), 10), + NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + SinglePartition, + NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( SinglePartition, OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), @@ -80,6 +95,16 @@ class DistributionSuite extends SparkFunSuite { NullSafeClusteredDistribution(Seq('d, 'e)), false) + checkSatisfied( + NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), + NullUnsafeClusteredDistribution(Seq('b, 'c)), + false) + + checkSatisfied( + NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), + NullUnsafeClusteredDistribution(Seq('d, 'e)), + false) + checkSatisfied( NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), AllTuples, @@ -104,6 +129,82 @@ class DistributionSuite extends SparkFunSuite { */ } + + + test("NullUnsafeHashPartitioning is the output partitioning") { + // Cases which do not need an exchange between two data properties. + checkSatisfied( + NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), + NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + NullUnsafeHashPartitioning(Seq('b, 'c), 10), + NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + SinglePartition, + NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + SinglePartition, + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + true) + + // Cases which need an exchange between two data properties. + checkSatisfied( + NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), + NullSafeClusteredDistribution(Seq('a, 'b, 'c)), + false) + + checkSatisfied( + NullUnsafeHashPartitioning(Seq('b, 'c), 10), + NullSafeClusteredDistribution(Seq('a, 'b, 'c)), + false) + + checkSatisfied( + NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), + NullSafeClusteredDistribution(Seq('b, 'c)), + false) + + checkSatisfied( + NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), + NullSafeClusteredDistribution(Seq('d, 'e)), + false) + + checkSatisfied( + NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), + NullUnsafeClusteredDistribution(Seq('b, 'c)), + false) + + checkSatisfied( + NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), + NullUnsafeClusteredDistribution(Seq('d, 'e)), + false) + + checkSatisfied( + NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), + AllTuples, + false) + + checkSatisfied( + NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) + + checkSatisfied( + NullUnsafeHashPartitioning(Seq('b, 'c), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) + } + test("RangePartitioning is the output partitioning") { // Cases which do not need an exchange between two data properties. checkSatisfied( @@ -141,6 +242,22 @@ class DistributionSuite extends SparkFunSuite { NullSafeClusteredDistribution(Seq('b, 'c, 'a, 'd)), true) + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + NullUnsafeClusteredDistribution(Seq('c, 'b, 'a)), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + NullUnsafeClusteredDistribution(Seq('b, 'c, 'a, 'd)), + true) + + // Cases which need an exchange between two data properties. // TODO: We can have an optimization to first sort the dataset // by a.asc and then sort b, and c in a partition. This optimization @@ -161,6 +278,16 @@ class DistributionSuite extends SparkFunSuite { NullSafeClusteredDistribution(Seq('a, 'b)), false) + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + NullUnsafeClusteredDistribution(Seq('c, 'd)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + NullUnsafeClusteredDistribution(Seq('a, 'b)), + false) + checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), NullSafeClusteredDistribution(Seq('c, 'd)), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index ab0cdc857c80e..9211341faf4f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -33,6 +33,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} trait ExpressionEvalHelper { self: SparkFunSuite => + protected val defaultOptimizer = new DefaultOptimizer + protected def create_row(values: Any*): InternalRow = { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } @@ -177,7 +179,7 @@ trait ExpressionEvalHelper { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) + val optimizedPlan = defaultOptimizer.execute(plan) checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 21459a7c69838..c7ee172c0f252 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ @@ -168,7 +167,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) + val optimizedPlan = defaultOptimizer.execute(plan) checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index ace6c15dc8418..bf197124d8dbc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -77,7 +77,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("AtLeastNNonNulls") { + test("AtLeastNNonNullNans") { val mix = Seq(Literal("x"), Literal.create(null, StringType), Literal.create(null, DoubleType), @@ -96,11 +96,46 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal(Float.MaxValue), Literal(false)) - checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow) - checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow) - checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(0, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(0, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(3, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(4, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(0, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(3, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(4, nullOnly), false, EmptyRow) + } + + test("AtLeastNNull") { + val mix = Seq(Literal("x"), + Literal.create(null, StringType), + Literal.create(null, DoubleType), + Literal(Double.NaN), + Literal(5f)) + + val nanOnly = Seq(Literal("x"), + Literal(10.0), + Literal(Float.NaN), + Literal(math.log(-2)), + Literal(Double.MaxValue)) + + val nullOnly = Seq(Literal("x"), + Literal.create(null, DoubleType), + Literal.create(null, DecimalType.USER_DEFAULT), + Literal(Float.MaxValue), + Literal(false)) + + checkEvaluation(AtLeastNNulls(0, mix), true, EmptyRow) + checkEvaluation(AtLeastNNulls(1, mix), true, EmptyRow) + checkEvaluation(AtLeastNNulls(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNulls(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNulls(0, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(1, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNulls(2, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNulls(0, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(1, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(2, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(3, nullOnly), false, EmptyRow) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index a4fd4cf3b330b..ea85f0657a726 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -122,7 +122,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { // Filtering condition: // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. - val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) + val predicate = AtLeastNNonNullNans(minNonNulls, cols.map(name => df.resolve(name))) df.filter(Column(predicate)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index b45625651bab2..9997925b28aa8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -171,18 +171,23 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // For NullUnsafeHashPartitioning, we do not want to send rows having any expression // in `expressions` evaluated as null to the same node. val materalizeExpressions = newMutableProjection(expressions, child.output)() - val partitionExpressionSchema = expressions.map { expr => - Alias(expr, "partitionExpr")().toAttribute + val partitionExpressionSchema = expressions.map { + case ne: NamedExpression => ne.toAttribute + case expr => Alias(expr, "partitionExpr")().toAttribute } val partitionId = If( - AtLeastNNonNulls(partitionExpressionSchema.length, partitionExpressionSchema), + Not(AtLeastNNulls(1, partitionExpressionSchema)), + // There is no null value in the partition expressions, we can just get the + // hashCode of the input row. RowHashCode, Cast(Multiply(new Rand(numPartition), Literal(numPartition.toDouble)), IntegerType)) val partitionIdExtractor = newMutableProjection(partitionId :: Nil, partitionExpressionSchema)() (row: InternalRow) => partitionIdExtractor(materalizeExpressions(row)) case NullUnsafeHashPartitioning(expressions, numPartition) => + // If spark.sql.advancedOptimization is not enabled, we will just do the same thing + // as NullSafeHashPartitioning. newMutableProjection(expressions, child.output)() case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala index 33582465ca7fb..929d6b429aaf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala @@ -24,51 +24,84 @@ import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, LeftSe import org.apache.spark.sql.catalyst.plans.logical.{Project, Filter, Join, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +/** + * An optimization rule used to insert Filters to filter out rows whose equal join keys + * have at least one null values. For this kind of rows, they will not contribute to + * the join results of equal joins because a null does not equal another null. We can + * filter them out before shuffling join input rows. For example, we have two tables + * + * table1(key String, value Int) + * "str1"|1 + * null |2 + * + * table2(key String, value Int) + * "str1"|3 + * null |4 + * + * For a inner equal join, the result will be + * "str1"|1|"str1"|3 + * + * those two rows having null as the value of key will not contribute to the result. + * So, we can filter them out early. + * + * This optimization rule can be disabled by setting spark.sql.advancedOptimization to false. + * + */ case class FilterNullsInJoinKey( sqlContext: SQLContext) extends Rule[LogicalPlan] { + /** + * Checks if we need to add a Filter operator. We will add a Filter when + * there is any attribute in `keys` whose corresponding attribute of `keys` + * in `plan.output` is still nullable (`nullable` field is `true`). + */ private def needsFilter(keys: Seq[Expression], plan: LogicalPlan): Boolean = { - if (keys.exists(!_.isInstanceOf[Attribute])) { - true - } else { - val keyAttributeSet = AttributeSet(keys.asInstanceOf[Seq[Attribute]]) - // If any key is still nullable, we need to add a Filter. - plan.output.filter(keyAttributeSet.contains).exists(_.nullable) - } + val keyAttributeSet = AttributeSet(keys.filter(_.isInstanceOf[Attribute])) + plan.output.filter(keyAttributeSet.contains).exists(_.nullable) } - private def addFilter( + /** + * Adds a Filter operator to make sure that every attribute in `keys` is non-nullable. + */ + private def addFilterIfNecessary( keys: Seq[Expression], - child: LogicalPlan): (Seq[Attribute], Filter) = { - val nonAttributes = keys.filterNot { + child: LogicalPlan): LogicalPlan = { + // We get all attributes from keys. + val attributes = keys.filter { case attr: Attribute => true case _ => false } - val materializedKeys = nonAttributes.map { expr => - expr -> Alias(expr, "joinKey")() - }.toMap - - val keyAttributes = keys.map { - case attr: Attribute => attr - case expr => materializedKeys(expr).toAttribute - } - - val project = Project(child.output ++ materializedKeys.map(_._2), child) - val filter = Filter(AtLeastNNonNulls(keyAttributes.length, keyAttributes), project) + // Then, we create a Filter to make sure these attributes are non-nullable. + val filter = + if (attributes.nonEmpty) { + Filter(Not(AtLeastNNulls(1, attributes)), child) + } else { + child + } - (keyAttributes, filter) + // We return attributes representing keys (keyAttributes) and the filter. + // keyAttributes will be used to rewrite the join condition. + filter } - private def rewriteJoinCondition( + /** + * We reconstruct the join condition. + */ + private def reconstructJoinCondition( leftKeys: Seq[Expression], rightKeys: Seq[Expression], otherPredicate: Option[Expression]): Expression = { + // First, we rewrite the equal condition part. When we extract those keys, + // we use splitConjunctivePredicates. So, it is safe to use .reduce(And). val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map { case (l, r) => EqualTo(l, r) }.reduce(And) + // Then, we add otherPredicate. When we extract those equal condition part, + // we use splitConjunctivePredicates. So, it is safe to use + // And(rewrittenEqualJoinCondition, c). val rewrittenJoinCondition = otherPredicate .map(c => And(rewrittenEqualJoinCondition, c)) .getOrElse(rewrittenEqualJoinCondition) @@ -82,41 +115,47 @@ case class FilterNullsInJoinKey( } else { plan transform { case join: Join => join match { + // For a inner join having equal join condition part, we can add filters + // to both sides of the join operator. case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => - // If any left key is null, the join condition will not be true. - // So, we can filter those rows out. - val (leftKeyAttributes, leftFilter) = addFilter(leftKeys, left) - val (rightKeyAttributes, rightFilter) = addFilter(rightKeys, right) + val withLeftFilter = addFilterIfNecessary(leftKeys, left) + val withRightFilter = addFilterIfNecessary(rightKeys, right) val rewrittenJoinCondition = - rewriteJoinCondition(leftKeyAttributes, rightKeyAttributes, condition) + reconstructJoinCondition(leftKeys, rightKeys, condition) - Join(leftFilter, rightFilter, Inner, Some(rewrittenJoinCondition)) + Join(withLeftFilter, withRightFilter, Inner, Some(rewrittenJoinCondition)) + // For a left outer join having equal join condition part, we can add a filter + // to the right side of the join operator. case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) if needsFilter(rightKeys, right) => - val (rightKeyAttributes, rightFilter) = addFilter(rightKeys, right) + val withRightFilter = addFilterIfNecessary(rightKeys, right) val rewrittenJoinCondition = - rewriteJoinCondition(leftKeys, rightKeyAttributes, condition) + reconstructJoinCondition(leftKeys, rightKeys, condition) - Join(left, rightFilter, LeftOuter, Some(rewrittenJoinCondition)) + Join(left, withRightFilter, LeftOuter, Some(rewrittenJoinCondition)) + // For a right outer join having equal join condition part, we can add a filter + // to the left side of the join operator. case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) if needsFilter(leftKeys, left) => - val (leftKeyAttributes, leftFilter) = addFilter(leftKeys, left) + val withLeftFilter = addFilterIfNecessary(leftKeys, left) val rewrittenJoinCondition = - rewriteJoinCondition(leftKeyAttributes, rightKeys, condition) + reconstructJoinCondition(leftKeys, rightKeys, condition) - Join(leftFilter, right, RightOuter, Some(rewrittenJoinCondition)) + Join(withLeftFilter, right, RightOuter, Some(rewrittenJoinCondition)) + // For a left semi join having equal join condition part, we can add filters + // to both sides of the join operator. case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => - val (leftKeyAttributes, leftFilter) = addFilter(leftKeys, left) - val (rightKeyAttributes, rightFilter) = addFilter(rightKeys, right) + val withLeftFilter = addFilterIfNecessary(leftKeys, left) + val withRightFilter = addFilterIfNecessary(rightKeys, right) val rewrittenJoinCondition = - rewriteJoinCondition(leftKeyAttributes, rightKeyAttributes, condition) + reconstructJoinCondition(leftKeys, rightKeys, condition) - Join(leftFilter, rightFilter, LeftSemi, Some(rewrittenJoinCondition)) + Join(withLeftFilter, withRightFilter, LeftSemi, Some(rewrittenJoinCondition)) case other => other } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/AdvancedOptimizationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/AdvancedOptimizationSuite.scala deleted file mode 100644 index 9c13989128028..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/AdvancedOptimizationSuite.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.optimizer - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.test.SQLTestUtils -import org.scalatest.BeforeAndAfterAll - -class AdvancedOptimizationSuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { - val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql - - override def beforeAll(): Unit = { - - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala new file mode 100644 index 0000000000000..2a6b29f6e96f5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.AtLeastNNulls +import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.test.TestSQLContext + +class FilterNullsInJoinKeySuite extends PlanTest { + + // We add predicate pushdown rules at here to make sure we do not + // create redundant + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Operator Optimizations", FixedPoint(100), + FilterNullsInJoinKey(TestSQLContext), + CombineFilters, + PushPredicateThroughProject, + BooleanSimplification, + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + ColumnPruning, + ProjectCollapsing) :: Nil + } + + val leftRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int) + + val rightRelation = LocalRelation('e.int, 'f.int, 'g.int, 'h.int) + + test("inner join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For an inner join, FilterNullsInJoinKey add filter to both side. + val correctLeft = + leftRelation + .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) + .select('a, 'b, 'd) + + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + correctLeft + .join(correctRight, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join (partially optimized)") { + val joinCondition = + ('a + 2 === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // We cannot extract attribute from the left join key. + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + leftRelation + .select('a, 'b, 'd) + .join(correctRight, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join (not optimized)") { + val nonOptimizedJoinConditions = + Some('c - 100 + 'd === 'g + 1 - 'h) :: + Some('d > 'h || 'c === 'g) :: + Some('d + 'g + 'c > 'd - 'h) :: Nil + + nonOptimizedJoinConditions.foreach { joinCondition => + val joinedPlan = + leftRelation + .select('a, 'c, 'd) + .join(rightRelation.select('f, 'g, 'h), Inner, joinCondition) + .select('a, 'c, 'f, 'd, 'h, 'g) + + val optimized = Optimize.execute(joinedPlan.analyze) + + comparePlans(optimized, joinedPlan.analyze) + } + } + + test("left outer join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, LeftOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For a left outer join, FilterNullsInJoinKey add filter to the right side. + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + leftRelation + .select('a, 'b, 'd) + .join(correctRight, LeftOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("right outer join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, RightOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For a right outer join, FilterNullsInJoinKey add filter to the left side. + val correctLeft = + leftRelation + .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) + .select('a, 'b, 'd) + + val correctAnswer = + correctLeft + .join(rightRelation, RightOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("full outer join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .select('a, 'b, 'd) + .join(rightRelation, FullOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + // FilterNullsInJoinKey does not fire for a full outer join. + val optimized = Optimize.execute(joinedPlan.analyze) + + comparePlans(optimized, joinedPlan.analyze) + } + + test("left semi join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, LeftSemi, Some(joinCondition)) + .select('a, 'd) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For a left semi join, FilterNullsInJoinKey add filter to both side. + val correctLeft = + leftRelation + .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) + .select('a, 'b, 'd) + + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + correctLeft + .join(correctRight, LeftSemi, Some(joinCondition)) + .select('a, 'd) + .analyze + + comparePlans(optimized, correctAnswer) + } +} From e616d3b0a2fa5836956c15b9f64410683a3ef9db Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 26 Jul 2015 20:28:49 -0700 Subject: [PATCH 05/15] wip --- .../plans/physical/partitioning.scala | 52 ++++++++++++++++++- .../apache/spark/sql/execution/Exchange.scala | 33 ++++++++---- .../joins/BroadcastHashOuterJoin.scala | 4 +- .../sql/execution/joins/HashOuterJoin.scala | 8 --- .../execution/joins/LeftSemiJoinHash.scala | 6 ++- .../execution/joins/ShuffledHashJoin.scala | 2 +- .../joins/ShuffledHashOuterJoin.scala | 15 +++++- 7 files changed, 94 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index b96dd11e398da..ff3979f738185 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -107,8 +107,12 @@ sealed trait Partitioning { */ def compatibleWith(other: Partitioning): Boolean + def guarantees(other: Partitioning): Boolean + /** Returns the expressions that are used to key the partitioning. */ def keyExpressions: Seq[Expression] + + def toNullUnsafePartitioning: Partitioning } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -122,7 +126,11 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case _ => false } + override def guarantees(other: Partitioning): Boolean = false + override def keyExpressions: Seq[Expression] = Nil + + override def toNullUnsafePartitioning: Partitioning = this } case object SinglePartition extends Partitioning { @@ -135,7 +143,14 @@ case object SinglePartition extends Partitioning { case _ => false } + override def guarantees(other: Partitioning): Boolean = other match { + case SinglePartition => true + case _ => false + } + override def keyExpressions: Seq[Expression] = Nil + + override def toNullUnsafePartitioning: Partitioning = this } case object BroadcastPartitioning extends Partitioning { @@ -148,7 +163,14 @@ case object BroadcastPartitioning extends Partitioning { case _ => false } + override def guarantees(other: Partitioning): Boolean = other match { + case BroadcastPartitioning => true + case _ => false + } + override def keyExpressions: Seq[Expression] = Nil + + override def toNullUnsafePartitioning: Partitioning = this } /** @@ -163,7 +185,7 @@ case class NullSafeHashPartitioning(expressions: Seq[Expression], numPartitions: override def nullable: Boolean = false override def dataType: DataType = IntegerType - private[this] lazy val clusteringSet = expressions.toSet + lazy val clusteringSet = expressions.toSet override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true @@ -180,7 +202,18 @@ case class NullSafeHashPartitioning(expressions: Seq[Expression], numPartitions: case _ => false } + override def guarantees(other: Partitioning): Boolean = other match { + case o: NullSafeHashPartitioning => + this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case o: NullUnsafeHashPartitioning => + this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case _ => false + } + override def keyExpressions: Seq[Expression] = expressions + + override def toNullUnsafePartitioning: Partitioning = + NullUnsafeHashPartitioning(expressions, numPartitions) } /** @@ -199,7 +232,7 @@ case class NullUnsafeHashPartitioning(expressions: Seq[Expression], numPartition override def nullable: Boolean = false override def dataType: DataType = IntegerType - private[this] lazy val clusteringSet = expressions.toSet + lazy val clusteringSet = expressions.toSet override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true @@ -214,7 +247,15 @@ case class NullUnsafeHashPartitioning(expressions: Seq[Expression], numPartition case _ => false } + override def guarantees(other: Partitioning): Boolean = other match { + case o: NullUnsafeHashPartitioning => + this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case _ => false + } + override def keyExpressions: Seq[Expression] = expressions + + override def toNullUnsafePartitioning: Partitioning = this } /** @@ -255,5 +296,12 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } + override def guarantees(other: Partitioning): Boolean = other match { + case o: RangePartitioning => this == o + case _ => false + } + override def keyExpressions: Seq[Expression] = ordering.map(_.child) + + override def toNullUnsafePartitioning: Partitioning = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 9997925b28aa8..65cccbd15879b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -141,8 +141,6 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } } - private val advancedSqlOptimizations = child.sqlContext.conf.advancedSqlOptimizations - protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { val rdd = child.execute() val part: Partitioner = newPartitioning match { @@ -167,7 +165,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match { case NullSafeHashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() - case NullUnsafeHashPartitioning(expressions, numPartition) if advancedSqlOptimizations => + case NullUnsafeHashPartitioning(expressions, numPartition) => // For NullUnsafeHashPartitioning, we do not want to send rows having any expression // in `expressions` evaluated as null to the same node. val materalizeExpressions = newMutableProjection(expressions, child.output)() @@ -185,10 +183,6 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una val partitionIdExtractor = newMutableProjection(partitionId :: Nil, partitionExpressionSchema)() (row: InternalRow) => partitionIdExtractor(materalizeExpressions(row)) - case NullUnsafeHashPartitioning(expressions, numPartition) => - // If spark.sql.advancedOptimization is not enabled, we will just do the same thing - // as NullSafeHashPartitioning. - newMutableProjection(expressions, child.output)() case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } @@ -221,6 +215,8 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ // TODO: Determine the number of partitions. def numPartitions: Int = sqlContext.conf.numShufflePartitions + def advancedSqlOptimizations = sqlContext.conf.advancedSqlOptimizations + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => // True iff every child's outputPartitioning satisfies the corresponding @@ -265,7 +261,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ child: SparkPlan): SparkPlan = { def addShuffleIfNecessary(child: SparkPlan): SparkPlan = { - if (child.outputPartitioning != partitioning) { + if (child.outputPartitioning.guarantees(partitioning)) { Exchange(partitioning, child) } else { child @@ -302,15 +298,32 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ val fixedChildren = requirements.zipped.map { case (AllTuples, rowOrdering, child) => addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + case (NullSafeClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(NullSafeHashPartitioning(clustering, numPartitions), rowOrdering, child) + addOperatorsIfNecessary( + NullSafeHashPartitioning(clustering, numPartitions), + rowOrdering, + child) + case (NullUnsafeClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(NullUnsafeHashPartitioning(clustering, numPartitions), rowOrdering, child) + if (advancedSqlOptimizations) { + addOperatorsIfNecessary( + NullUnsafeHashPartitioning(clustering, numPartitions), + rowOrdering, + child) + } else { + addOperatorsIfNecessary( + NullSafeHashPartitioning(clustering, numPartitions), + rowOrdering, + child) + } + case (OrderedDistribution(ordering), rowOrdering, child) => addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) case (UnspecifiedDistribution, Seq(), child) => child + case (UnspecifiedDistribution, rowOrdering, child) => sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index c9d1a880f4ef4..c30ad3708e3e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.ThreadUtils @@ -57,6 +57,8 @@ case class BroadcastHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + @transient private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 6bf2f82954046..d0cee21fbb840 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -38,14 +38,6 @@ trait HashOuterJoin { val left: SparkPlan val right: SparkPlan - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - override def output: Seq[Attribute] = { joinType match { case LeftOuter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index da9e3886d88ff..18d32343494a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, NullUnsafeClusteredDistribution, NullSafeClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, NullUnsafeClusteredDistribution, NullSafeClusteredDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -37,8 +37,10 @@ case class LeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { + override def outputPartitioning: Partitioning = left.outputPartitioning + override def requiredChildDistribution: Seq[Distribution] = - NullUnsafeClusteredDistribution(leftKeys) :: NullUnsafeClusteredDistribution(rightKeys) :: Nil + NullSafeClusteredDistribution(leftKeys) :: NullSafeClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index d972d96a4a7cc..01d200e31b4d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -41,7 +41,7 @@ case class ShuffledHashJoin( override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = - NullUnsafeClusteredDistribution(leftKeys) :: NullUnsafeClusteredDistribution(rightKeys) :: Nil + NullSafeClusteredDistribution(leftKeys) :: NullSafeClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 51ee7edfd7a60..7151e93a571e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{NullUnsafeClusteredDistribution, Distribution, NullSafeClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -41,8 +41,19 @@ case class ShuffledHashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashOuterJoin { + // It is a heuristic. We use NullUnsafeClusteredDistribution to + // let input rows that will have a match distributed evenly. override def requiredChildDistribution: Seq[Distribution] = - NullUnsafeClusteredDistribution(leftKeys) :: NullUnsafeClusteredDistribution(rightKeys) :: Nil + NullUnsafeClusteredDistribution(leftKeys) :: + NullUnsafeClusteredDistribution(rightKeys) :: Nil + + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => left.outputPartitioning.toNullUnsafePartitioning + case x => + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") + } protected override def doExecute(): RDD[InternalRow] = { val joinedRow = new JoinedRow() From c6667e745b0ce0c24dccd419d8fea10e21d24290 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 26 Jul 2015 22:03:46 -0700 Subject: [PATCH 06/15] Add PartitioningCollection. --- .../plans/physical/partitioning.scala | 40 +++++++++++++++++++ .../apache/spark/sql/execution/Exchange.scala | 17 +++++++- .../execution/joins/ShuffledHashJoin.scala | 5 ++- .../joins/ShuffledHashOuterJoin.scala | 16 ++++++-- .../sql/execution/joins/SortMergeJoin.scala | 3 +- 5 files changed, 73 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index ff3979f738185..a3f8229b6eca7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -305,3 +305,43 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def toNullUnsafePartitioning: Partitioning = this } + +/** + * A collection of [[Partitioning]]s. + */ +case class PartitioningCollection(partitionings: Seq[Partitioning]) + extends Expression with Partitioning with Unevaluable { + + require( + partitionings.map(_.numPartitions).distinct.length == 1, + s"PartitioningCollection requires all of its partitionings have the same numPartitions.") + + override def children: Seq[Expression] = partitionings.collect { + case expr: Expression => expr + } + + override def nullable: Boolean = false + + override def dataType: DataType = IntegerType + + override val numPartitions = partitionings.map(_.numPartitions).distinct.head + + override def satisfies(required: Distribution): Boolean = + partitionings.exists(_.satisfies(required)) + + override def compatibleWith(other: Partitioning): Boolean = + partitionings.exists(_.compatibleWith(other)) + + override def guarantees(other: Partitioning): Boolean = + partitionings.exists(_.guarantees(other)) + + override def keyExpressions: Seq[Expression] = partitionings.head.keyExpressions + + override def toNullUnsafePartitioning: Partitioning = { + PartitioningCollection(partitionings.map(_.toNullUnsafePartitioning)) + } + + override def toString: String = { + partitionings.map(_.toString).mkString("(", " or ", ")") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 65cccbd15879b..c92b202d2a7b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -164,7 +164,20 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // TODO: Handle BroadcastPartitioning. } def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match { - case NullSafeHashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() + case NullSafeHashPartitioning(expressions, _) => + // Since NullSafeHashPartitioning and NullUnsafeHashPartitioning may be used together + // for a join operator. We need to make sure they calculate the partition id with + // the same way. + val materalizeExpressions = newMutableProjection(expressions, child.output)() + val partitionExpressionSchema = expressions.map { + case ne: NamedExpression => ne.toAttribute + case expr => Alias(expr, "partitionExpr")().toAttribute + } + val partitionId = RowHashCode + val partitionIdExtractor = + newMutableProjection(partitionId :: Nil, partitionExpressionSchema)() + (row: InternalRow) => partitionIdExtractor(materalizeExpressions(row)) + // newMutableProjection(expressions, child.output)() case NullUnsafeHashPartitioning(expressions, numPartition) => // For NullUnsafeHashPartitioning, we do not want to send rows having any expression // in `expressions` evaluated as null to the same node. @@ -261,7 +274,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ child: SparkPlan): SparkPlan = { def addShuffleIfNecessary(child: SparkPlan): SparkPlan = { - if (child.outputPartitioning.guarantees(partitioning)) { + if (!child.outputPartitioning.guarantees(partitioning)) { Exchange(partitioning, child) } else { child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 01d200e31b4d8..ca36aba4dd8f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, NullUnsafeClusteredDistribution, NullSafeClusteredDistribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -38,7 +38,8 @@ case class ShuffledHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) override def requiredChildDistribution: Seq[Distribution] = NullSafeClusteredDistribution(leftKeys) :: NullSafeClusteredDistribution(rightKeys) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 7151e93a571e3..579266e49aeeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -48,9 +48,19 @@ case class ShuffledHashOuterJoin( NullUnsafeClusteredDistribution(rightKeys) :: Nil override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => left.outputPartitioning.toNullUnsafePartitioning + case LeftOuter => + val partitions = + Seq(left.outputPartitioning, right.outputPartitioning.toNullUnsafePartitioning) + PartitioningCollection(partitions) + case RightOuter => + val partitions = + Seq(right.outputPartitioning, left.outputPartitioning.toNullUnsafePartitioning) + PartitioningCollection(partitions) + case FullOuter => + val partitions = + Seq(left.outputPartitioning.toNullUnsafePartitioning, + right.outputPartitioning.toNullUnsafePartitioning) + PartitioningCollection(partitions) case x => throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 81a04cc6e2d02..6dd7364f271f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -40,7 +40,8 @@ case class SortMergeJoin( override def output: Seq[Attribute] = left.output ++ right.output - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) override def requiredChildDistribution: Seq[Distribution] = NullUnsafeClusteredDistribution(leftKeys) :: NullUnsafeClusteredDistribution(rightKeys) :: Nil From f9516b0687a90713f2b401d49418ec8ee081f457 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 26 Jul 2015 22:29:48 -0700 Subject: [PATCH 07/15] Style --- .../spark/sql/catalyst/plans/physical/partitioning.scala | 5 +++-- .../scala/org/apache/spark/sql/execution/Exchange.scala | 8 +++++--- .../org/apache/spark/sql/execution/SparkStrategies.scala | 3 ++- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index a3f8229b6eca7..a1cfa540bd415 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -73,8 +73,9 @@ case class NullUnsafeClusteredDistribution(clustering: Seq[Expression]) extends /** * Represents data where tuples have been ordered according to the `ordering` * [[Expression Expressions]]. This is a strictly stronger guarantee than - * [[NullSafeClusteredDistribution]] as an ordering will ensure that tuples that share the same value for - * the ordering expressions are contiguous and will never be split across partitions. + * [[NullSafeClusteredDistribution]] as an ordering will ensure that tuples that share the + * same value for the ordering expressions are contiguous and will never be split across + * partitions. */ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { require( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index c92b202d2a7b5..987d85bc3c05e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -144,8 +144,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { val rdd = child.execute() val part: Partitioner = newPartitioning match { - case NullSafeHashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) - case NullUnsafeHashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) + case NullSafeHashPartitioning(expressions, numPartitions) => + new HashPartitioner(numPartitions) + case NullUnsafeHashPartitioning(expressions, numPartitions) => + new HashPartitioner(numPartitions) case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. @@ -228,7 +230,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ // TODO: Determine the number of partitions. def numPartitions: Int = sqlContext.conf.numShufflePartitions - def advancedSqlOptimizations = sqlContext.conf.advancedSqlOptimizations + def advancedSqlOptimizations: Boolean = sqlContext.conf.advancedSqlOptimizations def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 56489c1436a05..ffbbac3137b93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -403,7 +403,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.RepartitionByExpression(expressions, child) => - execution.Exchange(NullSafeHashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + execution.Exchange( + NullSafeHashPartitioning(expressions, numPartitions), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil From d3d2e646d525cc9c6e425ae99020d26bbaab10dc Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 27 Jul 2015 14:14:34 -0700 Subject: [PATCH 08/15] First round of cleanup. --- .../plans/physical/partitioning.scala | 104 +++++----------- .../sql/catalyst/DistributionSuite.scala | 112 +++++++++--------- .../spark/sql/execution/Aggregate.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 18 ++- .../sql/execution/GeneratedAggregate.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../apache/spark/sql/execution/Window.scala | 2 +- .../aggregate/aggregateOperators.scala | 6 +- .../execution/joins/LeftSemiJoinHash.scala | 4 +- .../execution/joins/ShuffledHashJoin.scala | 2 +- .../joins/ShuffledHashOuterJoin.scala | 12 +- .../sql/execution/joins/SortMergeJoin.scala | 2 +- 12 files changed, 111 insertions(+), 157 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index a1cfa540bd415..9c04d55f6fb5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -47,10 +47,13 @@ case object AllTuples extends Distribution * Represents data where tuples that share the same values for the `clustering` * [[Expression Expressions]] will be co-located. Based on the context, this * can mean such tuples are either co-located in the same partition or they will be contiguous - * within a single partition. For two null values in two rows evaluated by `clustering`, + * within a single partition. When `nullSafe` is true, + * for two null values in two rows evaluated by `clustering`, * we consider these two nulls are equal. */ -case class NullSafeClusteredDistribution(clustering: Seq[Expression]) extends Distribution { +case class ClusteredDistribution( + clustering: Seq[Expression], + nullSafe: Boolean) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + @@ -58,22 +61,15 @@ case class NullSafeClusteredDistribution(clustering: Seq[Expression]) extends Di "a single partition.") } -/** - * It is basically the same as [[NullSafeClusteredDistribution]] except that - * it does not require that evaluated rows having any null values to be clustered. - */ -case class NullUnsafeClusteredDistribution(clustering: Seq[Expression]) extends Distribution { - require( - clustering != Nil, - "The clustering expressions of a ClusteredDistribution should not be Nil. " + - "An AllTuples should be used to represent a distribution that only has " + - "a single partition.") +object ClusteredDistribution { + def apply(clustering: Seq[Expression]): ClusteredDistribution = + ClusteredDistribution(clustering, nullSafe = true) } /** * Represents data where tuples have been ordered according to the `ordering` * [[Expression Expressions]]. This is a strictly stronger guarantee than - * [[NullSafeClusteredDistribution]] as an ordering will ensure that tuples that share the + * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the * same value for the ordering expressions are contiguous and will never be split across * partitions. */ @@ -103,7 +99,7 @@ sealed trait Partitioning { /** * Returns true iff all distribution guarantees made by this partitioning can also be made * for the `other` specified partitioning. - * For example, two [[NullSafeHashPartitioning HashPartitioning]]s are + * For example, two [[HashPartitioning HashPartitioning]]s are * only compatible if the `numPartitions` of them is the same. */ def compatibleWith(other: Partitioning): Boolean @@ -113,7 +109,7 @@ sealed trait Partitioning { /** Returns the expressions that are used to key the partitioning. */ def keyExpressions: Seq[Expression] - def toNullUnsafePartitioning: Partitioning + def withNullSafeSetting(newNullSafe: Boolean): Partitioning } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -131,7 +127,7 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { override def keyExpressions: Seq[Expression] = Nil - override def toNullUnsafePartitioning: Partitioning = this + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this } case object SinglePartition extends Partitioning { @@ -151,7 +147,7 @@ case object SinglePartition extends Partitioning { override def keyExpressions: Seq[Expression] = Nil - override def toNullUnsafePartitioning: Partitioning = this + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this } case object BroadcastPartitioning extends Partitioning { @@ -171,15 +167,16 @@ case object BroadcastPartitioning extends Partitioning { override def keyExpressions: Seq[Expression] = Nil - override def toNullUnsafePartitioning: Partitioning = this + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this } /** * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be - * in the same partition. + * in the same partition. When `nullSafe` is true, for two null values in two rows evaluated + * by `clustering`, we consider these two nulls are equal. */ -case class NullSafeHashPartitioning(expressions: Seq[Expression], numPartitions: Int) +case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int, nullSafe: Boolean) extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions @@ -190,73 +187,36 @@ case class NullSafeHashPartitioning(expressions: Seq[Expression], numPartitions: override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true - case NullSafeClusteredDistribution(requiredClustering) => + case ClusteredDistribution(requiredClustering, _) if nullSafe => clusteringSet.subsetOf(requiredClustering.toSet) - case NullUnsafeClusteredDistribution(requiredClustering) => + case ClusteredDistribution(requiredClustering, false) if !nullSafe => clusteringSet.subsetOf(requiredClustering.toSet) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true - case h: NullSafeHashPartitioning if h == this => true + case h: HashPartitioning if h == this => true case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: NullSafeHashPartitioning => + case o: HashPartitioning if nullSafe => this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions - case o: NullUnsafeHashPartitioning => + case o: HashPartitioning if !nullSafe && !o.nullSafe => this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions case _ => false } override def keyExpressions: Seq[Expression] = expressions - override def toNullUnsafePartitioning: Partitioning = - NullUnsafeHashPartitioning(expressions, numPartitions) + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = + HashPartitioning(expressions, numPartitions, newNullSafe) } -/** - * Represents a partitioning where rows are split up across partitions based on the hash - * of `expressions`. All rows where `expressions` evaluate to the same non-null values are - * guaranteed to be in the same partition. For `expressions`, if a evaluated row - * has any null value, it is not guaranteed to be in the same partition with other - * rows having the same values. - * - * For example, Row(1, null) and Row(1, null) may not be in the same partition. - */ -case class NullUnsafeHashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression with Partitioning with Unevaluable { - - override def children: Seq[Expression] = expressions - override def nullable: Boolean = false - override def dataType: DataType = IntegerType - - lazy val clusteringSet = expressions.toSet - - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case NullUnsafeClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case h: NullUnsafeHashPartitioning if h == this => true - case _ => false - } - - override def guarantees(other: Partitioning): Boolean = other match { - case o: NullUnsafeHashPartitioning => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions - case _ => false - } - - override def keyExpressions: Seq[Expression] = expressions - - override def toNullUnsafePartitioning: Partitioning = this +object HashPartitioning { + def apply(expressions: Seq[Expression], numPartitions: Int): HashPartitioning = + HashPartitioning(expressions, numPartitions, nullSafe = true) } /** @@ -285,9 +245,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case NullSafeClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) - case NullUnsafeClusteredDistribution(requiredClustering) => + case ClusteredDistribution(requiredClustering, _) => clusteringSet.subsetOf(requiredClustering.toSet) case _ => false } @@ -304,7 +262,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def keyExpressions: Seq[Expression] = ordering.map(_.child) - override def toNullUnsafePartitioning: Partitioning = this + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this } /** @@ -338,8 +296,8 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def keyExpressions: Seq[Expression] = partitionings.head.keyExpressions - override def toNullUnsafePartitioning: Partitioning = { - PartitioningCollection(partitionings.map(_.toNullUnsafePartitioning)) + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = { + PartitioningCollection(partitionings.map(_.withNullSafeSetting(newNullSafe))) } override def toString: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 07fe8780145d3..c9e60988aa9e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -42,41 +42,41 @@ class DistributionSuite extends SparkFunSuite { } } - test("NullSafeHashPartitioning is the output partitioning") { + test("HashPartitioning (with nullSafe = true) is the output partitioning") { // Cases which do not need an exchange between two data properties. checkSatisfied( - NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), + HashPartitioning(Seq('a, 'b, 'c), 10), UnspecifiedDistribution, true) checkSatisfied( - NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), - NullSafeClusteredDistribution(Seq('a, 'b, 'c)), + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), true) checkSatisfied( - NullSafeHashPartitioning(Seq('b, 'c), 10), - NullSafeClusteredDistribution(Seq('a, 'b, 'c)), + HashPartitioning(Seq('b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), true) checkSatisfied( SinglePartition, - NullSafeClusteredDistribution(Seq('a, 'b, 'c)), + ClusteredDistribution(Seq('a, 'b, 'c)), true) checkSatisfied( - NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), - NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), false), true) checkSatisfied( - NullSafeHashPartitioning(Seq('b, 'c), 10), - NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + HashPartitioning(Seq('b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), false), true) checkSatisfied( SinglePartition, - NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + ClusteredDistribution(Seq('a, 'b, 'c), false), true) checkSatisfied( @@ -86,37 +86,37 @@ class DistributionSuite extends SparkFunSuite { // Cases which need an exchange between two data properties. checkSatisfied( - NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), - NullSafeClusteredDistribution(Seq('b, 'c)), + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('b, 'c)), false) checkSatisfied( - NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), - NullSafeClusteredDistribution(Seq('d, 'e)), + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('d, 'e)), false) checkSatisfied( - NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), - NullUnsafeClusteredDistribution(Seq('b, 'c)), + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('b, 'c), false), false) checkSatisfied( - NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), - NullUnsafeClusteredDistribution(Seq('d, 'e)), + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('d, 'e), false), false) checkSatisfied( - NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), + HashPartitioning(Seq('a, 'b, 'c), 10), AllTuples, false) checkSatisfied( - NullSafeHashPartitioning(Seq('a, 'b, 'c), 10), + HashPartitioning(Seq('a, 'b, 'c), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), false) checkSatisfied( - NullSafeHashPartitioning(Seq('b, 'c), 10), + HashPartitioning(Seq('b, 'c), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), false) @@ -129,28 +129,26 @@ class DistributionSuite extends SparkFunSuite { */ } - - - test("NullUnsafeHashPartitioning is the output partitioning") { + test("HashPartitioning (with nullSafe = false) is the output partitioning") { // Cases which do not need an exchange between two data properties. checkSatisfied( - NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), + HashPartitioning(Seq('a, 'b, 'c), 10, false), UnspecifiedDistribution, true) checkSatisfied( - NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), - NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('a, 'b, 'c), false), true) checkSatisfied( - NullUnsafeHashPartitioning(Seq('b, 'c), 10), - NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + HashPartitioning(Seq('b, 'c), 10, false), + ClusteredDistribution(Seq('a, 'b, 'c), false), true) checkSatisfied( SinglePartition, - NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + ClusteredDistribution(Seq('a, 'b, 'c), false), true) checkSatisfied( @@ -160,47 +158,47 @@ class DistributionSuite extends SparkFunSuite { // Cases which need an exchange between two data properties. checkSatisfied( - NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), - NullSafeClusteredDistribution(Seq('a, 'b, 'c)), + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('a, 'b, 'c)), false) checkSatisfied( - NullUnsafeHashPartitioning(Seq('b, 'c), 10), - NullSafeClusteredDistribution(Seq('a, 'b, 'c)), + HashPartitioning(Seq('b, 'c), 10, false), + ClusteredDistribution(Seq('a, 'b, 'c)), false) checkSatisfied( - NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), - NullSafeClusteredDistribution(Seq('b, 'c)), + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('b, 'c)), false) checkSatisfied( - NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), - NullSafeClusteredDistribution(Seq('d, 'e)), + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('d, 'e)), false) checkSatisfied( - NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), - NullUnsafeClusteredDistribution(Seq('b, 'c)), + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('b, 'c), false), false) checkSatisfied( - NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), - NullUnsafeClusteredDistribution(Seq('d, 'e)), + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('d, 'e), false), false) checkSatisfied( - NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), + HashPartitioning(Seq('a, 'b, 'c), 10, false), AllTuples, false) checkSatisfied( - NullUnsafeHashPartitioning(Seq('a, 'b, 'c), 10), + HashPartitioning(Seq('a, 'b, 'c), 10, false), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), false) checkSatisfied( - NullUnsafeHashPartitioning(Seq('b, 'c), 10), + HashPartitioning(Seq('b, 'c), 10, false), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), false) } @@ -229,32 +227,32 @@ class DistributionSuite extends SparkFunSuite { checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - NullSafeClusteredDistribution(Seq('a, 'b, 'c)), + ClusteredDistribution(Seq('a, 'b, 'c)), true) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - NullSafeClusteredDistribution(Seq('c, 'b, 'a)), + ClusteredDistribution(Seq('c, 'b, 'a)), true) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - NullSafeClusteredDistribution(Seq('b, 'c, 'a, 'd)), + ClusteredDistribution(Seq('b, 'c, 'a, 'd)), true) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - NullUnsafeClusteredDistribution(Seq('a, 'b, 'c)), + ClusteredDistribution(Seq('a, 'b, 'c), false), true) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - NullUnsafeClusteredDistribution(Seq('c, 'b, 'a)), + ClusteredDistribution(Seq('c, 'b, 'a), false), true) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - NullUnsafeClusteredDistribution(Seq('b, 'c, 'a, 'd)), + ClusteredDistribution(Seq('b, 'c, 'a, 'd), false), true) @@ -275,22 +273,22 @@ class DistributionSuite extends SparkFunSuite { checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - NullSafeClusteredDistribution(Seq('a, 'b)), + ClusteredDistribution(Seq('a, 'b)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - NullUnsafeClusteredDistribution(Seq('c, 'd)), + ClusteredDistribution(Seq('c, 'd), false), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - NullUnsafeClusteredDistribution(Seq('a, 'b)), + ClusteredDistribution(Seq('a, 'b), false), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - NullSafeClusteredDistribution(Seq('c, 'd)), + ClusteredDistribution(Seq('c, 'd)), false) checkSatisfied( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 3979e396008b2..e8c6a0f8f801d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -52,7 +52,7 @@ case class Aggregate( if (groupingExpressions == Nil) { AllTuples :: Nil } else { - NullSafeClusteredDistribution(groupingExpressions) :: Nil + ClusteredDistribution(groupingExpressions) :: Nil } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 987d85bc3c05e..d5a240a72f97f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -144,9 +144,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { val rdd = child.execute() val part: Partitioner = newPartitioning match { - case NullSafeHashPartitioning(expressions, numPartitions) => - new HashPartitioner(numPartitions) - case NullUnsafeHashPartitioning(expressions, numPartitions) => + case HashPartitioning(expressions, numPartitions, _) => new HashPartitioner(numPartitions) case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute @@ -166,7 +164,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // TODO: Handle BroadcastPartitioning. } def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match { - case NullSafeHashPartitioning(expressions, _) => + case HashPartitioning(expressions, _, true) => // Since NullSafeHashPartitioning and NullUnsafeHashPartitioning may be used together // for a join operator. We need to make sure they calculate the partition id with // the same way. @@ -180,7 +178,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una newMutableProjection(partitionId :: Nil, partitionExpressionSchema)() (row: InternalRow) => partitionIdExtractor(materalizeExpressions(row)) // newMutableProjection(expressions, child.output)() - case NullUnsafeHashPartitioning(expressions, numPartition) => + case HashPartitioning(expressions, numPartition, false) => // For NullUnsafeHashPartitioning, we do not want to send rows having any expression // in `expressions` evaluated as null to the same node. val materalizeExpressions = newMutableProjection(expressions, child.output)() @@ -314,21 +312,21 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ case (AllTuples, rowOrdering, child) => addOperatorsIfNecessary(SinglePartition, rowOrdering, child) - case (NullSafeClusteredDistribution(clustering), rowOrdering, child) => + case (ClusteredDistribution(clustering, true), rowOrdering, child) => addOperatorsIfNecessary( - NullSafeHashPartitioning(clustering, numPartitions), + HashPartitioning(clustering, numPartitions), rowOrdering, child) - case (NullUnsafeClusteredDistribution(clustering), rowOrdering, child) => + case (ClusteredDistribution(clustering, false), rowOrdering, child) => if (advancedSqlOptimizations) { addOperatorsIfNecessary( - NullUnsafeHashPartitioning(clustering, numPartitions), + HashPartitioning(clustering, numPartitions, false), rowOrdering, child) } else { addOperatorsIfNecessary( - NullSafeHashPartitioning(clustering, numPartitions), + HashPartitioning(clustering, numPartitions), rowOrdering, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 6c390fcb0b343..5ad4691a5ca07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -61,7 +61,7 @@ case class GeneratedAggregate( if (groupingExpressions == Nil) { AllTuples :: Nil } else { - NullSafeClusteredDistribution(groupingExpressions) :: Nil + ClusteredDistribution(groupingExpressions) :: Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ffbbac3137b93..924f34614b3fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -404,7 +404,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.RepartitionByExpression(expressions, child) => execution.Exchange( - NullSafeHashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index bdba16037f208..91c8a02e2b5bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -92,7 +92,7 @@ case class Window( logWarning("No Partition Defined for Window operation! Moving all data to a single " + "partition, this can cause serious performance degradation.") AllTuples :: Nil - } else NullSafeClusteredDistribution(windowSpec.partitionSpec) :: Nil + } else ClusteredDistribution(windowSpec.partitionSpec) :: Nil } override def requiredChildOrdering: Seq[Seq[SortOrder]] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala index f4f4eb7fafdeb..0c9082897f390 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, NullSafeClusteredDistribution, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} case class Aggregate2Sort( @@ -49,7 +49,7 @@ case class Aggregate2Sort( override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => NullSafeClusteredDistribution(exprs) :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil case None => UnspecifiedDistribution :: Nil } } @@ -144,7 +144,7 @@ case class FinalAndCompleteAggregate2Sort( if (groupingExpressions.isEmpty) { AllTuples :: Nil } else { - NullSafeClusteredDistribution(groupingExpressions) :: Nil + ClusteredDistribution(groupingExpressions) :: Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 18d32343494a3..965cc65b2c0e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, NullUnsafeClusteredDistribution, NullSafeClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -40,7 +40,7 @@ case class LeftSemiJoinHash( override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = - NullSafeClusteredDistribution(leftKeys) :: NullSafeClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index ca36aba4dd8f4..8a07200b9928d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -42,7 +42,7 @@ case class ShuffledHashJoin( PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) override def requiredChildDistribution: Seq[Distribution] = - NullSafeClusteredDistribution(leftKeys) :: NullSafeClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 579266e49aeeb..64f68511fb179 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -44,22 +44,22 @@ case class ShuffledHashOuterJoin( // It is a heuristic. We use NullUnsafeClusteredDistribution to // let input rows that will have a match distributed evenly. override def requiredChildDistribution: Seq[Distribution] = - NullUnsafeClusteredDistribution(leftKeys) :: - NullUnsafeClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, nullSafe = false) :: + ClusteredDistribution(rightKeys, nullSafe = false) :: Nil override def outputPartitioning: Partitioning = joinType match { case LeftOuter => val partitions = - Seq(left.outputPartitioning, right.outputPartitioning.toNullUnsafePartitioning) + Seq(left.outputPartitioning, right.outputPartitioning.withNullSafeSetting(true)) PartitioningCollection(partitions) case RightOuter => val partitions = - Seq(right.outputPartitioning, left.outputPartitioning.toNullUnsafePartitioning) + Seq(right.outputPartitioning, left.outputPartitioning.withNullSafeSetting(true)) PartitioningCollection(partitions) case FullOuter => val partitions = - Seq(left.outputPartitioning.toNullUnsafePartitioning, - right.outputPartitioning.toNullUnsafePartitioning) + Seq(left.outputPartitioning.withNullSafeSetting(true), + right.outputPartitioning.withNullSafeSetting(true)) PartitioningCollection(partitions) case x => throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 6dd7364f271f4..41be78afd37e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -44,7 +44,7 @@ case class SortMergeJoin( PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) override def requiredChildDistribution: Seq[Distribution] = - NullUnsafeClusteredDistribution(leftKeys) :: NullUnsafeClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil // this is to manually construct an ordering that can be used to compare keys from both sides private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) From c57a95465a2410fa515d6bbcf3dd0276a19f1d21 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 27 Jul 2015 16:39:49 -0700 Subject: [PATCH 09/15] Bug fix. --- .../spark/sql/execution/joins/ShuffledHashOuterJoin.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 64f68511fb179..bb6cfa40194eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -50,16 +50,16 @@ case class ShuffledHashOuterJoin( override def outputPartitioning: Partitioning = joinType match { case LeftOuter => val partitions = - Seq(left.outputPartitioning, right.outputPartitioning.withNullSafeSetting(true)) + Seq(left.outputPartitioning, right.outputPartitioning.withNullSafeSetting(false)) PartitioningCollection(partitions) case RightOuter => val partitions = - Seq(right.outputPartitioning, left.outputPartitioning.withNullSafeSetting(true)) + Seq(right.outputPartitioning, left.outputPartitioning.withNullSafeSetting(false)) PartitioningCollection(partitions) case FullOuter => val partitions = - Seq(left.outputPartitioning.withNullSafeSetting(true), - right.outputPartitioning.withNullSafeSetting(true)) + Seq(left.outputPartitioning.withNullSafeSetting(false), + right.outputPartitioning.withNullSafeSetting(false)) PartitioningCollection(partitions) case x => throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") From 303236bed06817befc2786f3c01e34b071f819f1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 29 Jul 2015 19:15:55 -0700 Subject: [PATCH 10/15] Revert changes that are unrelated to null join key filtering --- .../spark/sql/catalyst/expressions/misc.scala | 20 --- .../plans/physical/partitioning.scala | 117 ++-------------- .../sql/catalyst/DistributionSuite.scala | 127 +----------------- .../apache/spark/sql/execution/Exchange.scala | 64 +-------- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../joins/BroadcastHashOuterJoin.scala | 4 +- .../sql/execution/joins/HashOuterJoin.scala | 8 ++ .../execution/joins/LeftSemiJoinHash.scala | 6 +- .../execution/joins/ShuffledHashJoin.scala | 7 +- .../joins/ShuffledHashOuterJoin.scala | 25 +--- .../sql/execution/joins/SortMergeJoin.scala | 3 +- 11 files changed, 34 insertions(+), 350 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 44fa2a0bce738..8d8d66ddeb341 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,7 +21,6 @@ import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -161,22 +160,3 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp }) } } - -/** An expression that returns the hashCode of the input row. */ -case object RowHashCode extends LeafExpression { - override def dataType: DataType = IntegerType - - /** hashCode will never be null. */ - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - input.hashCode - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - s""" - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.primitive} = i.hashCode(); - """ - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 9c04d55f6fb5d..2dcfa19fec383 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -47,13 +47,9 @@ case object AllTuples extends Distribution * Represents data where tuples that share the same values for the `clustering` * [[Expression Expressions]] will be co-located. Based on the context, this * can mean such tuples are either co-located in the same partition or they will be contiguous - * within a single partition. When `nullSafe` is true, - * for two null values in two rows evaluated by `clustering`, - * we consider these two nulls are equal. + * within a single partition. */ -case class ClusteredDistribution( - clustering: Seq[Expression], - nullSafe: Boolean) extends Distribution { +case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + @@ -61,17 +57,11 @@ case class ClusteredDistribution( "a single partition.") } -object ClusteredDistribution { - def apply(clustering: Seq[Expression]): ClusteredDistribution = - ClusteredDistribution(clustering, nullSafe = true) -} - /** * Represents data where tuples have been ordered according to the `ordering` * [[Expression Expressions]]. This is a strictly stronger guarantee than - * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the - * same value for the ordering expressions are contiguous and will never be split across - * partitions. + * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for + * the ordering expressions are contiguous and will never be split across partitions. */ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { require( @@ -104,12 +94,8 @@ sealed trait Partitioning { */ def compatibleWith(other: Partitioning): Boolean - def guarantees(other: Partitioning): Boolean - /** Returns the expressions that are used to key the partitioning. */ def keyExpressions: Seq[Expression] - - def withNullSafeSetting(newNullSafe: Boolean): Partitioning } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -123,11 +109,7 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case _ => false } - override def guarantees(other: Partitioning): Boolean = false - override def keyExpressions: Seq[Expression] = Nil - - override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this } case object SinglePartition extends Partitioning { @@ -140,14 +122,7 @@ case object SinglePartition extends Partitioning { case _ => false } - override def guarantees(other: Partitioning): Boolean = other match { - case SinglePartition => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil - - override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this } case object BroadcastPartitioning extends Partitioning { @@ -160,36 +135,26 @@ case object BroadcastPartitioning extends Partitioning { case _ => false } - override def guarantees(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil - - override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this } /** * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be - * in the same partition. When `nullSafe` is true, for two null values in two rows evaluated - * by `clustering`, we consider these two nulls are equal. + * in the same partition. */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int, nullSafe: Boolean) +case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions override def nullable: Boolean = false override def dataType: DataType = IntegerType - lazy val clusteringSet = expressions.toSet + private[this] lazy val clusteringSet = expressions.toSet override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering, _) if nullSafe => - clusteringSet.subsetOf(requiredClustering.toSet) - case ClusteredDistribution(requiredClustering, false) if !nullSafe => + case ClusteredDistribution(requiredClustering) => clusteringSet.subsetOf(requiredClustering.toSet) case _ => false } @@ -200,23 +165,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int, nu case _ => false } - override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning if nullSafe => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions - case o: HashPartitioning if !nullSafe && !o.nullSafe => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions - case _ => false - } - override def keyExpressions: Seq[Expression] = expressions - - override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = - HashPartitioning(expressions, numPartitions, newNullSafe) -} - -object HashPartitioning { - def apply(expressions: Seq[Expression], numPartitions: Int): HashPartitioning = - HashPartitioning(expressions, numPartitions, nullSafe = true) } /** @@ -245,62 +194,16 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, _) => + case ClusteredDistribution(requiredClustering) => clusteringSet.subsetOf(requiredClustering.toSet) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true - case _ => false - } - - override def guarantees(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this == o + case r: RangePartitioning if r == this => true case _ => false } override def keyExpressions: Seq[Expression] = ordering.map(_.child) - - override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this -} - -/** - * A collection of [[Partitioning]]s. - */ -case class PartitioningCollection(partitionings: Seq[Partitioning]) - extends Expression with Partitioning with Unevaluable { - - require( - partitionings.map(_.numPartitions).distinct.length == 1, - s"PartitioningCollection requires all of its partitionings have the same numPartitions.") - - override def children: Seq[Expression] = partitionings.collect { - case expr: Expression => expr - } - - override def nullable: Boolean = false - - override def dataType: DataType = IntegerType - - override val numPartitions = partitionings.map(_.numPartitions).distinct.head - - override def satisfies(required: Distribution): Boolean = - partitionings.exists(_.satisfies(required)) - - override def compatibleWith(other: Partitioning): Boolean = - partitionings.exists(_.compatibleWith(other)) - - override def guarantees(other: Partitioning): Boolean = - partitionings.exists(_.guarantees(other)) - - override def keyExpressions: Seq[Expression] = partitionings.head.keyExpressions - - override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = { - PartitioningCollection(partitionings.map(_.withNullSafeSetting(newNullSafe))) - } - - override def toString: String = { - partitionings.map(_.toString).mkString("(", " or ", ")") - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index c9e60988aa9e0..c046dbf4dc2c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -42,7 +42,7 @@ class DistributionSuite extends SparkFunSuite { } } - test("HashPartitioning (with nullSafe = true) is the output partitioning") { + test("HashPartitioning is the output partitioning") { // Cases which do not need an exchange between two data properties. checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), @@ -64,21 +64,6 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('a, 'b, 'c)), true) - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c), false), - true) - - checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c), false), - true) - - checkSatisfied( - SinglePartition, - ClusteredDistribution(Seq('a, 'b, 'c), false), - true) - checkSatisfied( SinglePartition, OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), @@ -95,16 +80,6 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('d, 'e)), false) - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('b, 'c), false), - false) - - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('d, 'e), false), - false) - checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), AllTuples, @@ -129,80 +104,6 @@ class DistributionSuite extends SparkFunSuite { */ } - test("HashPartitioning (with nullSafe = false) is the output partitioning") { - // Cases which do not need an exchange between two data properties. - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10, false), - UnspecifiedDistribution, - true) - - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10, false), - ClusteredDistribution(Seq('a, 'b, 'c), false), - true) - - checkSatisfied( - HashPartitioning(Seq('b, 'c), 10, false), - ClusteredDistribution(Seq('a, 'b, 'c), false), - true) - - checkSatisfied( - SinglePartition, - ClusteredDistribution(Seq('a, 'b, 'c), false), - true) - - checkSatisfied( - SinglePartition, - OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - true) - - // Cases which need an exchange between two data properties. - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10, false), - ClusteredDistribution(Seq('a, 'b, 'c)), - false) - - checkSatisfied( - HashPartitioning(Seq('b, 'c), 10, false), - ClusteredDistribution(Seq('a, 'b, 'c)), - false) - - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10, false), - ClusteredDistribution(Seq('b, 'c)), - false) - - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10, false), - ClusteredDistribution(Seq('d, 'e)), - false) - - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10, false), - ClusteredDistribution(Seq('b, 'c), false), - false) - - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10, false), - ClusteredDistribution(Seq('d, 'e), false), - false) - - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10, false), - AllTuples, - false) - - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10, false), - OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - false) - - checkSatisfied( - HashPartitioning(Seq('b, 'c), 10, false), - OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - false) - } - test("RangePartitioning is the output partitioning") { // Cases which do not need an exchange between two data properties. checkSatisfied( @@ -240,22 +141,6 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('b, 'c, 'a, 'd)), true) - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b, 'c), false), - true) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'b, 'a), false), - true) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('b, 'c, 'a, 'd), false), - true) - - // Cases which need an exchange between two data properties. // TODO: We can have an optimization to first sort the dataset // by a.asc and then sort b, and c in a partition. This optimization @@ -276,16 +161,6 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('a, 'b)), false) - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'd), false), - false) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b), false), - false) - checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), ClusteredDistribution(Seq('c, 'd)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index d5a240a72f97f..41a0c519ba527 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} @@ -144,8 +143,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { val rdd = child.execute() val part: Partitioner = newPartitioning match { - case HashPartitioning(expressions, numPartitions, _) => - new HashPartitioner(numPartitions) + case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. @@ -164,38 +162,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // TODO: Handle BroadcastPartitioning. } def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match { - case HashPartitioning(expressions, _, true) => - // Since NullSafeHashPartitioning and NullUnsafeHashPartitioning may be used together - // for a join operator. We need to make sure they calculate the partition id with - // the same way. - val materalizeExpressions = newMutableProjection(expressions, child.output)() - val partitionExpressionSchema = expressions.map { - case ne: NamedExpression => ne.toAttribute - case expr => Alias(expr, "partitionExpr")().toAttribute - } - val partitionId = RowHashCode - val partitionIdExtractor = - newMutableProjection(partitionId :: Nil, partitionExpressionSchema)() - (row: InternalRow) => partitionIdExtractor(materalizeExpressions(row)) - // newMutableProjection(expressions, child.output)() - case HashPartitioning(expressions, numPartition, false) => - // For NullUnsafeHashPartitioning, we do not want to send rows having any expression - // in `expressions` evaluated as null to the same node. - val materalizeExpressions = newMutableProjection(expressions, child.output)() - val partitionExpressionSchema = expressions.map { - case ne: NamedExpression => ne.toAttribute - case expr => Alias(expr, "partitionExpr")().toAttribute - } - val partitionId = - If( - Not(AtLeastNNulls(1, partitionExpressionSchema)), - // There is no null value in the partition expressions, we can just get the - // hashCode of the input row. - RowHashCode, - Cast(Multiply(new Rand(numPartition), Literal(numPartition.toDouble)), IntegerType)) - val partitionIdExtractor = - newMutableProjection(partitionId :: Nil, partitionExpressionSchema)() - (row: InternalRow) => partitionIdExtractor(materalizeExpressions(row)) + case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } @@ -228,8 +195,6 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ // TODO: Determine the number of partitions. def numPartitions: Int = sqlContext.conf.numShufflePartitions - def advancedSqlOptimizations: Boolean = sqlContext.conf.advancedSqlOptimizations - def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => // True iff every child's outputPartitioning satisfies the corresponding @@ -274,7 +239,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ child: SparkPlan): SparkPlan = { def addShuffleIfNecessary(child: SparkPlan): SparkPlan = { - if (!child.outputPartitioning.guarantees(partitioning)) { + if (child.outputPartitioning != partitioning) { Exchange(partitioning, child) } else { child @@ -311,32 +276,13 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ val fixedChildren = requirements.zipped.map { case (AllTuples, rowOrdering, child) => addOperatorsIfNecessary(SinglePartition, rowOrdering, child) - - case (ClusteredDistribution(clustering, true), rowOrdering, child) => - addOperatorsIfNecessary( - HashPartitioning(clustering, numPartitions), - rowOrdering, - child) - - case (ClusteredDistribution(clustering, false), rowOrdering, child) => - if (advancedSqlOptimizations) { - addOperatorsIfNecessary( - HashPartitioning(clustering, numPartitions, false), - rowOrdering, - child) - } else { - addOperatorsIfNecessary( - HashPartitioning(clustering, numPartitions), - rowOrdering, - child) - } - + case (ClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) case (OrderedDistribution(ordering), rowOrdering, child) => addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) case (UnspecifiedDistribution, Seq(), child) => child - case (UnspecifiedDistribution, rowOrdering, child) => sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ac91abafeb734..f3ef066528ff8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -411,8 +411,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.RepartitionByExpression(expressions, child) => - execution.Exchange( - HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 309716a0efcc0..77e7fe71009b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.ThreadUtils @@ -57,8 +57,6 @@ case class BroadcastHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - @transient private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 438f5874ca5f8..7e671e7914f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -38,6 +38,14 @@ trait HashOuterJoin { val left: SparkPlan val right: SparkPlan + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") + } + override def output: Seq[Attribute] = { joinType match { case LeftOuter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 68ccd34d8ed9b..26a664104d6fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -37,9 +37,7 @@ case class LeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { - override def outputPartitioning: Partitioning = left.outputPartitioning - - override def requiredChildDistribution: Seq[Distribution] = + override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index fc6efe87bceb5..5439e10a60b2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -38,10 +38,9 @@ case class ShuffledHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - override def outputPartitioning: Partitioning = - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + override def outputPartitioning: Partitioning = left.outputPartitioning - override def requiredChildDistribution: Seq[Distribution] = + override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 1ffdb570e8c6b..d29b593207c4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution} import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -41,29 +41,8 @@ case class ShuffledHashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashOuterJoin { - // It is a heuristic. We use NullUnsafeClusteredDistribution to - // let input rows that will have a match distributed evenly. override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys, nullSafe = false) :: - ClusteredDistribution(rightKeys, nullSafe = false) :: Nil - - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => - val partitions = - Seq(left.outputPartitioning, right.outputPartitioning.withNullSafeSetting(false)) - PartitioningCollection(partitions) - case RightOuter => - val partitions = - Seq(right.outputPartitioning, left.outputPartitioning.withNullSafeSetting(false)) - PartitioningCollection(partitions) - case FullOuter => - val partitions = - Seq(left.outputPartitioning.withNullSafeSetting(false), - right.outputPartitioning.withNullSafeSetting(false)) - PartitioningCollection(partitions) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { val joinedRow = new JoinedRow() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 41be78afd37e6..bb18b5403f8e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -40,8 +40,7 @@ case class SortMergeJoin( override def output: Seq[Attribute] = left.output ++ right.output - override def outputPartitioning: Partitioning = - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil From 8bb39adfd63060fcfaae0ddfcd7ccb1b4def3ab4 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 29 Jul 2015 20:28:28 -0700 Subject: [PATCH 11/15] Fix non-deterministic tests. --- .../optimizer/FilterNullsInJoinKeySuite.scala | 33 ++++++++----------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala index 2a6b29f6e96f5..a6039939cf11b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala @@ -30,7 +30,11 @@ import org.apache.spark.sql.test.TestSQLContext class FilterNullsInJoinKeySuite extends PlanTest { // We add predicate pushdown rules at here to make sure we do not - // create redundant + // create redundant Filter operators. Also, because the attribute ordering of + // the Project operator added by ColumnPruning may be not deterministic + // (the ordering may depend on the testing environment), + // we first construct the plan with expected Filter operators and then + // run the optimizer to add the the Project for column pruning. object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, @@ -65,7 +69,6 @@ class FilterNullsInJoinKeySuite extends PlanTest { val correctLeft = leftRelation .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - .select('a, 'b, 'd) val correctRight = rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) @@ -74,9 +77,8 @@ class FilterNullsInJoinKeySuite extends PlanTest { correctLeft .join(correctRight, Inner, Some(joinCondition)) .select('a, 'f, 'd, 'h) - .analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) } test("inner join (partially optimized)") { @@ -96,12 +98,10 @@ class FilterNullsInJoinKeySuite extends PlanTest { val correctAnswer = leftRelation - .select('a, 'b, 'd) .join(correctRight, Inner, Some(joinCondition)) .select('a, 'f, 'd, 'h) - .analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) } test("inner join (not optimized)") { @@ -113,13 +113,12 @@ class FilterNullsInJoinKeySuite extends PlanTest { nonOptimizedJoinConditions.foreach { joinCondition => val joinedPlan = leftRelation - .select('a, 'c, 'd) .join(rightRelation.select('f, 'g, 'h), Inner, joinCondition) .select('a, 'c, 'f, 'd, 'h, 'g) val optimized = Optimize.execute(joinedPlan.analyze) - comparePlans(optimized, joinedPlan.analyze) + comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) } } @@ -140,12 +139,10 @@ class FilterNullsInJoinKeySuite extends PlanTest { val correctAnswer = leftRelation - .select('a, 'b, 'd) .join(correctRight, LeftOuter, Some(joinCondition)) .select('a, 'f, 'd, 'h) - .analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) } test("right outer join") { @@ -163,15 +160,14 @@ class FilterNullsInJoinKeySuite extends PlanTest { val correctLeft = leftRelation .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - .select('a, 'b, 'd) val correctAnswer = correctLeft .join(rightRelation, RightOuter, Some(joinCondition)) .select('a, 'f, 'd, 'h) - .analyze - comparePlans(optimized, correctAnswer) + + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) } test("full outer join") { @@ -180,14 +176,13 @@ class FilterNullsInJoinKeySuite extends PlanTest { val joinedPlan = leftRelation - .select('a, 'b, 'd) .join(rightRelation, FullOuter, Some(joinCondition)) .select('a, 'f, 'd, 'h) // FilterNullsInJoinKey does not fire for a full outer join. val optimized = Optimize.execute(joinedPlan.analyze) - comparePlans(optimized, joinedPlan.analyze) + comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) } test("left semi join") { @@ -205,7 +200,6 @@ class FilterNullsInJoinKeySuite extends PlanTest { val correctLeft = leftRelation .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - .select('a, 'b, 'd) val correctRight = rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) @@ -214,8 +208,7 @@ class FilterNullsInJoinKeySuite extends PlanTest { correctLeft .join(correctRight, LeftSemi, Some(joinCondition)) .select('a, 'd) - .analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) } } From be887607cab2d389c6d1c9ad16e9b6e4c84b6651 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 29 Jul 2015 20:31:17 -0700 Subject: [PATCH 12/15] Make it clear that FilterNullsInJoinKeySuite.scala is used to test FilterNullsInJoinKey. --- .../apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala index a6039939cf11b..26c8cc3f73ce2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Logic import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.test.TestSQLContext +/** This is the test suite for FilterNullsInJoinKey optimziation rule. */ class FilterNullsInJoinKeySuite extends PlanTest { // We add predicate pushdown rules at here to make sure we do not @@ -40,7 +41,7 @@ class FilterNullsInJoinKeySuite extends PlanTest { Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Operator Optimizations", FixedPoint(100), - FilterNullsInJoinKey(TestSQLContext), + FilterNullsInJoinKey(TestSQLContext), // This is the rule we test in this suite. CombineFilters, PushPredicateThroughProject, BooleanSimplification, From ea7d5a62c7be65d7acf5a714ece61a75a9783778 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 29 Jul 2015 23:20:35 -0700 Subject: [PATCH 13/15] Make sure we do not keep adding filters. --- .../sql/catalyst/optimizer/Optimizer.scala | 6 +++++ .../optimizer/FilterNullsInJoinKeySuite.scala | 23 ++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 41ee3eb141bb2..04c35c7f15005 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -529,6 +529,12 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { */ object CombineFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Filter(Not(AtLeastNNulls(1, e1)), Filter(Not(AtLeastNNulls(1, e2)), grandChild)) => + // If we are combining two expressions Not(AtLeastNNulls(1, e1)) and + // Not(AtLeastNNulls(1, e2)) + // (this is used to make sure there is no null in the result of e1 and e2), we can + // just create a Not(AtLeastNNulls(1, (e1 ++ e2).distinct)). + Filter(Not(AtLeastNNulls(1, (e1 ++ e2).distinct)), grandChild) case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala index 26c8cc3f73ce2..0b312a8946136 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.AtLeastNNulls +import org.apache.spark.sql.catalyst.expressions.{Not, AtLeastNNulls} import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} @@ -82,6 +82,27 @@ class FilterNullsInJoinKeySuite extends PlanTest { comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) } + test("make sure we do not keep adding filters") { + val thirdRelation = LocalRelation('i.int, 'j.int, 'k.int, 'l.int) + val joinedPlan = + leftRelation + .join(rightRelation, Inner, Some('a === 'e)) + .join(thirdRelation, Inner, Some('b === 'i && 'a === 'j)) + + val optimized = Optimize.execute(joinedPlan.analyze) + val conditions = optimized.collect { + case Filter(condition @ Not(AtLeastNNulls(1, exprs)), _) => exprs + } + + // Make sure that we have three Not(AtLeastNNulls(1, exprs)) for those three tables. + assert(conditions.length === 3) + + // Make sure attribtues are indded a, b, e, i, and j. + assert( + conditions.flatMap(exprs => exprs).toSet === + joinedPlan.select('a, 'b, 'e, 'i, 'j).analyze.output.toSet) + } + test("inner join (partially optimized)") { val joinCondition = ('a + 2 === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) From 0a8e096df0cb7af6ba5e42d54f45fd8014a29ebf Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 29 Jul 2015 23:21:45 -0700 Subject: [PATCH 14/15] Update comments. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 3 ++- .../apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 04c35c7f15005..e0ef8f50981c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -532,7 +532,8 @@ object CombineFilters extends Rule[LogicalPlan] { case Filter(Not(AtLeastNNulls(1, e1)), Filter(Not(AtLeastNNulls(1, e2)), grandChild)) => // If we are combining two expressions Not(AtLeastNNulls(1, e1)) and // Not(AtLeastNNulls(1, e2)) - // (this is used to make sure there is no null in the result of e1 and e2), we can + // (this is used to make sure there is no null in the result of e1 and e2 and + // they are added by FilterNullsInJoinKey optimziation rule), we can // just create a Not(AtLeastNNulls(1, (e1 ++ e2).distinct)). Filter(Not(AtLeastNNulls(1, (e1 ++ e2).distinct)), grandChild) case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala index 0b312a8946136..4161927e396ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Logic import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.test.TestSQLContext -/** This is the test suite for FilterNullsInJoinKey optimziation rule. */ +/** This is the test suite for FilterNullsInJoinKey optimization rule. */ class FilterNullsInJoinKeySuite extends PlanTest { // We add predicate pushdown rules at here to make sure we do not From c02fc3f4179df860ebb8c24614247c016c3603e6 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 2 Aug 2015 21:30:31 -0700 Subject: [PATCH 15/15] Address Josh's comments. --- .../optimizer/extendedOperatorOptimizations.scala | 13 ++++--------- .../sql/optimizer/FilterNullsInJoinKeySuite.scala | 2 +- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala index 929d6b429aaf5..5a4dde5756964 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala @@ -68,10 +68,7 @@ case class FilterNullsInJoinKey( keys: Seq[Expression], child: LogicalPlan): LogicalPlan = { // We get all attributes from keys. - val attributes = keys.filter { - case attr: Attribute => true - case _ => false - } + val attributes = keys.filter(_.isInstanceOf[Attribute]) // Then, we create a Filter to make sure these attributes are non-nullable. val filter = @@ -81,8 +78,6 @@ case class FilterNullsInJoinKey( child } - // We return attributes representing keys (keyAttributes) and the filter. - // keyAttributes will be used to rewrite the join condition. filter } @@ -90,9 +85,9 @@ case class FilterNullsInJoinKey( * We reconstruct the join condition. */ private def reconstructJoinCondition( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - otherPredicate: Option[Expression]): Expression = { + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + otherPredicate: Option[Expression]): Expression = { // First, we rewrite the equal condition part. When we extract those keys, // we use splitConjunctivePredicates. So, it is safe to use .reduce(And). val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala index 4161927e396ef..f98e4acafbf2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala @@ -97,7 +97,7 @@ class FilterNullsInJoinKeySuite extends PlanTest { // Make sure that we have three Not(AtLeastNNulls(1, exprs)) for those three tables. assert(conditions.length === 3) - // Make sure attribtues are indded a, b, e, i, and j. + // Make sure attribtues are indeed a, b, e, i, and j. assert( conditions.flatMap(exprs => exprs).toSet === joinedPlan.select('a, 'b, 'e, 'i, 'j).analyze.output.toSet)