diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 8d8d66ddeb341..44fa2a0bce738 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -160,3 +161,22 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp }) } } + +/** An expression that returns the hashCode of the input row. */ +case object RowHashCode extends LeafExpression { + override def dataType: DataType = IntegerType + + /** hashCode will never be null. */ + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + input.hashCode + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = i.hashCode(); + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 287718fab7f0d..d58c4756938c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -210,14 +210,58 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } } +/** + * A predicate that is evaluated to be true if there are at least `n` null values. + */ +case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate { + override def nullable: Boolean = false + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})" + + private[this] val childrenArray = children.toArray + + override def eval(input: InternalRow): Boolean = { + var numNulls = 0 + var i = 0 + while (i < childrenArray.length && numNulls < n) { + val evalC = childrenArray(i).eval(input) + if (evalC == null) { + numNulls += 1 + } + i += 1 + } + numNulls >= n + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val numNulls = ctx.freshName("numNulls") + val code = children.map { e => + val eval = e.gen(ctx) + s""" + if ($numNulls < $n) { + ${eval.code} + if (${eval.isNull}) { + $numNulls += 1; + } + } + """ + }.mkString("\n") + s""" + int $numNulls = 0; + $code + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = $numNulls >= $n; + """ + } +} /** * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values. */ -case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { +case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" + override def toString: String = s"AtLeastNNonNullNans($n, ${children.mkString(",")})" private[this] val childrenArray = children.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 813c62009666c..41ee3eb141bb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -31,8 +31,14 @@ import org.apache.spark.sql.types._ abstract class Optimizer extends RuleExecutor[LogicalPlan] -object DefaultOptimizer extends Optimizer { - val batches = +class DefaultOptimizer extends Optimizer { + + /** + * Override to provide additional rules for the "Operator Optimizations" batch. + */ + val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + + lazy val batches = // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: @@ -41,26 +47,27 @@ object DefaultOptimizer extends Optimizer { RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down - SetOperationPushDown, - SamplePushDown, - PushPredicateThroughJoin, - PushPredicateThroughProject, - PushPredicateThroughGenerate, - ColumnPruning, + SetOperationPushDown :: + SamplePushDown :: + PushPredicateThroughJoin :: + PushPredicateThroughProject :: + PushPredicateThroughGenerate :: + ColumnPruning :: // Operator combine - ProjectCollapsing, - CombineFilters, - CombineLimits, + ProjectCollapsing :: + CombineFilters :: + CombineLimits :: // Constant folding - NullPropagation, - OptimizeIn, - ConstantFolding, - LikeSimplification, - BooleanSimplification, - RemovePositive, - SimplifyFilters, - SimplifyCasts, - SimplifyCaseConversionExpressions) :: + NullPropagation :: + OptimizeIn :: + ConstantFolding :: + LikeSimplification :: + BooleanSimplification :: + RemovePositive :: + SimplifyFilters :: + SimplifyCasts :: + SimplifyCaseConversionExpressions :: + extendedOperatorOptimizationRules.toList : _*) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -222,12 +229,18 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** Applies a projection only when the child is producing unnecessary attributes */ - private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = + private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = { if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - Project(allReferences.filter(c.outputSet.contains).toSeq, c) + // We need to preserve the nullability of c's output. + // So, we first create a outputMap and if a reference is from the output of + // c, we use that output attribute from c. + val outputMap = AttributeMap(c.output.map(attr => (attr, attr))) + val projectList = allReferences.filter(outputMap.contains).map(outputMap).toSeq + Project(projectList, c) } else { c } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index af68358daf5f1..77f294bc43a19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -85,7 +85,37 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + /** + * Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children + * have at least one null value and atLeastNNulls.children are all attributes. + */ + private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): Boolean = { + val expressions = atLeastNNulls.children + val n = atLeastNNulls.n + if (n != 1) { + // AtLeastNNulls is not used to check if atLeastNNulls.children have + // at least one null value. + false + } else { + // AtLeastNNulls is used to check if atLeastNNulls.children have + // at least one null value. We need to make sure all atLeastNNulls.children + // are attributes. + expressions.forall(_.isInstanceOf[Attribute]) + } + } + + override def output: Seq[Attribute] = condition match { + case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) => + // The condition is used to make sure that there is no null value in + // a.children. + val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]]) + child.output.map { + case attr if nonNullableAttributes.contains(attr) => + attr.withNullability(false) + case attr => attr + } + case _ => child.output + } } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2dcfa19fec383..9c04d55f6fb5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -47,9 +47,13 @@ case object AllTuples extends Distribution * Represents data where tuples that share the same values for the `clustering` * [[Expression Expressions]] will be co-located. Based on the context, this * can mean such tuples are either co-located in the same partition or they will be contiguous - * within a single partition. + * within a single partition. When `nullSafe` is true, + * for two null values in two rows evaluated by `clustering`, + * we consider these two nulls are equal. */ -case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution { +case class ClusteredDistribution( + clustering: Seq[Expression], + nullSafe: Boolean) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + @@ -57,11 +61,17 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi "a single partition.") } +object ClusteredDistribution { + def apply(clustering: Seq[Expression]): ClusteredDistribution = + ClusteredDistribution(clustering, nullSafe = true) +} + /** * Represents data where tuples have been ordered according to the `ordering` * [[Expression Expressions]]. This is a strictly stronger guarantee than - * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for - * the ordering expressions are contiguous and will never be split across partitions. + * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the + * same value for the ordering expressions are contiguous and will never be split across + * partitions. */ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { require( @@ -94,8 +104,12 @@ sealed trait Partitioning { */ def compatibleWith(other: Partitioning): Boolean + def guarantees(other: Partitioning): Boolean + /** Returns the expressions that are used to key the partitioning. */ def keyExpressions: Seq[Expression] + + def withNullSafeSetting(newNullSafe: Boolean): Partitioning } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -109,7 +123,11 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case _ => false } + override def guarantees(other: Partitioning): Boolean = false + override def keyExpressions: Seq[Expression] = Nil + + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this } case object SinglePartition extends Partitioning { @@ -122,7 +140,14 @@ case object SinglePartition extends Partitioning { case _ => false } + override def guarantees(other: Partitioning): Boolean = other match { + case SinglePartition => true + case _ => false + } + override def keyExpressions: Seq[Expression] = Nil + + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this } case object BroadcastPartitioning extends Partitioning { @@ -135,26 +160,36 @@ case object BroadcastPartitioning extends Partitioning { case _ => false } + override def guarantees(other: Partitioning): Boolean = other match { + case BroadcastPartitioning => true + case _ => false + } + override def keyExpressions: Seq[Expression] = Nil + + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this } /** * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be - * in the same partition. + * in the same partition. When `nullSafe` is true, for two null values in two rows evaluated + * by `clustering`, we consider these two nulls are equal. */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) +case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int, nullSafe: Boolean) extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions override def nullable: Boolean = false override def dataType: DataType = IntegerType - private[this] lazy val clusteringSet = expressions.toSet + lazy val clusteringSet = expressions.toSet override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering) => + case ClusteredDistribution(requiredClustering, _) if nullSafe => + clusteringSet.subsetOf(requiredClustering.toSet) + case ClusteredDistribution(requiredClustering, false) if !nullSafe => clusteringSet.subsetOf(requiredClustering.toSet) case _ => false } @@ -165,7 +200,23 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } + override def guarantees(other: Partitioning): Boolean = other match { + case o: HashPartitioning if nullSafe => + this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case o: HashPartitioning if !nullSafe && !o.nullSafe => + this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case _ => false + } + override def keyExpressions: Seq[Expression] = expressions + + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = + HashPartitioning(expressions, numPartitions, newNullSafe) +} + +object HashPartitioning { + def apply(expressions: Seq[Expression], numPartitions: Int): HashPartitioning = + HashPartitioning(expressions, numPartitions, nullSafe = true) } /** @@ -194,16 +245,62 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering) => + case ClusteredDistribution(requiredClustering, _) => clusteringSet.subsetOf(requiredClustering.toSet) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true - case r: RangePartitioning if r == this => true + case _ => false + } + + override def guarantees(other: Partitioning): Boolean = other match { + case o: RangePartitioning => this == o case _ => false } override def keyExpressions: Seq[Expression] = ordering.map(_.child) + + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this +} + +/** + * A collection of [[Partitioning]]s. + */ +case class PartitioningCollection(partitionings: Seq[Partitioning]) + extends Expression with Partitioning with Unevaluable { + + require( + partitionings.map(_.numPartitions).distinct.length == 1, + s"PartitioningCollection requires all of its partitionings have the same numPartitions.") + + override def children: Seq[Expression] = partitionings.collect { + case expr: Expression => expr + } + + override def nullable: Boolean = false + + override def dataType: DataType = IntegerType + + override val numPartitions = partitionings.map(_.numPartitions).distinct.head + + override def satisfies(required: Distribution): Boolean = + partitionings.exists(_.satisfies(required)) + + override def compatibleWith(other: Partitioning): Boolean = + partitionings.exists(_.compatibleWith(other)) + + override def guarantees(other: Partitioning): Boolean = + partitionings.exists(_.guarantees(other)) + + override def keyExpressions: Seq[Expression] = partitionings.head.keyExpressions + + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = { + PartitioningCollection(partitionings.map(_.withNullSafeSetting(newNullSafe))) + } + + override def toString: String = { + partitionings.map(_.toString).mkString("(", " or ", ")") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index c046dbf4dc2c9..c9e60988aa9e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -42,7 +42,7 @@ class DistributionSuite extends SparkFunSuite { } } - test("HashPartitioning is the output partitioning") { + test("HashPartitioning (with nullSafe = true) is the output partitioning") { // Cases which do not need an exchange between two data properties. checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), @@ -64,6 +64,21 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('a, 'b, 'c)), true) + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), false), + true) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), false), + true) + + checkSatisfied( + SinglePartition, + ClusteredDistribution(Seq('a, 'b, 'c), false), + true) + checkSatisfied( SinglePartition, OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), @@ -80,6 +95,16 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('d, 'e)), false) + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('b, 'c), false), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('d, 'e), false), + false) + checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), AllTuples, @@ -104,6 +129,80 @@ class DistributionSuite extends SparkFunSuite { */ } + test("HashPartitioning (with nullSafe = false) is the output partitioning") { + // Cases which do not need an exchange between two data properties. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + UnspecifiedDistribution, + true) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('a, 'b, 'c), false), + true) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10, false), + ClusteredDistribution(Seq('a, 'b, 'c), false), + true) + + checkSatisfied( + SinglePartition, + ClusteredDistribution(Seq('a, 'b, 'c), false), + true) + + checkSatisfied( + SinglePartition, + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + true) + + // Cases which need an exchange between two data properties. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('a, 'b, 'c)), + false) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10, false), + ClusteredDistribution(Seq('a, 'b, 'c)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('b, 'c)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('d, 'e)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('b, 'c), false), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('d, 'e), false), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + AllTuples, + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10, false), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) + } + test("RangePartitioning is the output partitioning") { // Cases which do not need an exchange between two data properties. checkSatisfied( @@ -141,6 +240,22 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('b, 'c, 'a, 'd)), true) + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('a, 'b, 'c), false), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('c, 'b, 'a), false), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('b, 'c, 'a, 'd), false), + true) + + // Cases which need an exchange between two data properties. // TODO: We can have an optimization to first sort the dataset // by a.asc and then sort b, and c in a partition. This optimization @@ -161,6 +276,16 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('a, 'b)), false) + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('c, 'd), false), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('a, 'b), false), + false) + checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), ClusteredDistribution(Seq('c, 'd)), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index ab0cdc857c80e..9211341faf4f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -33,6 +33,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} trait ExpressionEvalHelper { self: SparkFunSuite => + protected val defaultOptimizer = new DefaultOptimizer + protected def create_row(values: Any*): InternalRow = { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } @@ -177,7 +179,7 @@ trait ExpressionEvalHelper { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) + val optimizedPlan = defaultOptimizer.execute(plan) checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 21459a7c69838..c7ee172c0f252 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ @@ -168,7 +167,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) + val optimizedPlan = defaultOptimizer.execute(plan) checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index ace6c15dc8418..bf197124d8dbc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -77,7 +77,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("AtLeastNNonNulls") { + test("AtLeastNNonNullNans") { val mix = Seq(Literal("x"), Literal.create(null, StringType), Literal.create(null, DoubleType), @@ -96,11 +96,46 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal(Float.MaxValue), Literal(false)) - checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow) - checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow) - checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(0, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(0, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(3, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(4, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(0, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(3, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(4, nullOnly), false, EmptyRow) + } + + test("AtLeastNNull") { + val mix = Seq(Literal("x"), + Literal.create(null, StringType), + Literal.create(null, DoubleType), + Literal(Double.NaN), + Literal(5f)) + + val nanOnly = Seq(Literal("x"), + Literal(10.0), + Literal(Float.NaN), + Literal(math.log(-2)), + Literal(Double.MaxValue)) + + val nullOnly = Seq(Literal("x"), + Literal.create(null, DoubleType), + Literal.create(null, DecimalType.USER_DEFAULT), + Literal(Float.MaxValue), + Literal(false)) + + checkEvaluation(AtLeastNNulls(0, mix), true, EmptyRow) + checkEvaluation(AtLeastNNulls(1, mix), true, EmptyRow) + checkEvaluation(AtLeastNNulls(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNulls(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNulls(0, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(1, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNulls(2, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNulls(0, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(1, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(2, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(3, nullOnly), false, EmptyRow) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index a4fd4cf3b330b..ea85f0657a726 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -122,7 +122,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { // Filtering condition: // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. - val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) + val predicate = AtLeastNNonNullNans(minNonNulls, cols.map(name => df.resolve(name))) df.filter(Column(predicate)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 9b2dbd7442f5c..8364d4b17d862 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -401,6 +401,10 @@ private[spark] object SQLConf { "spark.sql.useSerializer2", defaultValue = Some(true), isPublic = false) + val ADVANCED_SQL_OPTIMIZATION = booleanConf( + "spark.sql.advancedOptimization", + defaultValue = Some(true), isPublic = false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -470,6 +474,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) + private[spark] def advancedSqlOptimizations: Boolean = getConf(ADVANCED_SQL_OPTIMIZATION) + private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dbb2a09846548..31e2b508d485e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.optimizer.FilterNullsInJoinKey import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -156,7 +157,9 @@ class SQLContext(@transient val sparkContext: SparkContext) } @transient - protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer + protected[sql] lazy val optimizer: Optimizer = new DefaultOptimizer { + override val extendedOperatorOptimizationRules = FilterNullsInJoinKey(self) :: Nil + } @transient protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 41a0c519ba527..d5a240a72f97f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} @@ -143,7 +144,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { val rdd = child.execute() val part: Partitioner = newPartitioning match { - case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) + case HashPartitioning(expressions, numPartitions, _) => + new HashPartitioner(numPartitions) case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. @@ -162,7 +164,38 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // TODO: Handle BroadcastPartitioning. } def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match { - case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() + case HashPartitioning(expressions, _, true) => + // Since NullSafeHashPartitioning and NullUnsafeHashPartitioning may be used together + // for a join operator. We need to make sure they calculate the partition id with + // the same way. + val materalizeExpressions = newMutableProjection(expressions, child.output)() + val partitionExpressionSchema = expressions.map { + case ne: NamedExpression => ne.toAttribute + case expr => Alias(expr, "partitionExpr")().toAttribute + } + val partitionId = RowHashCode + val partitionIdExtractor = + newMutableProjection(partitionId :: Nil, partitionExpressionSchema)() + (row: InternalRow) => partitionIdExtractor(materalizeExpressions(row)) + // newMutableProjection(expressions, child.output)() + case HashPartitioning(expressions, numPartition, false) => + // For NullUnsafeHashPartitioning, we do not want to send rows having any expression + // in `expressions` evaluated as null to the same node. + val materalizeExpressions = newMutableProjection(expressions, child.output)() + val partitionExpressionSchema = expressions.map { + case ne: NamedExpression => ne.toAttribute + case expr => Alias(expr, "partitionExpr")().toAttribute + } + val partitionId = + If( + Not(AtLeastNNulls(1, partitionExpressionSchema)), + // There is no null value in the partition expressions, we can just get the + // hashCode of the input row. + RowHashCode, + Cast(Multiply(new Rand(numPartition), Literal(numPartition.toDouble)), IntegerType)) + val partitionIdExtractor = + newMutableProjection(partitionId :: Nil, partitionExpressionSchema)() + (row: InternalRow) => partitionIdExtractor(materalizeExpressions(row)) case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } @@ -195,6 +228,8 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ // TODO: Determine the number of partitions. def numPartitions: Int = sqlContext.conf.numShufflePartitions + def advancedSqlOptimizations: Boolean = sqlContext.conf.advancedSqlOptimizations + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => // True iff every child's outputPartitioning satisfies the corresponding @@ -239,7 +274,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ child: SparkPlan): SparkPlan = { def addShuffleIfNecessary(child: SparkPlan): SparkPlan = { - if (child.outputPartitioning != partitioning) { + if (!child.outputPartitioning.guarantees(partitioning)) { Exchange(partitioning, child) } else { child @@ -276,13 +311,32 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ val fixedChildren = requirements.zipped.map { case (AllTuples, rowOrdering, child) => addOperatorsIfNecessary(SinglePartition, rowOrdering, child) - case (ClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + + case (ClusteredDistribution(clustering, true), rowOrdering, child) => + addOperatorsIfNecessary( + HashPartitioning(clustering, numPartitions), + rowOrdering, + child) + + case (ClusteredDistribution(clustering, false), rowOrdering, child) => + if (advancedSqlOptimizations) { + addOperatorsIfNecessary( + HashPartitioning(clustering, numPartitions, false), + rowOrdering, + child) + } else { + addOperatorsIfNecessary( + HashPartitioning(clustering, numPartitions), + rowOrdering, + child) + } + case (OrderedDistribution(ordering), rowOrdering, child) => addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) case (UnspecifiedDistribution, Seq(), child) => child + case (UnspecifiedDistribution, rowOrdering, child) => sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 306bbfec624c0..924f34614b3fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -403,7 +403,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.RepartitionByExpression(expressions, child) => - execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + execution.Exchange( + HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index c9d1a880f4ef4..c30ad3708e3e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.ThreadUtils @@ -57,6 +57,8 @@ case class BroadcastHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + @transient private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 6bf2f82954046..d0cee21fbb840 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -38,14 +38,6 @@ trait HashOuterJoin { val left: SparkPlan val right: SparkPlan - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - override def output: Seq[Attribute] = { joinType match { case LeftOuter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 874712a4e739f..965cc65b2c0e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -37,7 +37,9 @@ case class LeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { - override def requiredChildDistribution: Seq[ClusteredDistribution] = + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 948d0ccebceb0..8a07200b9928d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -38,9 +38,10 @@ case class ShuffledHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) - override def requiredChildDistribution: Seq[ClusteredDistribution] = + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index f54f1edd38ec8..bb6cfa40194eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -41,8 +41,29 @@ case class ShuffledHashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashOuterJoin { + // It is a heuristic. We use NullUnsafeClusteredDistribution to + // let input rows that will have a match distributed evenly. override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, nullSafe = false) :: + ClusteredDistribution(rightKeys, nullSafe = false) :: Nil + + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => + val partitions = + Seq(left.outputPartitioning, right.outputPartitioning.withNullSafeSetting(false)) + PartitioningCollection(partitions) + case RightOuter => + val partitions = + Seq(right.outputPartitioning, left.outputPartitioning.withNullSafeSetting(false)) + PartitioningCollection(partitions) + case FullOuter => + val partitions = + Seq(left.outputPartitioning.withNullSafeSetting(false), + right.outputPartitioning.withNullSafeSetting(false)) + PartitioningCollection(partitions) + case x => + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") + } protected override def doExecute(): RDD[InternalRow] = { val joinedRow = new JoinedRow() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index bb18b5403f8e8..41be78afd37e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -40,7 +40,8 @@ case class SortMergeJoin( override def output: Seq[Attribute] = left.output ++ right.output - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala new file mode 100644 index 0000000000000..929d6b429aaf5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.optimizer + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, LeftSemi} +import org.apache.spark.sql.catalyst.plans.logical.{Project, Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * An optimization rule used to insert Filters to filter out rows whose equal join keys + * have at least one null values. For this kind of rows, they will not contribute to + * the join results of equal joins because a null does not equal another null. We can + * filter them out before shuffling join input rows. For example, we have two tables + * + * table1(key String, value Int) + * "str1"|1 + * null |2 + * + * table2(key String, value Int) + * "str1"|3 + * null |4 + * + * For a inner equal join, the result will be + * "str1"|1|"str1"|3 + * + * those two rows having null as the value of key will not contribute to the result. + * So, we can filter them out early. + * + * This optimization rule can be disabled by setting spark.sql.advancedOptimization to false. + * + */ +case class FilterNullsInJoinKey( + sqlContext: SQLContext) + extends Rule[LogicalPlan] { + + /** + * Checks if we need to add a Filter operator. We will add a Filter when + * there is any attribute in `keys` whose corresponding attribute of `keys` + * in `plan.output` is still nullable (`nullable` field is `true`). + */ + private def needsFilter(keys: Seq[Expression], plan: LogicalPlan): Boolean = { + val keyAttributeSet = AttributeSet(keys.filter(_.isInstanceOf[Attribute])) + plan.output.filter(keyAttributeSet.contains).exists(_.nullable) + } + + /** + * Adds a Filter operator to make sure that every attribute in `keys` is non-nullable. + */ + private def addFilterIfNecessary( + keys: Seq[Expression], + child: LogicalPlan): LogicalPlan = { + // We get all attributes from keys. + val attributes = keys.filter { + case attr: Attribute => true + case _ => false + } + + // Then, we create a Filter to make sure these attributes are non-nullable. + val filter = + if (attributes.nonEmpty) { + Filter(Not(AtLeastNNulls(1, attributes)), child) + } else { + child + } + + // We return attributes representing keys (keyAttributes) and the filter. + // keyAttributes will be used to rewrite the join condition. + filter + } + + /** + * We reconstruct the join condition. + */ + private def reconstructJoinCondition( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + otherPredicate: Option[Expression]): Expression = { + // First, we rewrite the equal condition part. When we extract those keys, + // we use splitConjunctivePredicates. So, it is safe to use .reduce(And). + val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map { + case (l, r) => EqualTo(l, r) + }.reduce(And) + + // Then, we add otherPredicate. When we extract those equal condition part, + // we use splitConjunctivePredicates. So, it is safe to use + // And(rewrittenEqualJoinCondition, c). + val rewrittenJoinCondition = otherPredicate + .map(c => And(rewrittenEqualJoinCondition, c)) + .getOrElse(rewrittenEqualJoinCondition) + + rewrittenJoinCondition + } + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!sqlContext.conf.advancedSqlOptimizations) { + plan + } else { + plan transform { + case join: Join => join match { + // For a inner join having equal join condition part, we can add filters + // to both sides of the join operator. + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => + val withLeftFilter = addFilterIfNecessary(leftKeys, left) + val withRightFilter = addFilterIfNecessary(rightKeys, right) + val rewrittenJoinCondition = + reconstructJoinCondition(leftKeys, rightKeys, condition) + + Join(withLeftFilter, withRightFilter, Inner, Some(rewrittenJoinCondition)) + + // For a left outer join having equal join condition part, we can add a filter + // to the right side of the join operator. + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + if needsFilter(rightKeys, right) => + val withRightFilter = addFilterIfNecessary(rightKeys, right) + val rewrittenJoinCondition = + reconstructJoinCondition(leftKeys, rightKeys, condition) + + Join(left, withRightFilter, LeftOuter, Some(rewrittenJoinCondition)) + + // For a right outer join having equal join condition part, we can add a filter + // to the left side of the join operator. + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) + if needsFilter(leftKeys, left) => + val withLeftFilter = addFilterIfNecessary(leftKeys, left) + val rewrittenJoinCondition = + reconstructJoinCondition(leftKeys, rightKeys, condition) + + Join(withLeftFilter, right, RightOuter, Some(rewrittenJoinCondition)) + + // For a left semi join having equal join condition part, we can add filters + // to both sides of the join operator. + case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) + if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => + val withLeftFilter = addFilterIfNecessary(leftKeys, left) + val withRightFilter = addFilterIfNecessary(rightKeys, right) + val rewrittenJoinCondition = + reconstructJoinCondition(leftKeys, rightKeys, condition) + + Join(withLeftFilter, withRightFilter, LeftSemi, Some(rewrittenJoinCondition)) + + case other => other + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala new file mode 100644 index 0000000000000..2a6b29f6e96f5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.AtLeastNNulls +import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.test.TestSQLContext + +class FilterNullsInJoinKeySuite extends PlanTest { + + // We add predicate pushdown rules at here to make sure we do not + // create redundant + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Operator Optimizations", FixedPoint(100), + FilterNullsInJoinKey(TestSQLContext), + CombineFilters, + PushPredicateThroughProject, + BooleanSimplification, + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + ColumnPruning, + ProjectCollapsing) :: Nil + } + + val leftRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int) + + val rightRelation = LocalRelation('e.int, 'f.int, 'g.int, 'h.int) + + test("inner join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For an inner join, FilterNullsInJoinKey add filter to both side. + val correctLeft = + leftRelation + .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) + .select('a, 'b, 'd) + + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + correctLeft + .join(correctRight, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join (partially optimized)") { + val joinCondition = + ('a + 2 === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // We cannot extract attribute from the left join key. + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + leftRelation + .select('a, 'b, 'd) + .join(correctRight, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join (not optimized)") { + val nonOptimizedJoinConditions = + Some('c - 100 + 'd === 'g + 1 - 'h) :: + Some('d > 'h || 'c === 'g) :: + Some('d + 'g + 'c > 'd - 'h) :: Nil + + nonOptimizedJoinConditions.foreach { joinCondition => + val joinedPlan = + leftRelation + .select('a, 'c, 'd) + .join(rightRelation.select('f, 'g, 'h), Inner, joinCondition) + .select('a, 'c, 'f, 'd, 'h, 'g) + + val optimized = Optimize.execute(joinedPlan.analyze) + + comparePlans(optimized, joinedPlan.analyze) + } + } + + test("left outer join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, LeftOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For a left outer join, FilterNullsInJoinKey add filter to the right side. + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + leftRelation + .select('a, 'b, 'd) + .join(correctRight, LeftOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("right outer join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, RightOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For a right outer join, FilterNullsInJoinKey add filter to the left side. + val correctLeft = + leftRelation + .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) + .select('a, 'b, 'd) + + val correctAnswer = + correctLeft + .join(rightRelation, RightOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("full outer join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .select('a, 'b, 'd) + .join(rightRelation, FullOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + // FilterNullsInJoinKey does not fire for a full outer join. + val optimized = Optimize.execute(joinedPlan.analyze) + + comparePlans(optimized, joinedPlan.analyze) + } + + test("left semi join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, LeftSemi, Some(joinCondition)) + .select('a, 'd) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For a left semi join, FilterNullsInJoinKey add filter to both side. + val correctLeft = + leftRelation + .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) + .select('a, 'b, 'd) + + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + correctLeft + .join(correctRight, LeftSemi, Some(joinCondition)) + .select('a, 'd) + .analyze + + comparePlans(optimized, correctAnswer) + } +}