From be19a0f21187c364ee3eaee5f0a310ece874c66c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 5 Aug 2015 13:43:55 -0700 Subject: [PATCH 01/56] [SPARK-9054] [SQL] Rename RowOrdering to InterpretedOrdering; use newOrdering in more places. --- .../sql/catalyst/expressions/arithmetic.scala | 4 +-- .../catalyst/expressions/conditionals.scala | 4 +-- .../{RowOrdering.scala => ordering.scala} | 27 ++++++++++--------- .../sql/catalyst/expressions/predicates.scala | 8 +++--- .../spark/sql/catalyst/util/TypeUtils.scala | 4 +-- .../apache/spark/sql/types/StructType.scala | 4 +-- .../expressions/CodeGenerationSuite.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 5 +++- .../spark/sql/execution/SparkPlan.scala | 15 +++++++++-- .../spark/sql/execution/basicOperators.scala | 4 ++- .../sql/execution/joins/SortMergeJoin.scala | 5 ++-- .../UnsafeKVExternalSorterSuite.scala | 6 ++--- 12 files changed, 53 insertions(+), 35 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{RowOrdering.scala => ordering.scala} (85%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 5808e3f66de3c..98464edf4d390 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -320,7 +320,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def eval(input: InternalRow): Any = { val input1 = left.eval(input) @@ -374,7 +374,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def eval(input: InternalRow): Any = { val input1 = left.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 961b1d8616801..d51f3d3cef588 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -319,7 +319,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { @@ -374,7 +374,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala similarity index 85% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 873f5324c573e..6407c73bc97d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ /** * An interpreted row ordering comparator. */ -class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { +class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = this(ordering.map(BindReferences.bindReference(_, inputSchema))) @@ -49,9 +49,9 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { case dt: AtomicType if order.direction == Descending => dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case s: StructType if order.direction == Ascending => - s.ordering.asInstanceOf[Ordering[Any]].compare(left, right) + s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) case s: StructType if order.direction == Descending => - s.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + s.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case other => throw new IllegalArgumentException(s"Type $other does not support ordered operations") } @@ -65,6 +65,18 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { } } +object InterpretedOrdering { + + /** + * Creates a [[InterpretedOrdering]] for the given schema, in natural ascending order. + */ + def forSchema(dataTypes: Seq[DataType]): InterpretedOrdering = { + new InterpretedOrdering(dataTypes.zipWithIndex.map { + case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) + } +} + object RowOrdering { /** @@ -81,13 +93,4 @@ object RowOrdering { * Returns true iff outputs from the expressions can be ordered. */ def isOrderable(exprs: Seq[Expression]): Boolean = exprs.forall(e => isOrderable(e.dataType)) - - /** - * Creates a [[RowOrdering]] for the given schema, in natural ascending order. - */ - def forSchema(dataTypes: Seq[DataType]): RowOrdering = { - new RowOrdering(dataTypes.zipWithIndex.map { - case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) - }) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 68c832d7194d4..fe7dffb815987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -376,7 +376,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso override def symbol: String = "<" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } @@ -388,7 +388,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo override def symbol: String = "<=" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } @@ -400,7 +400,7 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar override def symbol: String = ">" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } @@ -412,7 +412,7 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar override def symbol: String = ">=" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 0b41f92c6193c..bcf4d78fb9371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -54,10 +54,10 @@ object TypeUtils { def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] - def getOrdering(t: DataType): Ordering[Any] = { + def getInterpretedOrdering(t: DataType): Ordering[Any] = { t match { case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case s: StructType => s.ordering.asInstanceOf[Ordering[Any]] + case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 6928707f7bf6e..9cbc207538d4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, AttributeReference, Attribute, InterpretedOrdering$} /** @@ -301,7 +301,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(newFields) } - private[sql] val ordering = RowOrdering.forSchema(this.fields.map(_.dataType)) + private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } object StructType extends AbstractDataType { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index cc82f7c3f5a73..e310aee221666 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -54,7 +54,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { // GenerateOrdering agrees with RowOrdering. (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => test(s"GenerateOrdering with $dataType") { - val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType)) + val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType)) val genOrdering = GenerateOrdering.generate( BoundReference(0, dataType, nullable = true).asc :: BoundReference(1, dataType, nullable = true).asc :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 05b009d1935bb..6ea5eeedf1bbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -156,7 +156,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una val mutablePair = new MutablePair[InternalRow, Null]() iter.map(row => mutablePair.update(row.copy(), null)) } - implicit val ordering = new RowOrdering(sortingExpressions, child.output) + // We need to use an interpreted ordering here because generated orderings cannot be + // serialized and this ordering needs to be created on the driver in order to be passed into + // Spark core code. + implicit val ordering = new InterpretedOrdering(sortingExpressions, child.output) new RangePartitioner(numPartitions, rddForSampling, ascending = true) case SinglePartition => new Partitioner { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index dbc0cefbe2e10..8351dafcc6d30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.concurrent.atomic.AtomicBoolean +import org.apache.spark.sql.types.DataType + import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Accumulator, Logging} @@ -309,12 +311,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ throw e } else { log.error("Failed to generate ordering, fallback to interpreted", e) - new RowOrdering(order, inputSchema) + new InterpretedOrdering(order, inputSchema) } } } else { - new RowOrdering(order, inputSchema) + new InterpretedOrdering(order, inputSchema) + } + } + /** + * Creates a row ordering for the given schema, in natural ascending order. + */ + protected def newNaturalAscendingOrdering(dataTypes: Seq[DataType]): Ordering[InternalRow] = { + val order: Seq[SortOrder] = dataTypes.zipWithIndex.map { + case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) } + newOrdering(order, Seq.empty) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 477170297c2ac..f4677b4ee86bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -212,7 +212,9 @@ case class TakeOrderedAndProject( override def outputPartitioning: Partitioning = SinglePartition - private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) + // We need to use an interpreted ordering here because generated orderings cannot be serialized + // and this ordering needs to be created on the driver in order to be passed into Spark core code. + private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index eb595490fbf28..ce37bd5009df6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -48,9 +48,6 @@ case class SortMergeJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - // this is to manually construct an ordering that can be used to compare keys from both sides - private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) override def requiredChildOrdering: Seq[Seq[SortOrder]] = @@ -68,6 +65,8 @@ case class SortMergeJoin( leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => new Iterator[InternalRow] { + // An ordering that can be used to compare keys from both sides. + private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) // Mutable per row objects. private[this] val joinRow = new JoinedRow private[this] var leftElement: InternalRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 08156f0e39ce8..a9515a03acf2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, RowOrdering, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -144,8 +144,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { } sorter.cleanupResources() - val keyOrdering = RowOrdering.forSchema(keySchema.map(_.dataType)) - val valueOrdering = RowOrdering.forSchema(valueSchema.map(_.dataType)) + val keyOrdering = InterpretedOrdering.forSchema(keySchema.map(_.dataType)) + val valueOrdering = InterpretedOrdering.forSchema(valueSchema.map(_.dataType)) val kvOrdering = new Ordering[(InternalRow, InternalRow)] { override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = { keyOrdering.compare(x._1, y._1) match { From 34b8e0cb336b931303f16efc82d85e64283737e1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 5 Aug 2015 13:47:15 -0700 Subject: [PATCH 02/56] Import ordering --- .../main/scala/org/apache/spark/sql/execution/SparkPlan.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 8351dafcc6d30..2f29067f5646a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark.sql.types.DataType - import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Accumulator, Logging} @@ -34,6 +32,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.DataType object SparkPlan { protected[sql] val currentContext = new ThreadLocal[SQLContext]() From e610655f1dc96865523baf1353a7636714d68764 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 5 Aug 2015 14:15:19 -0700 Subject: [PATCH 03/56] Add comment RE: Ascending ordering --- .../org/apache/spark/sql/execution/joins/SortMergeJoin.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index ce37bd5009df6..4ae23c186cf7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -56,8 +56,10 @@ case class SortMergeJoin( @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) + } protected override def doExecute(): RDD[InternalRow] = { val leftResults = left.execute().map(_.copy()) From df885484bb33888bf813a52097d97f152759c0ae Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 18 Jun 2015 21:47:30 -0700 Subject: [PATCH 04/56] Squash @adrian-wang's changes. --- .../spark/sql/execution/SparkStrategies.scala | 11 +- .../sql/execution/joins/SortMergeJoin.scala | 246 +++++++++++++----- .../org/apache/spark/sql/JoinSuite.scala | 34 ++- 3 files changed, 213 insertions(+), 78 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 952ba7d45c13e..d14501bfc68e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -96,13 +96,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) - // If the sort merge join option is set, we want to use sort merge join prior to hashjoin - // for now let's support inner join first, then add outer join - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + // If the sort merge join option is set, we want to use sort merge join prior to hashjoin. + // And for outer join, we can not put conditions outside of the join + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => - val mergeJoin = - joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) - condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + joins.SortMergeJoin( + leftKeys, rightKeys, joinType, planLater(left), planLater(right), condition) :: Nil case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index eb595490fbf28..8e634c65af325 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -23,6 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.collection.CompactBuffer @@ -35,49 +36,93 @@ import org.apache.spark.util.collection.CompactBuffer case class SortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + joinType: JoinType, left: SparkPlan, - right: SparkPlan) extends BinaryNode { + right: SparkPlan, + condition: Option[Expression] = None) extends BinaryNode { - override protected[sql] val trackNumOfRowsEnabled = true + val (streamedPlan, bufferedPlan, streamedKeys, bufferedKeys) = joinType match { + case RightOuter => (right, left, rightKeys, leftKeys) + case _ => (left, right, leftKeys, rightKeys) + } - override def output: Seq[Attribute] = left.output ++ right.output + override def output: Seq[Attribute] = joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new IllegalStateException(s"SortMergeJoin should not take $x as the JoinType") + } - override def outputPartitioning: Partitioning = - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + override def outputPartitioning: Partitioning = joinType match { + case FullOuter => + // when doing Full Outer join, NULL rows from both sides are not so partitioned. + UnknownPartitioning(streamedPlan.outputPartitioning.numPartitions) + case _ => streamedPlan.outputPartitioning + } override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil // this is to manually construct an ordering that can be used to compare keys from both sides - private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) + private val keyOrdering: RowOrdering = RowOrdering.forSchema(streamedKeys.map(_.dataType)) - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + override def outputOrdering: Seq[SortOrder] = joinType match { + case FullOuter => Nil // when doing Full Outer join, NULL rows from both sides are not ordered. + case _ => requiredOrders(streamedKeys) + } override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) - @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + @transient protected lazy val streamedKeyGenerator = + newProjection(streamedKeys, streamedPlan.output) + @transient protected lazy val bufferedKeyGenerator = + newProjection(bufferedKeys, bufferedPlan.output) + + // checks if the joinedRow can meet condition requirements + @transient private[this] lazy val boundCondition = + condition.map(newPredicate(_, streamedPlan.output ++ bufferedPlan.output)).getOrElse( + (row: InternalRow) => true) private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = keys.map(SortOrder(_, Ascending)) protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) + val streamResults = streamedPlan.execute().map(_.copy()) + val bufferResults = bufferedPlan.execute().map(_.copy()) - leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + streamResults.zipPartitions(bufferResults) ( (streamedIter, bufferedIter) => { + // standard null rows + val streamedNullRow = InternalRow.fromSeq(Seq.fill(bufferedPlan.output.length)(null)) + val bufferedNullRow = InternalRow.fromSeq(Seq.fill(bufferedPlan.output.length)(null)) new Iterator[InternalRow] { // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] var leftElement: InternalRow = _ - private[this] var rightElement: InternalRow = _ - private[this] var leftKey: InternalRow = _ - private[this] var rightKey: InternalRow = _ - private[this] var rightMatches: CompactBuffer[InternalRow] = _ - private[this] var rightPosition: Int = -1 + private[this] var streamedElement: InternalRow = _ + private[this] var bufferedElement: InternalRow = _ + private[this] var streamedKey: InternalRow = _ + private[this] var bufferedKey: InternalRow = _ + private[this] var bufferedMatches: CompactBuffer[InternalRow] = _ + private[this] var bufferedPosition: Int = -1 private[this] var stop: Boolean = false private[this] var matchKey: InternalRow = _ + // when we do merge algorithm and find some not matched join key, there must be a side + // that do not have a corresponding match. So we need to mark which side it is. True means + // streamed side not have match, and False means the buffered side. Only set when needed. + private[this] var continueStreamed: Boolean = _ + // when we do full outer join and find all matched keys, we put a null stream row into + // this to tell next() that we need to combine null stream row with all rows that not match + // conditions. + private[this] var secondStreamedElement: InternalRow = _ + // Stores rows that match the join key but not match conditions. + // These rows will be useful when we are doing Full Outer Join. + private[this] var secondBufferedMatches: CompactBuffer[InternalRow] = _ // initialize iterator initialize() @@ -86,86 +131,169 @@ case class SortMergeJoin( override final def next(): InternalRow = { if (hasNext) { - // we are using the buffered right rows and run down left iterator - val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) - rightPosition += 1 - if (rightPosition >= rightMatches.size) { - rightPosition = 0 - fetchLeft() - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false - rightMatches = null + if (bufferedMatches == null || bufferedMatches.size == 0) { + // we just found a row with no join match and we are here to produce a row + // with this row and a standard null row from the other side. + if (continueStreamed) { + val joinedRow = smartJoinRow(streamedElement, bufferedNullRow) + fetchStreamed() + joinedRow + } else { + val joinedRow = smartJoinRow(streamedNullRow, bufferedElement) + fetchBuffered() + joinedRow + } + } else { + // we are using the buffered right rows and run down left iterator + val joinedRow = smartJoinRow(streamedElement, bufferedMatches(bufferedPosition)) + bufferedPosition += 1 + if (bufferedPosition >= bufferedMatches.size) { + bufferedPosition = 0 + if (joinType != FullOuter || secondStreamedElement == null) { + fetchStreamed() + if (streamedElement == null || keyOrdering.compare(streamedKey, matchKey) != 0) { + stop = false + bufferedMatches = null + } + } else { + // in FullOuter join and the first time we finish the match buffer, + // we still want to generate all rows with streamed null row and buffered + // rows that match the join key but not the conditions. + streamedElement = secondStreamedElement + bufferedMatches = secondBufferedMatches + secondStreamedElement = null + secondBufferedMatches = null + } } + joinedRow } - joinedRow } else { // no more result throw new NoSuchElementException } } - private def fetchLeft() = { - if (leftIter.hasNext) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) + private def smartJoinRow(streamedRow: InternalRow, bufferedRow: InternalRow): InternalRow = + joinType match { + case RightOuter => joinRow(bufferedRow, streamedRow) + case _ => joinRow(streamedRow, bufferedRow) + } + + private def fetchStreamed(): Unit = { + if (streamedIter.hasNext) { + streamedElement = streamedIter.next() + streamedKey = streamedKeyGenerator(streamedElement) } else { - leftElement = null + streamedElement = null } } - private def fetchRight() = { - if (rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) + private def fetchBuffered(): Unit = { + if (bufferedIter.hasNext) { + bufferedElement = bufferedIter.next() + bufferedKey = bufferedKeyGenerator(bufferedElement) } else { - rightElement = null + bufferedElement = null } } private def initialize() = { - fetchLeft() - fetchRight() + fetchStreamed() + fetchBuffered() } /** * Searches the right iterator for the next rows that have matches in left side, and store * them in a buffer. + * When this is not a Inner join, we will also return true when we get a row with no match + * on the other side. This search will jump out every time from the same position until + * `next()` is called. + * Unless we call `next()`, this function can be called multiple times, with the same + * return value and result as running it once, since we have set guardians in it. * * @return true if the search is successful, and false if the right iterator runs out of * tuples. */ private def nextMatchingPair(): Boolean = { - if (!stop && rightElement != null) { - // run both side to get the first match pair - while (!stop && leftElement != null && rightElement != null) { - val comparing = keyOrdering.compare(leftKey, rightKey) + if (!stop && streamedElement != null) { + // step 1: run both side to get the first match pair + while (!stop && streamedElement != null && bufferedElement != null) { + val comparing = keyOrdering.compare(streamedKey, bufferedKey) // for inner join, we need to filter those null keys - stop = comparing == 0 && !leftKey.anyNull - if (comparing > 0 || rightKey.anyNull) { - fetchRight() - } else if (comparing < 0 || leftKey.anyNull) { - fetchLeft() + stop = comparing == 0 && !streamedKey.anyNull + if (comparing > 0 || bufferedKey.anyNull) { + if (joinType == FullOuter) { + // the join type is full outer and the buffered side has a row with no + // join match, so we have a result row with streamed null with buffered + // side as this row. Then we fetch next buffered element and go back. + continueStreamed = false + return true + } else { + fetchBuffered() + } + } else if (comparing < 0 || streamedKey.anyNull) { + if (joinType == Inner) { + fetchStreamed() + } else { + // the join type is not inner and the streamed side has a row with no + // join match, so we have a result row with this streamed row with buffered + // null row. Then we fetch next streamed element and go back. + continueStreamed = true + return true + } } } - rightMatches = new CompactBuffer[InternalRow]() + // step 2: run down the buffered side to put all matched rows in a buffer + bufferedMatches = new CompactBuffer[InternalRow]() + secondBufferedMatches = new CompactBuffer[InternalRow]() if (stop) { stop = false // iterate the right side to buffer all rows that matches // as the records should be ordered, exit when we meet the first that not match - while (!stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - stop = keyOrdering.compare(leftKey, rightKey) != 0 + while (!stop) { + if (boundCondition(joinRow(streamedElement, bufferedElement))) { + bufferedMatches += bufferedElement + } else if (joinType == FullOuter) { + bufferedMatches += bufferedNullRow + secondBufferedMatches += bufferedElement + } + fetchBuffered() + stop = + keyOrdering.compare(streamedKey, bufferedKey) != 0 || bufferedElement == null + } + if (bufferedMatches.size == 0 && joinType != Inner) { + bufferedMatches += bufferedNullRow + } + if (bufferedMatches.size > 0) { + bufferedPosition = 0 + matchKey = streamedKey + // secondBufferedMatches.size cannot be larger than bufferedMatches + if (secondBufferedMatches.size > 0) { + secondStreamedElement = streamedNullRow + } + } + } + } + // `stop` is false iff left or right has finished iteration in step 1. + // if we get into step 2, `stop` cannot be false. + if (!stop && (bufferedMatches == null || bufferedMatches.size == 0)) { + if (streamedElement == null && bufferedElement != null) { + // streamedElement == null but bufferedElement != null + if (joinType == FullOuter) { + continueStreamed = false + return true } - if (rightMatches.size > 0) { - rightPosition = 0 - matchKey = leftKey + } else if (streamedElement != null && bufferedElement == null) { + // bufferedElement == null but streamedElement != null + if (joinType != Inner) { + continueStreamed = true + return true } } } - rightMatches != null && rightMatches.size > 0 + bufferedMatches != null && bufferedMatches.size > 0 } } - } + }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5bef1d8966031..2f0baa1213ab3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -83,13 +83,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -98,11 +97,20 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, false) Seq( - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) @@ -154,14 +162,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), + classOf[SortMergeJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) + classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, false) Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", From 58edb2e426ef33d9240a53cd6da72c28de53f0c5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 3 Aug 2015 14:06:41 -0700 Subject: [PATCH 05/56] Remove old TODO --- .../apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index eee8ad800f98e..1dcfd04bb8665 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -55,7 +55,6 @@ case class ShuffledHashOuterJoin( protected override def doExecute(): RDD[InternalRow] = { val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => val hashed = HashedRelation(rightIter, buildKeyGenerator) From 9faa2eeb85ac6be58ac22f3336f6775654faa250 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 3 Aug 2015 17:58:30 -0700 Subject: [PATCH 06/56] Use withSQLConf in JoinSuite --- .../org/apache/spark/sql/JoinSuite.scala | 57 +++++++------------ 1 file changed, 20 insertions(+), 37 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2f0baa1213ab3..8ffa6157a933d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -22,13 +22,14 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.types.BinaryType +import org.apache.spark.sql.test.SQLTestUtils -class JoinSuite extends QueryTest with BeforeAndAfterEach { +class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Ensures tables are loaded. TestData + override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ import ctx.logicalPlanToSparkQuery @@ -66,7 +67,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { ctx.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -96,8 +96,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, false) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { Seq( ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", @@ -112,20 +111,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[ShuffledHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } test("SortMergeJoin shouldn't work on unsortable columns") { - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } @@ -133,15 +126,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", @@ -149,8 +140,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } ctx.sql("UNCACHE TABLE testData") @@ -160,7 +149,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", @@ -168,8 +156,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, false) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", @@ -177,8 +164,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } ctx.sql("UNCACHE TABLE testData") @@ -465,25 +450,24 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("broadcasted left semi join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - val tmp = ctx.conf.autoBroadcastJoinThreshold - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000") - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastLeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") - - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp) ctx.sql("UNCACHE TABLE testData") } @@ -496,6 +480,5 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) - } } From 8d83e152f06bddf4d50da661a614b1840769232b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 4 Aug 2015 13:21:49 -0700 Subject: [PATCH 07/56] Use explicit toScala conversions in ShuffledHashOuterJoin. --- .../joins/ShuffledHashOuterJoin.scala | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 1dcfd04bb8665..60b0aa2ba1ff8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.joins -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer /** * :: DeveloperApi :: @@ -78,11 +79,16 @@ case class ShuffledHashOuterJoin( // TODO(davies): use UnsafeRow val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST), - joinedRow) + (leftHashTable.keySet.asScala ++ rightHashTable.keySet.asScala).iterator.flatMap { key => + val leftRows: CompactBuffer[InternalRow] = { + val rows = leftHashTable.get(key) + if (rows == null) EMPTY_LIST else rows + } + val rightRows: CompactBuffer[InternalRow] = { + val rows = rightHashTable.get(key) + if (rows == null) EMPTY_LIST else rows + } + fullOuterIterator(key, leftRows, rightRows, joinedRow) } case x => From a471a6e07b5affaa9a388f7524115797e3d6e555 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 4 Aug 2015 14:42:02 -0700 Subject: [PATCH 08/56] Revert changes to SortMergeJoin; add new SortMergeOuterJoin operator --- .../spark/sql/execution/SparkStrategies.scala | 14 +- .../sql/execution/joins/HashOuterJoin.scala | 19 +- .../joins/ShuffledHashOuterJoin.scala | 8 - .../sql/execution/joins/SortMergeJoin.scala | 247 +++++------------- .../execution/joins/SortMergeOuterJoin.scala | 110 ++++++++ .../org/apache/spark/sql/JoinSuite.scala | 15 +- 6 files changed, 202 insertions(+), 211 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d14501bfc68e6..420b33efb07cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -96,12 +96,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) - // If the sort merge join option is set, we want to use sort merge join prior to hashjoin. - // And for outer join, we can not put conditions outside of the join + // If the sort merge join option is set, we want to use sort merge join prior to hashjoin + // for now let's support inner join first, then add outer join + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + val mergeJoin = + joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) + condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => - joins.SortMergeJoin( - leftKeys, rightKeys, joinType, planLater(left), planLater(right), condition) :: Nil + joins.SortMergeOuterJoin( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index a323aea4ea2c4..2946224e93f6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -23,6 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @@ -37,7 +38,7 @@ trait HashOuterJoin { val left: SparkPlan val right: SparkPlan - override def output: Seq[Attribute] = { + final override def output: Seq[Attribute] = { joinType match { case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -46,16 +47,26 @@ trait HashOuterJoin { case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") } } + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + protected[this] lazy val (buildPlan, streamedPlan) = joinType match { case RightOuter => (left, right) case LeftOuter => (right, left) case x => throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") + s"${getClass.getSimpleName} should not take $x as the JoinType") } protected[this] lazy val (buildKeys, streamedKeys) = joinType match { @@ -63,7 +74,7 @@ trait HashOuterJoin { case LeftOuter => (rightKeys, leftKeys) case x => throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") + s"${getClass.getSimpleName} should not take $x as the JoinType") } protected[this] def isUnsafeMode: Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 60b0aa2ba1ff8..ac9c11ab3951e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -45,14 +45,6 @@ case class ShuffledHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - protected override def doExecute(): RDD[InternalRow] = { val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 8e634c65af325..3fe7167dcd701 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -23,7 +23,6 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.collection.CompactBuffer @@ -36,93 +35,48 @@ import org.apache.spark.util.collection.CompactBuffer case class SortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], - joinType: JoinType, left: SparkPlan, - right: SparkPlan, - condition: Option[Expression] = None) extends BinaryNode { + right: SparkPlan) extends BinaryNode { - val (streamedPlan, bufferedPlan, streamedKeys, bufferedKeys) = joinType match { - case RightOuter => (right, left, rightKeys, leftKeys) - case _ => (left, right, leftKeys, rightKeys) - } - - override def output: Seq[Attribute] = joinType match { - case Inner => - left.output ++ right.output - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new IllegalStateException(s"SortMergeJoin should not take $x as the JoinType") - } + override def output: Seq[Attribute] = left.output ++ right.output - override def outputPartitioning: Partitioning = joinType match { - case FullOuter => - // when doing Full Outer join, NULL rows from both sides are not so partitioned. - UnknownPartitioning(streamedPlan.outputPartitioning.numPartitions) - case _ => streamedPlan.outputPartitioning - } + override def outputPartitioning: Partitioning = + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil // this is to manually construct an ordering that can be used to compare keys from both sides - private val keyOrdering: RowOrdering = RowOrdering.forSchema(streamedKeys.map(_.dataType)) + private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) - override def outputOrdering: Seq[SortOrder] = joinType match { - case FullOuter => Nil // when doing Full Outer join, NULL rows from both sides are not ordered. - case _ => requiredOrders(streamedKeys) - } + override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - @transient protected lazy val streamedKeyGenerator = - newProjection(streamedKeys, streamedPlan.output) - @transient protected lazy val bufferedKeyGenerator = - newProjection(bufferedKeys, bufferedPlan.output) - - // checks if the joinedRow can meet condition requirements - @transient private[this] lazy val boundCondition = - condition.map(newPredicate(_, streamedPlan.output ++ bufferedPlan.output)).getOrElse( - (row: InternalRow) => true) + @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) + @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = keys.map(SortOrder(_, Ascending)) protected override def doExecute(): RDD[InternalRow] = { - val streamResults = streamedPlan.execute().map(_.copy()) - val bufferResults = bufferedPlan.execute().map(_.copy()) + // TODO(josh): why is this copying necessary? + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) - streamResults.zipPartitions(bufferResults) ( (streamedIter, bufferedIter) => { - // standard null rows - val streamedNullRow = InternalRow.fromSeq(Seq.fill(bufferedPlan.output.length)(null)) - val bufferedNullRow = InternalRow.fromSeq(Seq.fill(bufferedPlan.output.length)(null)) + leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => new Iterator[InternalRow] { // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] var streamedElement: InternalRow = _ - private[this] var bufferedElement: InternalRow = _ - private[this] var streamedKey: InternalRow = _ - private[this] var bufferedKey: InternalRow = _ - private[this] var bufferedMatches: CompactBuffer[InternalRow] = _ - private[this] var bufferedPosition: Int = -1 + private[this] var leftElement: InternalRow = _ + private[this] var rightElement: InternalRow = _ + private[this] var leftKey: InternalRow = _ + private[this] var rightKey: InternalRow = _ + private[this] var rightMatches: CompactBuffer[InternalRow] = _ + private[this] var rightPosition: Int = -1 private[this] var stop: Boolean = false private[this] var matchKey: InternalRow = _ - // when we do merge algorithm and find some not matched join key, there must be a side - // that do not have a corresponding match. So we need to mark which side it is. True means - // streamed side not have match, and False means the buffered side. Only set when needed. - private[this] var continueStreamed: Boolean = _ - // when we do full outer join and find all matched keys, we put a null stream row into - // this to tell next() that we need to combine null stream row with all rows that not match - // conditions. - private[this] var secondStreamedElement: InternalRow = _ - // Stores rows that match the join key but not match conditions. - // These rows will be useful when we are doing Full Outer Join. - private[this] var secondBufferedMatches: CompactBuffer[InternalRow] = _ // initialize iterator initialize() @@ -131,169 +85,86 @@ case class SortMergeJoin( override final def next(): InternalRow = { if (hasNext) { - if (bufferedMatches == null || bufferedMatches.size == 0) { - // we just found a row with no join match and we are here to produce a row - // with this row and a standard null row from the other side. - if (continueStreamed) { - val joinedRow = smartJoinRow(streamedElement, bufferedNullRow) - fetchStreamed() - joinedRow - } else { - val joinedRow = smartJoinRow(streamedNullRow, bufferedElement) - fetchBuffered() - joinedRow - } - } else { - // we are using the buffered right rows and run down left iterator - val joinedRow = smartJoinRow(streamedElement, bufferedMatches(bufferedPosition)) - bufferedPosition += 1 - if (bufferedPosition >= bufferedMatches.size) { - bufferedPosition = 0 - if (joinType != FullOuter || secondStreamedElement == null) { - fetchStreamed() - if (streamedElement == null || keyOrdering.compare(streamedKey, matchKey) != 0) { - stop = false - bufferedMatches = null - } - } else { - // in FullOuter join and the first time we finish the match buffer, - // we still want to generate all rows with streamed null row and buffered - // rows that match the join key but not the conditions. - streamedElement = secondStreamedElement - bufferedMatches = secondBufferedMatches - secondStreamedElement = null - secondBufferedMatches = null - } + // we are using the buffered right rows and run down left iterator + val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) + rightPosition += 1 + if (rightPosition >= rightMatches.size) { + rightPosition = 0 + fetchLeft() + if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { + stop = false + rightMatches = null } - joinedRow } + joinedRow } else { // no more result throw new NoSuchElementException } } - private def smartJoinRow(streamedRow: InternalRow, bufferedRow: InternalRow): InternalRow = - joinType match { - case RightOuter => joinRow(bufferedRow, streamedRow) - case _ => joinRow(streamedRow, bufferedRow) - } - - private def fetchStreamed(): Unit = { - if (streamedIter.hasNext) { - streamedElement = streamedIter.next() - streamedKey = streamedKeyGenerator(streamedElement) + private def fetchLeft() = { + if (leftIter.hasNext) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) } else { - streamedElement = null + leftElement = null } } - private def fetchBuffered(): Unit = { - if (bufferedIter.hasNext) { - bufferedElement = bufferedIter.next() - bufferedKey = bufferedKeyGenerator(bufferedElement) + private def fetchRight() = { + if (rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) } else { - bufferedElement = null + rightElement = null } } private def initialize() = { - fetchStreamed() - fetchBuffered() + fetchLeft() + fetchRight() } /** * Searches the right iterator for the next rows that have matches in left side, and store * them in a buffer. - * When this is not a Inner join, we will also return true when we get a row with no match - * on the other side. This search will jump out every time from the same position until - * `next()` is called. - * Unless we call `next()`, this function can be called multiple times, with the same - * return value and result as running it once, since we have set guardians in it. * * @return true if the search is successful, and false if the right iterator runs out of * tuples. */ private def nextMatchingPair(): Boolean = { - if (!stop && streamedElement != null) { - // step 1: run both side to get the first match pair - while (!stop && streamedElement != null && bufferedElement != null) { - val comparing = keyOrdering.compare(streamedKey, bufferedKey) + if (!stop && rightElement != null) { + // run both side to get the first match pair + while (!stop && leftElement != null && rightElement != null) { + val comparing = keyOrdering.compare(leftKey, rightKey) // for inner join, we need to filter those null keys - stop = comparing == 0 && !streamedKey.anyNull - if (comparing > 0 || bufferedKey.anyNull) { - if (joinType == FullOuter) { - // the join type is full outer and the buffered side has a row with no - // join match, so we have a result row with streamed null with buffered - // side as this row. Then we fetch next buffered element and go back. - continueStreamed = false - return true - } else { - fetchBuffered() - } - } else if (comparing < 0 || streamedKey.anyNull) { - if (joinType == Inner) { - fetchStreamed() - } else { - // the join type is not inner and the streamed side has a row with no - // join match, so we have a result row with this streamed row with buffered - // null row. Then we fetch next streamed element and go back. - continueStreamed = true - return true - } + stop = comparing == 0 && !leftKey.anyNull + if (comparing > 0 || rightKey.anyNull) { + fetchRight() + } else if (comparing < 0 || leftKey.anyNull) { + fetchLeft() } } - // step 2: run down the buffered side to put all matched rows in a buffer - bufferedMatches = new CompactBuffer[InternalRow]() - secondBufferedMatches = new CompactBuffer[InternalRow]() + rightMatches = new CompactBuffer[InternalRow]() if (stop) { stop = false // iterate the right side to buffer all rows that matches // as the records should be ordered, exit when we meet the first that not match - while (!stop) { - if (boundCondition(joinRow(streamedElement, bufferedElement))) { - bufferedMatches += bufferedElement - } else if (joinType == FullOuter) { - bufferedMatches += bufferedNullRow - secondBufferedMatches += bufferedElement - } - fetchBuffered() - stop = - keyOrdering.compare(streamedKey, bufferedKey) != 0 || bufferedElement == null - } - if (bufferedMatches.size == 0 && joinType != Inner) { - bufferedMatches += bufferedNullRow - } - if (bufferedMatches.size > 0) { - bufferedPosition = 0 - matchKey = streamedKey - // secondBufferedMatches.size cannot be larger than bufferedMatches - if (secondBufferedMatches.size > 0) { - secondStreamedElement = streamedNullRow - } - } - } - } - // `stop` is false iff left or right has finished iteration in step 1. - // if we get into step 2, `stop` cannot be false. - if (!stop && (bufferedMatches == null || bufferedMatches.size == 0)) { - if (streamedElement == null && bufferedElement != null) { - // streamedElement == null but bufferedElement != null - if (joinType == FullOuter) { - continueStreamed = false - return true + while (!stop && rightElement != null) { + rightMatches += rightElement + fetchRight() + stop = keyOrdering.compare(leftKey, rightKey) != 0 } - } else if (streamedElement != null && bufferedElement == null) { - // bufferedElement == null but streamedElement != null - if (joinType != Inner) { - continueStreamed = true - return true + if (rightMatches.size > 0) { + rightPosition = 0 + matchKey = leftKey } } } - bufferedMatches != null && bufferedMatches.size > 0 + rightMatches != null && rightMatches.size > 0 } } - }) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala new file mode 100644 index 0000000000000..efc194582b59c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs an sort merge outer join of two child relations. + */ +@DeveloperApi +case class SortMergeOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashOuterJoin { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + // this is to manually construct an ordering that can be used to compare keys from both sides + private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) + + override def outputOrdering: Seq[SortOrder] = joinType match { + case FullOuter => Nil // when doing Full Outer join, NULL rows from both sides are not ordered. + case _ => requiredOrders(leftKeys) + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + keys.map(SortOrder(_, Ascending)) + + protected override def doExecute(): RDD[InternalRow] = { + // TODO(josh): why is this copying necessary? + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + val joinedRow = new JoinedRow() + leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + joinType match { + case LeftOuter => + // TODO(josh): for SMJ we would buffer keys here: + val hashed = HashedRelation(rightIter, buildKeyGenerator) + val keyGenerator = streamedKeyGenerator + leftIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey)) + }) + + case RightOuter => + // TODO(josh): for SMJ we would buffer keys here: + val hashed = HashedRelation(leftIter, buildKeyGenerator) + val keyGenerator = streamedKeyGenerator + rightIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow) + }) + + case FullOuter => + // TODO(davies): use UnsafeRow + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + (leftHashTable.keySet.asScala ++ rightHashTable.keySet.asScala).iterator.flatMap { key => + val leftRows: CompactBuffer[InternalRow] = { + val rows = leftHashTable.get(key) + if (rows == null) EMPTY_LIST else rows + } + val rightRows: CompactBuffer[InternalRow] = { + val rows = rightHashTable.get(key) + if (rows == null) EMPTY_LIST else rows + } + fullOuterIterator(key, leftRows, rightRows, joinedRow) + } + + case x => + throw new IllegalArgumentException( + s"SortMergeOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8ffa6157a933d..7d8fba047ac6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -56,6 +56,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j case j: SortMergeJoin => j + case j: SortMergeOuterJoin => j } assert(operators.size === 1) @@ -83,12 +84,12 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[SortMergeJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[SortMergeJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[SortMergeJoin]), + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -150,11 +151,11 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { ctx.sql("CACHE TABLE testData") Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[SortMergeJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[SortMergeJoin]) + classOf[SortMergeOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { Seq( From cf8c0429c603c01ecdc5e26c233db2e85ff2e446 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 4 Aug 2015 15:32:30 -0700 Subject: [PATCH 09/56] Fix join operator selection for outer join: Previously, the planner would always choose sort-merge-join for outer joins, even in cases where broadcast outer join could be used. --- .../spark/sql/execution/SparkStrategies.scala | 45 +++++++++------- .../org/apache/spark/sql/JoinSuite.scala | 51 +++++++++---------- 2 files changed, 53 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 420b33efb07cf..7ad1f3f504dac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -62,18 +62,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + // TODO(josh): this class's name is slightly misleading in the sense that it also plans non-hash + // joins, such as SortMergeJoin. Maybe we could just name this something like JoinSelection. /** - * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be - * evaluated by matching hash keys. + * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates + * can be evaluated by matching hash keys. * - * This strategy applies a simple optimization based on the estimates of the physical sizes of - * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an - * estimated physical size smaller than the user-settable threshold - * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the - * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be - * ''broadcasted'' to all of the executors involved in the join, as a - * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they - * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. + * Join implementations are chosen with the following precedence: + * + * - Broadcast: if one side of the join has an estimated physical size that is smaller than the + * user-configurable [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold + * or if that side has an explicit broadcast hint (e.g. the user applied the + * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side + * of the join will be broadcasted and the other side will be streamed, with no shuffling + * performed. If both sides of the join are eligible to be broadcasted then the + * - Sort merge: if the matching join keys are sortable and + * [[org.apache.spark.sql.SQLConf.SORTMERGE_JOIN]] is enabled (default), then sort merge join + * will be used. + * - Hash: will be chosen if neither of the above optimizations apply to this join. */ object HashJoin extends Strategy with PredicateHelper { @@ -90,25 +96,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + + // --- Inner joins -------------------------------------------------------------------------- + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) - // If the sort merge join option is set, we want to use sort merge join prior to hashjoin - // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => val mergeJoin = joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => - joins.SortMergeOuterJoin( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { @@ -120,6 +122,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + // --- Outer joins -------------------------------------------------------------------------- + case ExtractEquiJoinKeys( LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => joins.BroadcastHashOuterJoin( @@ -130,10 +134,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.BroadcastHashOuterJoin( leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + joins.SortMergeOuterJoin( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.ShuffledHashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + // --- Cases where this strategy does not apply --------------------------------------------- + case _ => Nil } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 7d8fba047ac6f..13b0eea1e7ba2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -126,47 +126,46 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { test("broadcasted hash join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - - Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", - classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + for (sortMergeJoinEnabled <- Seq(true, false)) { + withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") { + Seq( + ("SELECT * FROM testData join testData2 ON key = a", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + } } - ctx.sql("UNCACHE TABLE testData") } test("broadcasted hash outer join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - - Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[SortMergeOuterJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[BroadcastHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } - ctx.sql("UNCACHE TABLE testData") } From a09d6e32f61518b215c475424fb6c0634d8ec712 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 4 Aug 2015 15:34:05 -0700 Subject: [PATCH 10/56] Rename HashOuterJoin to OuterJoin. --- .../spark/sql/execution/joins/BroadcastHashOuterJoin.scala | 2 +- .../execution/joins/{HashOuterJoin.scala => OuterJoin.scala} | 2 +- .../spark/sql/execution/joins/ShuffledHashOuterJoin.scala | 2 +- .../apache/spark/sql/execution/joins/SortMergeOuterJoin.scala | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/joins/{HashOuterJoin.scala => OuterJoin.scala} (99%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index e342fd914d321..9fbdb5f0790ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -44,7 +44,7 @@ case class BroadcastHashOuterJoin( joinType: JoinType, condition: Option[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashOuterJoin { + right: SparkPlan) extends BinaryNode with OuterJoin { val timeout = { val timeoutValue = sqlContext.conf.broadcastTimeout diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala index 2946224e93f6c..2937823d30d01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @DeveloperApi -trait HashOuterJoin { +trait OuterJoin { self: SparkPlan => val leftKeys: Seq[Expression] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index ac9c11ab3951e..9ca5410c54dfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -40,7 +40,7 @@ case class ShuffledHashOuterJoin( joinType: JoinType, condition: Option[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashOuterJoin { + right: SparkPlan) extends BinaryNode with OuterJoin { override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index efc194582b59c..90009dabb456a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -39,7 +39,7 @@ case class SortMergeOuterJoin( joinType: JoinType, condition: Option[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashOuterJoin { + right: SparkPlan) extends BinaryNode with OuterJoin { override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil From 58b2d1cbd1f16bb45110bc6d4854c4ef950ace55 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 4 Aug 2015 16:34:53 -0700 Subject: [PATCH 11/56] Clean up non-obvious side-effect in JoinedRow.with[Left|Right] --- .../apache/spark/sql/catalyst/expressions/Projection.scala | 6 +++--- .../spark/sql/execution/joins/BroadcastHashOuterJoin.scala | 6 ++---- .../spark/sql/execution/joins/ShuffledHashOuterJoin.scala | 6 ++---- .../spark/sql/execution/joins/SortMergeOuterJoin.scala | 7 +++---- 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 4296b4b123fc0..7ac6dbacedbc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -185,20 +185,20 @@ class JoinedRow extends InternalRow { } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { + def apply(r1: InternalRow, r2: InternalRow): JoinedRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { + def withLeft(newLeft: InternalRow): JoinedRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { + def withRight(newRight: InternalRow): JoinedRow = { row2 = newRight this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 9fbdb5f0790ca..c8b4f50d16dc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -102,15 +102,13 @@ case class BroadcastHashOuterJoin( case LeftOuter => streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey)) + leftOuterIterator(rowKey, joinedRow.withLeft(currentRow), hashTable.get(rowKey)) }) case RightOuter => streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow.withRight(currentRow)) }) case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 9ca5410c54dfb..3ee8811082075 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -54,8 +54,7 @@ case class ShuffledHashOuterJoin( val keyGenerator = streamedKeyGenerator leftIter.flatMap( currentRow => { val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey)) + leftOuterIterator(rowKey, joinedRow.withLeft(currentRow), hashed.get(rowKey)) }) case RightOuter => @@ -63,8 +62,7 @@ case class ShuffledHashOuterJoin( val keyGenerator = streamedKeyGenerator rightIter.flatMap ( currentRow => { val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow.withRight(currentRow)) }) case FullOuter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 90009dabb456a..f70f7e70f9c88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -71,8 +71,7 @@ case class SortMergeOuterJoin( val keyGenerator = streamedKeyGenerator leftIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey)) + leftOuterIterator(rowKey, joinedRow.withLeft(currentRow), hashed.get(rowKey)) }) case RightOuter => @@ -81,11 +80,11 @@ case class SortMergeOuterJoin( val keyGenerator = streamedKeyGenerator rightIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow.withRight(currentRow)) }) case FullOuter => + // TODO(josh): handle this case efficiently in SMJ // TODO(davies): use UnsafeRow val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) From 07ef4788b07ae2898054de9a8285ed689763def6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 4 Aug 2015 16:39:59 -0700 Subject: [PATCH 12/56] Style cleanup in flatMap; use curly braces instead of parens. --- .../sql/execution/joins/BroadcastHashOuterJoin.scala | 8 ++++---- .../spark/sql/execution/joins/ShuffledHashOuterJoin.scala | 8 ++++---- .../spark/sql/execution/joins/SortMergeOuterJoin.scala | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index c8b4f50d16dc6..7192b059069c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -100,16 +100,16 @@ case class BroadcastHashOuterJoin( joinType match { case LeftOuter => - streamedIter.flatMap(currentRow => { + streamedIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) leftOuterIterator(rowKey, joinedRow.withLeft(currentRow), hashTable.get(rowKey)) - }) + } case RightOuter => - streamedIter.flatMap(currentRow => { + streamedIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow.withRight(currentRow)) - }) + } case x => throw new IllegalArgumentException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 3ee8811082075..ce1a1a7187fff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -52,18 +52,18 @@ case class ShuffledHashOuterJoin( case LeftOuter => val hashed = HashedRelation(rightIter, buildKeyGenerator) val keyGenerator = streamedKeyGenerator - leftIter.flatMap( currentRow => { + leftIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) leftOuterIterator(rowKey, joinedRow.withLeft(currentRow), hashed.get(rowKey)) - }) + } case RightOuter => val hashed = HashedRelation(leftIter, buildKeyGenerator) val keyGenerator = streamedKeyGenerator - rightIter.flatMap ( currentRow => { + rightIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow.withRight(currentRow)) - }) + } case FullOuter => // TODO(davies): use UnsafeRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index f70f7e70f9c88..a8b7f0e9c3d17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -69,19 +69,19 @@ case class SortMergeOuterJoin( // TODO(josh): for SMJ we would buffer keys here: val hashed = HashedRelation(rightIter, buildKeyGenerator) val keyGenerator = streamedKeyGenerator - leftIter.flatMap(currentRow => { + leftIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) leftOuterIterator(rowKey, joinedRow.withLeft(currentRow), hashed.get(rowKey)) - }) + } case RightOuter => // TODO(josh): for SMJ we would buffer keys here: val hashed = HashedRelation(leftIter, buildKeyGenerator) val keyGenerator = streamedKeyGenerator - rightIter.flatMap(currentRow => { + rightIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow.withRight(currentRow)) - }) + } case FullOuter => // TODO(josh): handle this case efficiently in SMJ From c3c7ed41ec38fc5f12b37441e2901b6fc96c4b31 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 4 Aug 2015 16:57:22 -0700 Subject: [PATCH 13/56] Move initialize() definition closer to usage. --- .../spark/sql/execution/joins/SortMergeJoin.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 3fe7167dcd701..5818579efdc15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -81,6 +81,11 @@ case class SortMergeJoin( // initialize iterator initialize() + private def initialize(): Unit = { + fetchLeft() + fetchRight() + } + override final def hasNext: Boolean = nextMatchingPair() override final def next(): InternalRow = { @@ -103,7 +108,7 @@ case class SortMergeJoin( } } - private def fetchLeft() = { + private def fetchLeft(): Unit = { if (leftIter.hasNext) { leftElement = leftIter.next() leftKey = leftKeyGenerator(leftElement) @@ -112,7 +117,7 @@ case class SortMergeJoin( } } - private def fetchRight() = { + private def fetchRight(): Unit = { if (rightIter.hasNext) { rightElement = rightIter.next() rightKey = rightKeyGenerator(rightElement) @@ -121,11 +126,6 @@ case class SortMergeJoin( } } - private def initialize() = { - fetchLeft() - fetchRight() - } - /** * Searches the right iterator for the next rows that have matches in left side, and store * them in a buffer. From 78714dd6c9462b8aed7223d988a2a6b128270f0b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 4 Aug 2015 23:06:06 -0700 Subject: [PATCH 14/56] Large refactoring of SMJ internals to improve clarity. --- .../sql/execution/joins/SortMergeJoin.scala | 220 +++++++++++------- 1 file changed, 132 insertions(+), 88 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 5818579efdc15..316c4eb9a89a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.joins -import java.util.NoSuchElementException - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -67,104 +65,150 @@ case class SortMergeJoin( leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => new Iterator[InternalRow] { - // Mutable per row objects. + private[this] var currentLeftRow: InternalRow = _ + private[this] var currentRightMatches: CompactBuffer[InternalRow] = _ + private[this] var currentMatchIdx: Int = -1 + private[this] val smjScanner = new SortMergeJoinScanner( + leftKeyGenerator, + rightKeyGenerator, + keyOrdering, + leftIter, + rightIter + ) private[this] val joinRow = new JoinedRow - private[this] var leftElement: InternalRow = _ - private[this] var rightElement: InternalRow = _ - private[this] var leftKey: InternalRow = _ - private[this] var rightKey: InternalRow = _ - private[this] var rightMatches: CompactBuffer[InternalRow] = _ - private[this] var rightPosition: Int = -1 - private[this] var stop: Boolean = false - private[this] var matchKey: InternalRow = _ - - // initialize iterator - initialize() - - private def initialize(): Unit = { - fetchLeft() - fetchRight() - } - override final def hasNext: Boolean = nextMatchingPair() - - override final def next(): InternalRow = { - if (hasNext) { - // we are using the buffered right rows and run down left iterator - val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) - rightPosition += 1 - if (rightPosition >= rightMatches.size) { - rightPosition = 0 - fetchLeft() - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false - rightMatches = null - } - } - joinedRow - } else { - // no more result - throw new NoSuchElementException - } - } + override final def hasNext: Boolean = + (currentMatchIdx != -1 && currentMatchIdx < currentRightMatches.length) || fetchNext() - private def fetchLeft(): Unit = { - if (leftIter.hasNext) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) + private[this] def fetchNext(): Boolean = { + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getRightMatches + currentLeftRow = smjScanner.getLeftRow + currentMatchIdx = 0 + true } else { - leftElement = null + currentRightMatches = null + currentLeftRow = null + currentMatchIdx = -1 + false } } - private def fetchRight(): Unit = { - if (rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) - } else { - rightElement = null + override def next(): InternalRow = { + if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { + fetchNext() } + val joinedRow = joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) + currentMatchIdx += 1 + joinedRow } + } + } + } +} - /** - * Searches the right iterator for the next rows that have matches in left side, and store - * them in a buffer. - * - * @return true if the search is successful, and false if the right iterator runs out of - * tuples. - */ - private def nextMatchingPair(): Boolean = { - if (!stop && rightElement != null) { - // run both side to get the first match pair - while (!stop && leftElement != null && rightElement != null) { - val comparing = keyOrdering.compare(leftKey, rightKey) - // for inner join, we need to filter those null keys - stop = comparing == 0 && !leftKey.anyNull - if (comparing > 0 || rightKey.anyNull) { - fetchRight() - } else if (comparing < 0 || leftKey.anyNull) { - fetchLeft() - } - } - rightMatches = new CompactBuffer[InternalRow]() - if (stop) { - stop = false - // iterate the right side to buffer all rows that matches - // as the records should be ordered, exit when we meet the first that not match - while (!stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - stop = keyOrdering.compare(leftKey, rightKey) != 0 - } - if (rightMatches.size > 0) { - rightPosition = 0 - matchKey = leftKey - } - } - } - rightMatches != null && rightMatches.size > 0 +/** + * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]]. + */ +private[joins] class SortMergeJoinScanner( + leftKeyGenerator: Projection, + rightKeyGenerator: Projection, + keyOrdering: RowOrdering, + leftIter: Iterator[InternalRow], + rightIter: Iterator[InternalRow]) { + private[this] var leftRow: InternalRow = _ + private[this] var leftJoinKey: InternalRow = _ + private[this] var rightRow: InternalRow = _ + private[this] var rightJoinKey: InternalRow = _ + /** The join key for the rows buffered in `rightMatches`, or null if `rightMatches` is empty */ + private[this] var matchedJoinKey: InternalRow = _ + /** Buffered rows from the right side of the join. This is never null. */ + private[this] var rightMatches: CompactBuffer[InternalRow] = new CompactBuffer[InternalRow]() + + // Initialization (note: do _not_ want to advance left here). + advanceRight() + + // --- Public methods --------------------------------------------------------------------------- + + /** + * Advances both input iterators, stopping when we have found rows with matching join keys. + * @return true if matching rows have been found and false otherwise. If this returns true, then + * [[getLeftRow]] and [[getRightMatches]] can be called to produce the join results. + */ + final def findNextInnerJoinRows(): Boolean = { + advanceLeft() + if (leftRow == null) { + // We have consumed the entire left iterator, so there can be no more matches. + false + } else if (matchedJoinKey != null && keyOrdering.compare(leftJoinKey, matchedJoinKey) == 0) { + // The new left row has the same join key as the previous row, so return the same matches. + true + } else if (rightRow == null) { + // The left row's join key does not match the current batch of right rows and there are no + // more rows to read from the right iterator, so there can be no more matches. + false + } else { + // Advance both the left and right iterators to find the next pair of matching rows. + var comp = 0 + do { + if (leftJoinKey.anyNull) { + advanceLeft() + } else if (rightJoinKey.anyNull) { + advanceRight() + } else { + comp = keyOrdering.compare(leftJoinKey, rightJoinKey) + if (comp > 0) advanceRight() + else if (comp < 0) advanceLeft() } + } while (leftRow != null && rightRow != null && comp != 0) + if (leftRow == null || rightRow == null) { + // We have either hit the end of one of the iterators, so there can be no more matches. + false + } else { + // The left and right rows have matching join keys, so scan through the right iterator to + // buffer all matching rows. + assert(comp == 0) + matchedJoinKey = leftJoinKey + rightMatches = new CompactBuffer[InternalRow]() + do { + // TODO(josh): if we move the row copying further down, we would do it here: + // TODO(josh): could maybe avoid a copy for case where all rows have exactly one match + rightMatches += rightRow + advanceRight() + } while (rightRow != null && keyOrdering.compare(leftJoinKey, rightJoinKey) == 0) + true } } } + + def getRightMatches: CompactBuffer[InternalRow] = rightMatches + def getLeftRow: InternalRow = leftRow + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the left iterator and compute the new row's join key. + */ + private def advanceLeft(): Unit = { + if (leftIter.hasNext) { + leftRow = leftIter.next() + leftJoinKey = leftKeyGenerator(leftRow) + } else { + leftRow = null + leftJoinKey = null + } + } + + /** + * Advance the right iterator and compute the new row's join key. + */ + private def advanceRight(): Unit = { + if (rightIter.hasNext) { + rightRow = rightIter.next() + rightJoinKey = rightKeyGenerator(rightRow) + } else { + rightRow = null + rightJoinKey = null + } + } } From 124f4ba2796f52998d588630b41be98debd44c2d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 4 Aug 2015 23:11:24 -0700 Subject: [PATCH 15/56] Remove unnecessary row copying. --- .../apache/spark/sql/execution/joins/SortMergeJoin.scala | 9 ++------- .../spark/sql/execution/joins/SortMergeOuterJoin.scala | 5 +---- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 316c4eb9a89a2..5653b3caab1f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -59,11 +59,7 @@ case class SortMergeJoin( keys.map(SortOrder(_, Ascending)) protected override def doExecute(): RDD[InternalRow] = { - // TODO(josh): why is this copying necessary? - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) - - leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => new Iterator[InternalRow] { private[this] var currentLeftRow: InternalRow = _ private[this] var currentRightMatches: CompactBuffer[InternalRow] = _ @@ -171,9 +167,8 @@ private[joins] class SortMergeJoinScanner( matchedJoinKey = leftJoinKey rightMatches = new CompactBuffer[InternalRow]() do { - // TODO(josh): if we move the row copying further down, we would do it here: // TODO(josh): could maybe avoid a copy for case where all rows have exactly one match - rightMatches += rightRow + rightMatches += rightRow.copy() // need to copy mutable rows before buffering them advanceRight() } while (rightRow != null && keyOrdering.compare(leftJoinKey, rightJoinKey) == 0) true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index a8b7f0e9c3d17..279a121c498cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -59,11 +59,8 @@ case class SortMergeOuterJoin( keys.map(SortOrder(_, Ascending)) protected override def doExecute(): RDD[InternalRow] = { - // TODO(josh): why is this copying necessary? - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) val joinedRow = new JoinedRow() - leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => joinType match { case LeftOuter => // TODO(josh): for SMJ we would buffer keys here: From 8c50c307803ef2b7b097d94c3b0a6cd699a8e459 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 5 Aug 2015 01:51:16 -0700 Subject: [PATCH 16/56] Support SMJ for left outer join. --- .../joins/BroadcastHashOuterJoin.scala | 3 +- .../spark/sql/execution/joins/OuterJoin.scala | 21 ++--- .../joins/ShuffledHashOuterJoin.scala | 3 +- .../sql/execution/joins/SortMergeJoin.scala | 90 ++++++++++++++++--- .../execution/joins/SortMergeOuterJoin.scala | 20 +++-- 5 files changed, 102 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 7192b059069c8..16043396a28c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -102,7 +102,8 @@ case class BroadcastHashOuterJoin( case LeftOuter => streamedIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) - leftOuterIterator(rowKey, joinedRow.withLeft(currentRow), hashTable.get(rowKey)) + val matches = if (rowKey.anyNull) null else hashTable.get(rowKey) + leftOuterIterator(joinedRow.withLeft(currentRow), matches) } case RightOuter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala index 2937823d30d01..8674bbb9e7c13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala @@ -122,25 +122,20 @@ trait OuterJoin { // iterator for performance purpose. protected[this] def leftOuterIterator( - key: InternalRow, joinedRow: JoinedRow, rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (rightIter != null) { - rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() - } - } else { - List.empty - } - if (temp.isEmpty) { - resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil - } else { - temp + val temp = if (rightIter != null) { + rightIter.collect { + case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() } } else { + List.empty + } + if (temp.isEmpty) { resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil + } else { + temp } } ret.iterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index ce1a1a7187fff..119318e879cf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -54,7 +54,8 @@ case class ShuffledHashOuterJoin( val keyGenerator = streamedKeyGenerator leftIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) - leftOuterIterator(rowKey, joinedRow.withLeft(currentRow), hashed.get(rowKey)) + val matches = if (rowKey.anyNull) null else hashed.get(rowKey) + leftOuterIterator(joinedRow.withLeft(currentRow), matches) } case RightOuter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 5653b3caab1f9..985a2b05845a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -106,6 +106,8 @@ case class SortMergeJoin( /** * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]]. */ +// TODO(josh): rename to build and probe terminology, which should be easy now that the projection +// building has been moved out of here private[joins] class SortMergeJoinScanner( leftKeyGenerator: Projection, rightKeyGenerator: Projection, @@ -117,9 +119,9 @@ private[joins] class SortMergeJoinScanner( private[this] var rightRow: InternalRow = _ private[this] var rightJoinKey: InternalRow = _ /** The join key for the rows buffered in `rightMatches`, or null if `rightMatches` is empty */ - private[this] var matchedJoinKey: InternalRow = _ - /** Buffered rows from the right side of the join. This is never null. */ - private[this] var rightMatches: CompactBuffer[InternalRow] = new CompactBuffer[InternalRow]() + private[this] var matchJoinKey: InternalRow = _ + /** Buffered rows from the right side of the join. This is null if there are no matches */ + private[this] var rightMatches: CompactBuffer[InternalRow] = _ // Initialization (note: do _not_ want to advance left here). advanceRight() @@ -136,7 +138,7 @@ private[joins] class SortMergeJoinScanner( if (leftRow == null) { // We have consumed the entire left iterator, so there can be no more matches. false - } else if (matchedJoinKey != null && keyOrdering.compare(leftJoinKey, matchedJoinKey) == 0) { + } else if (matchJoinKey != null && keyOrdering.compare(leftJoinKey, matchJoinKey) == 0) { // The new left row has the same join key as the previous row, so return the same matches. true } else if (rightRow == null) { @@ -164,46 +166,106 @@ private[joins] class SortMergeJoinScanner( // The left and right rows have matching join keys, so scan through the right iterator to // buffer all matching rows. assert(comp == 0) - matchedJoinKey = leftJoinKey - rightMatches = new CompactBuffer[InternalRow]() - do { - // TODO(josh): could maybe avoid a copy for case where all rows have exactly one match - rightMatches += rightRow.copy() // need to copy mutable rows before buffering them - advanceRight() - } while (rightRow != null && keyOrdering.compare(leftJoinKey, rightJoinKey) == 0) + bufferMatchingRightRows() true } } } - def getRightMatches: CompactBuffer[InternalRow] = rightMatches + /** + * Advances the left input iterator and buffers all rows from the right input with matching keys. + * @return true if the left iterator returned a row, false otherwise. If this returns true, then + * [[getLeftRow]] and [[getRightMatches]] can be called to produce the outer join results. + */ + final def findNextOuterJoinRows(): Boolean = { + if (advanceLeft()) { + if (leftJoinKey.anyNull) { + // Since at least one join column is null, the left row has no matches. + matchJoinKey = null + rightMatches = null + } else if (matchJoinKey != null && keyOrdering.compare(leftJoinKey, matchJoinKey) == 0) { + // Matches the current group, so do nothing. + } else { + // The left row does not match the current group. + matchJoinKey = null + rightMatches = null + if (rightRow != null) { + // The right iterator could still contain matching rows, so we'll need to scan through it + // until we either find matches or pass where they would be found. + var comp = if (rightJoinKey.anyNull) 1 else keyOrdering.compare(leftJoinKey, rightJoinKey) + while (comp > 0 && advanceRight()) { + comp = if (rightJoinKey.anyNull) 1 else keyOrdering.compare(leftJoinKey, rightJoinKey) + } + if (comp == 0) { + // We have found matches, so buffer them (this updates matchJoinKey) + bufferMatchingRightRows() + } else { + // We have overshot the position where the row would be found, hence no matches. + } + } + } + // If there is a left input, then we always return true since outer join always returns a row. + true + } else { + // End of left input, hence no more results. + false + } + } + def getLeftRow: InternalRow = leftRow + def getRightMatches: CompactBuffer[InternalRow] = rightMatches // --- Private methods -------------------------------------------------------------------------- /** * Advance the left iterator and compute the new row's join key. + * @return true if the left iterator returned a row and false otherwise. */ - private def advanceLeft(): Unit = { + private def advanceLeft(): Boolean = { if (leftIter.hasNext) { leftRow = leftIter.next() leftJoinKey = leftKeyGenerator(leftRow) + true } else { leftRow = null leftJoinKey = null + false } } /** * Advance the right iterator and compute the new row's join key. + * @return true if the right iterator returned a row and false otherwise. */ - private def advanceRight(): Unit = { + private def advanceRight(): Boolean = { if (rightIter.hasNext) { rightRow = rightIter.next() rightJoinKey = rightKeyGenerator(rightRow) + true } else { rightRow = null rightJoinKey = null + false } } + + /** Called when the left and right join keys match in order to buffer the matching right rows. */ + private def bufferMatchingRightRows(): Unit = { + assert(leftJoinKey != null) + assert(!leftJoinKey.anyNull) + assert(rightJoinKey != null) + assert(!rightJoinKey.anyNull) + assert(keyOrdering.compare(leftJoinKey, rightJoinKey) == 0) + matchJoinKey = leftJoinKey.copy() + rightMatches = new CompactBuffer[InternalRow] + do { + // TODO(josh): could maybe avoid a copy for case where all rows have exactly one match + rightMatches += rightRow.copy() // need to copy mutable rows before buffering them + advanceRight() + } while ( + rightRow != null && + !rightJoinKey.anyNull && + keyOrdering.compare(leftJoinKey, rightJoinKey) == 0 + ) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 279a121c498cf..6962d351c9a69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -63,12 +63,20 @@ case class SortMergeOuterJoin( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => joinType match { case LeftOuter => - // TODO(josh): for SMJ we would buffer keys here: - val hashed = HashedRelation(rightIter, buildKeyGenerator) - val keyGenerator = streamedKeyGenerator - leftIter.flatMap { currentRow => - val rowKey = keyGenerator(currentRow) - leftOuterIterator(rowKey, joinedRow.withLeft(currentRow), hashed.get(rowKey)) + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator, + buildKeyGenerator, + keyOrdering, + leftIter, + rightIter // TODO(josh): streamed vs. right/left terminology; may be more explicit to + // just call these arguments with name = value syntax and continue to use + // left and right terminology here. + ) + // TODO(josh): this is a little terse and needs explanation: + Iterator.continually(0).takeWhile(_ => smjScanner.findNextOuterJoinRows()).flatMap { _ => + leftOuterIterator( + joinedRow.withLeft(smjScanner.getLeftRow), + smjScanner.getRightMatches) } case RightOuter => From 8dade551a41493d79c784c54ffd50ec9f72ce809 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 5 Aug 2015 02:05:13 -0700 Subject: [PATCH 17/56] Also enable for right outer join. --- .../joins/BroadcastHashOuterJoin.scala | 3 +- .../spark/sql/execution/joins/OuterJoin.scala | 25 +-- .../joins/ShuffledHashOuterJoin.scala | 3 +- .../sql/execution/joins/SortMergeJoin.scala | 189 +++++++++--------- .../execution/joins/SortMergeOuterJoin.scala | 28 +-- 5 files changed, 131 insertions(+), 117 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 16043396a28c4..52e7f81864f51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -109,7 +109,8 @@ case class BroadcastHashOuterJoin( case RightOuter => streamedIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow.withRight(currentRow)) + val matches = if (rowKey.anyNull) null else hashTable.get(rowKey) + rightOuterIterator(matches, joinedRow.withRight(currentRow)) } case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala index 8674bbb9e7c13..9fa63bc47732c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala @@ -78,6 +78,8 @@ trait OuterJoin { } protected[this] def isUnsafeMode: Boolean = { + // TODO(josh): there is an existing bug here: this should also check whether unsafe mode + // is enabled. also, the default for self.codegenEnabled looks inconsistent to me. (self.codegenEnabled && joinType != FullOuter && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) @@ -142,26 +144,21 @@ trait OuterJoin { } protected[this] def rightOuterIterator( - key: InternalRow, leftIter: Iterable[InternalRow], joinedRow: JoinedRow): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (leftIter != null) { - leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => - resultProjection(joinedRow).copy() - } - } else { - List.empty - } - if (temp.isEmpty) { - resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil - } else { - temp + val temp = if (leftIter != null) { + leftIter.collect { + case l if boundCondition(joinedRow.withLeft(l)) => + resultProjection(joinedRow).copy() } } else { + List.empty + } + if (temp.isEmpty) { resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil + } else { + temp } } ret.iterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 119318e879cf6..341405c62bfda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -63,7 +63,8 @@ case class ShuffledHashOuterJoin( val keyGenerator = streamedKeyGenerator rightIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow.withRight(currentRow)) + val matches = if (rowKey.anyNull) null else hashed.get(rowKey) + rightOuterIterator(matches, joinedRow.withRight(currentRow)) } case FullOuter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 985a2b05845a3..ec40100860cb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -78,8 +78,8 @@ case class SortMergeJoin( private[this] def fetchNext(): Boolean = { if (smjScanner.findNextInnerJoinRows()) { - currentRightMatches = smjScanner.getRightMatches - currentLeftRow = smjScanner.getLeftRow + currentRightMatches = smjScanner.getBuildMatches + currentLeftRow = smjScanner.getStreamedRow currentMatchIdx = 0 true } else { @@ -105,167 +105,178 @@ case class SortMergeJoin( /** * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]]. + * + * The streamed input is the left side of a left outer join or the right side of a right outer join. + * + * // todo(josh): scaladoc + * @param streamedKeyGenerator + * @param buildKeyGenerator + * @param keyOrdering + * @param streamedIter + * @param buildIter */ -// TODO(josh): rename to build and probe terminology, which should be easy now that the projection -// building has been moved out of here private[joins] class SortMergeJoinScanner( - leftKeyGenerator: Projection, - rightKeyGenerator: Projection, + streamedKeyGenerator: Projection, + buildKeyGenerator: Projection, keyOrdering: RowOrdering, - leftIter: Iterator[InternalRow], - rightIter: Iterator[InternalRow]) { - private[this] var leftRow: InternalRow = _ - private[this] var leftJoinKey: InternalRow = _ - private[this] var rightRow: InternalRow = _ - private[this] var rightJoinKey: InternalRow = _ - /** The join key for the rows buffered in `rightMatches`, or null if `rightMatches` is empty */ + streamedIter: Iterator[InternalRow], + buildIter: Iterator[InternalRow]) { + private[this] var streamedRow: InternalRow = _ + private[this] var streamedRowKey: InternalRow = _ + private[this] var buildRow: InternalRow = _ + private[this] var buildRowKey: InternalRow = _ + /** The join key for the rows buffered in `buildMatches`, or null if `buildMatches` is empty */ private[this] var matchJoinKey: InternalRow = _ - /** Buffered rows from the right side of the join. This is null if there are no matches */ - private[this] var rightMatches: CompactBuffer[InternalRow] = _ + /** Buffered rows from the build side of the join. This is null if there are no matches */ + private[this] var buildMatches: CompactBuffer[InternalRow] = _ - // Initialization (note: do _not_ want to advance left here). - advanceRight() + // Initialization (note: do _not_ want to advance streamed here). + advanceBuild() // --- Public methods --------------------------------------------------------------------------- /** * Advances both input iterators, stopping when we have found rows with matching join keys. * @return true if matching rows have been found and false otherwise. If this returns true, then - * [[getLeftRow]] and [[getRightMatches]] can be called to produce the join results. + * [[getStreamedRow]] and [[getBuildMatches]] can be called to produce the join results. */ final def findNextInnerJoinRows(): Boolean = { - advanceLeft() - if (leftRow == null) { - // We have consumed the entire left iterator, so there can be no more matches. + advancedStreamed() + if (streamedRow == null) { + // We have consumed the entire streamed iterator, so there can be no more matches. false - } else if (matchJoinKey != null && keyOrdering.compare(leftJoinKey, matchJoinKey) == 0) { - // The new left row has the same join key as the previous row, so return the same matches. + } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // The new streamed row has the same join key as the previous row, so return the same matches. true - } else if (rightRow == null) { - // The left row's join key does not match the current batch of right rows and there are no - // more rows to read from the right iterator, so there can be no more matches. + } else if (buildRow == null) { + // The streamed row's join key does not match the current batch of build rows and there are no + // more rows to read from the build iterator, so there can be no more matches. false } else { - // Advance both the left and right iterators to find the next pair of matching rows. + // Advance both the streamed and build iterators to find the next pair of matching rows. var comp = 0 do { - if (leftJoinKey.anyNull) { - advanceLeft() - } else if (rightJoinKey.anyNull) { - advanceRight() + if (streamedRowKey.anyNull) { + advancedStreamed() + } else if (buildRowKey.anyNull) { + advanceBuild() } else { - comp = keyOrdering.compare(leftJoinKey, rightJoinKey) - if (comp > 0) advanceRight() - else if (comp < 0) advanceLeft() + comp = keyOrdering.compare(streamedRowKey, buildRowKey) + if (comp > 0) advanceBuild() + else if (comp < 0) advancedStreamed() } - } while (leftRow != null && rightRow != null && comp != 0) - if (leftRow == null || rightRow == null) { + } while (streamedRow != null && buildRow != null && comp != 0) + if (streamedRow == null || buildRow == null) { // We have either hit the end of one of the iterators, so there can be no more matches. false } else { - // The left and right rows have matching join keys, so scan through the right iterator to - // buffer all matching rows. + // The streamed and build rows have matching join keys, so walk through the build iterator + // to buffer all matching rows. assert(comp == 0) - bufferMatchingRightRows() + bufferMatchingBuildRows() true } } } /** - * Advances the left input iterator and buffers all rows from the right input with matching keys. - * @return true if the left iterator returned a row, false otherwise. If this returns true, then - * [[getLeftRow]] and [[getRightMatches]] can be called to produce the outer join results. + * Advances the streamed input iterator and buffers all rows from the build input with matching + * keys. + * @return true if the streamed iterator returned a row, false otherwise. If this returns true, + * then [getStreamedRow and [[getBuildMatches]] can be called to produce the outer + * join results. */ final def findNextOuterJoinRows(): Boolean = { - if (advanceLeft()) { - if (leftJoinKey.anyNull) { - // Since at least one join column is null, the left row has no matches. + if (advancedStreamed()) { + if (streamedRowKey.anyNull) { + // Since at least one join column is null, the streamed row has no matches. matchJoinKey = null - rightMatches = null - } else if (matchJoinKey != null && keyOrdering.compare(leftJoinKey, matchJoinKey) == 0) { + buildMatches = null + } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { // Matches the current group, so do nothing. } else { - // The left row does not match the current group. + // The streamed row does not match the current group. matchJoinKey = null - rightMatches = null - if (rightRow != null) { - // The right iterator could still contain matching rows, so we'll need to scan through it + buildMatches = null + if (buildRow != null) { + // The build iterator could still contain matching rows, so we'll need to walk through it // until we either find matches or pass where they would be found. - var comp = if (rightJoinKey.anyNull) 1 else keyOrdering.compare(leftJoinKey, rightJoinKey) - while (comp > 0 && advanceRight()) { - comp = if (rightJoinKey.anyNull) 1 else keyOrdering.compare(leftJoinKey, rightJoinKey) + var comp = if (buildRowKey.anyNull) 1 else keyOrdering.compare(streamedRowKey, buildRowKey) + while (comp > 0 && advanceBuild()) { + comp = if (buildRowKey.anyNull) 1 else keyOrdering.compare(streamedRowKey, buildRowKey) } if (comp == 0) { // We have found matches, so buffer them (this updates matchJoinKey) - bufferMatchingRightRows() + bufferMatchingBuildRows() } else { // We have overshot the position where the row would be found, hence no matches. } } } - // If there is a left input, then we always return true since outer join always returns a row. + // If there is a streamed input, then we always return true since outer join always returns a row. true } else { - // End of left input, hence no more results. + // End of streamed input, hence no more results. false } } - def getLeftRow: InternalRow = leftRow - def getRightMatches: CompactBuffer[InternalRow] = rightMatches + def getStreamedRow: InternalRow = streamedRow + def getBuildMatches: CompactBuffer[InternalRow] = buildMatches // --- Private methods -------------------------------------------------------------------------- /** - * Advance the left iterator and compute the new row's join key. - * @return true if the left iterator returned a row and false otherwise. + * Advance the streamed iterator and compute the new row's join key. + * @return true if the streamed iterator returned a row and false otherwise. */ - private def advanceLeft(): Boolean = { - if (leftIter.hasNext) { - leftRow = leftIter.next() - leftJoinKey = leftKeyGenerator(leftRow) + private def advancedStreamed(): Boolean = { + if (streamedIter.hasNext) { + streamedRow = streamedIter.next() + streamedRowKey = streamedKeyGenerator(streamedRow) true } else { - leftRow = null - leftJoinKey = null + streamedRow = null + streamedRowKey = null false } } /** - * Advance the right iterator and compute the new row's join key. - * @return true if the right iterator returned a row and false otherwise. + * Advance the build iterator and compute the new row's join key. + * @return true if the build iterator returned a row and false otherwise. */ - private def advanceRight(): Boolean = { - if (rightIter.hasNext) { - rightRow = rightIter.next() - rightJoinKey = rightKeyGenerator(rightRow) + private def advanceBuild(): Boolean = { + if (buildIter.hasNext) { + buildRow = buildIter.next() + buildRowKey = buildKeyGenerator(buildRow) true } else { - rightRow = null - rightJoinKey = null + buildRow = null + buildRowKey = null false } } - /** Called when the left and right join keys match in order to buffer the matching right rows. */ - private def bufferMatchingRightRows(): Unit = { - assert(leftJoinKey != null) - assert(!leftJoinKey.anyNull) - assert(rightJoinKey != null) - assert(!rightJoinKey.anyNull) - assert(keyOrdering.compare(leftJoinKey, rightJoinKey) == 0) - matchJoinKey = leftJoinKey.copy() - rightMatches = new CompactBuffer[InternalRow] + /** + * Called when the streamed and build join keys match in order to buffer the matching build rows. + */ + private def bufferMatchingBuildRows(): Unit = { + assert(streamedRowKey != null) + assert(!streamedRowKey.anyNull) + assert(buildRowKey != null) + assert(!buildRowKey.anyNull) + assert(keyOrdering.compare(streamedRowKey, buildRowKey) == 0) + matchJoinKey = streamedRowKey.copy() + buildMatches = new CompactBuffer[InternalRow] do { // TODO(josh): could maybe avoid a copy for case where all rows have exactly one match - rightMatches += rightRow.copy() // need to copy mutable rows before buffering them - advanceRight() + buildMatches += buildRow.copy() // need to copy mutable rows before buffering them + advanceBuild() } while ( - rightRow != null && - !rightJoinKey.anyNull && - keyOrdering.compare(leftJoinKey, rightJoinKey) == 0 + buildRow != null && + !buildRowKey.anyNull && + keyOrdering.compare(streamedRowKey, buildRowKey) == 0 ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 6962d351c9a69..4c957759e252d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -67,25 +67,29 @@ case class SortMergeOuterJoin( streamedKeyGenerator, buildKeyGenerator, keyOrdering, - leftIter, - rightIter // TODO(josh): streamed vs. right/left terminology; may be more explicit to - // just call these arguments with name = value syntax and continue to use - // left and right terminology here. + streamedIter = leftIter, + buildIter = rightIter ) // TODO(josh): this is a little terse and needs explanation: Iterator.continually(0).takeWhile(_ => smjScanner.findNextOuterJoinRows()).flatMap { _ => leftOuterIterator( - joinedRow.withLeft(smjScanner.getLeftRow), - smjScanner.getRightMatches) + joinedRow.withLeft(smjScanner.getStreamedRow), + smjScanner.getBuildMatches) } case RightOuter => - // TODO(josh): for SMJ we would buffer keys here: - val hashed = HashedRelation(leftIter, buildKeyGenerator) - val keyGenerator = streamedKeyGenerator - rightIter.flatMap { currentRow => - val rowKey = keyGenerator(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow.withRight(currentRow)) + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator, + buildKeyGenerator, + keyOrdering, + streamedIter = rightIter, + buildIter = leftIter + ) + // TODO(josh): this is a little terse and needs explanation: + Iterator.continually(0).takeWhile(_ => smjScanner.findNextOuterJoinRows()).flatMap { _ => + rightOuterIterator( + smjScanner.getBuildMatches, + joinedRow.withRight(smjScanner.getStreamedRow)) } case FullOuter => From 8e496b25f1c7dfceb09e3d2187d7b30d847d9ec1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 5 Aug 2015 02:58:30 -0700 Subject: [PATCH 18/56] Fix scalastyle --- .../org/apache/spark/sql/execution/joins/SortMergeJoin.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index ec40100860cb0..3d9008418587e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -201,7 +201,8 @@ private[joins] class SortMergeJoinScanner( if (buildRow != null) { // The build iterator could still contain matching rows, so we'll need to walk through it // until we either find matches or pass where they would be found. - var comp = if (buildRowKey.anyNull) 1 else keyOrdering.compare(streamedRowKey, buildRowKey) + var comp = + if (buildRowKey.anyNull) 1 else keyOrdering.compare(streamedRowKey, buildRowKey) while (comp > 0 && advanceBuild()) { comp = if (buildRowKey.anyNull) 1 else keyOrdering.compare(streamedRowKey, buildRowKey) } @@ -213,7 +214,7 @@ private[joins] class SortMergeJoinScanner( } } } - // If there is a streamed input, then we always return true since outer join always returns a row. + // If there is a streamed input, then we always return true true } else { // End of streamed input, hence no more results. From 6587ef2d9f5e21686fe74a2c2e35afb24dae44a0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 5 Aug 2015 14:25:51 -0700 Subject: [PATCH 19/56] Rewrite OuterJoinSuite in preparation for adding more tests. --- .../sql/execution/joins/OuterJoinSuite.scala | 128 +++++++++++------- 1 file changed, 82 insertions(+), 46 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 2c27da596bc4f..53d30546ac96a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -17,14 +17,49 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan} -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} class OuterJoinSuite extends SparkPlanTest { + private def testOuterJoin( + testName: String, + leftRows: DataFrame, + rightRows: DataFrame, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + expectedAnswer: Seq[Product]): Unit = { + // Precondition: leftRows and rightRows should be sorted according to the join keys. + + test(s"$testName with ShuffledHashOuterJoin") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, condition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = false) + } + + if (joinType != FullOuter) { + test(s"$testName with BroadcastHashOuterJoin") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, condition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = false) + } + } + + test(s"$testName with SortMergeOuterJoin") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + SortMergeOuterJoin(leftKeys, rightKeys, joinType, condition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = false) + } + } + val left = Seq( (1, 2.0), (2, 1.0), @@ -41,49 +76,50 @@ class OuterJoinSuite extends SparkPlanTest { val rightKeys: List[Expression] = 'c :: Nil val condition = Some(LessThan('b, 'd)) - test("shuffled hash outer join") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null) - ).map(Row.fromTuple)) + testOuterJoin( + "basic left outer join", + left, + right, + leftKeys, + rightKeys, + LeftOuter, + condition, + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null) + ) + ) - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), - Seq( - (2, 1.0, 2, 3.0), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) + testOuterJoin( + "basic right outer join", + left, + right, + leftKeys, + rightKeys, + RightOuter, + condition, + Seq( + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + ) + ) - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - } - - test("broadcast hash outer join") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), - Seq( - (2, 1.0, 2, 3.0), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - } + testOuterJoin( + "basic full outer join", + left, + right, + leftKeys, + rightKeys, + FullOuter, + condition, + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + ) + ) } From 681e87980a88ec93c56095f53a69836751b51373 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 5 Aug 2015 17:01:10 -0700 Subject: [PATCH 20/56] Add tests for outer joins with both inputs empty --- .../sql/execution/joins/OuterJoinSuite.scala | 43 +++++++++++++++++-- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 53d30546ac96a..5d8e35baec48a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -36,7 +36,7 @@ class OuterJoinSuite extends SparkPlanTest { expectedAnswer: Seq[Product]): Unit = { // Precondition: leftRows and rightRows should be sorted according to the join keys. - test(s"$testName with ShuffledHashOuterJoin") { + test(s"$testName using ShuffledHashOuterJoin") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, condition, left, right), expectedAnswer.map(Row.fromTuple), @@ -44,7 +44,7 @@ class OuterJoinSuite extends SparkPlanTest { } if (joinType != FullOuter) { - test(s"$testName with BroadcastHashOuterJoin") { + test(s"$testName using BroadcastHashOuterJoin") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, condition, left, right), expectedAnswer.map(Row.fromTuple), @@ -52,7 +52,7 @@ class OuterJoinSuite extends SparkPlanTest { } } - test(s"$testName with SortMergeOuterJoin") { + test(s"$testName using SortMergeOuterJoin") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => SortMergeOuterJoin(leftKeys, rightKeys, joinType, condition, left, right), expectedAnswer.map(Row.fromTuple), @@ -76,6 +76,8 @@ class OuterJoinSuite extends SparkPlanTest { val rightKeys: List[Expression] = 'c :: Nil val condition = Some(LessThan('b, 'd)) + // --- Basic outer joins ------------------------------------------------------------------------ + testOuterJoin( "basic left outer join", left, @@ -122,4 +124,39 @@ class OuterJoinSuite extends SparkPlanTest { (null, null, 4, 1.0) ) ) + + // --- Both inputs empty ------------------------------------------------------------------------ + + testOuterJoin( + "left outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + leftKeys, + rightKeys, + LeftOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "right outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + leftKeys, + rightKeys, + RightOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "full outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + leftKeys, + rightKeys, + FullOuter, + condition, + Seq.empty + ) } From 37725050e272fc47559a3c96b6cff72fb5ad1a8d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 5 Aug 2015 18:09:22 -0700 Subject: [PATCH 21/56] Fix two minor bugs in SMJ (regression tests pending) --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- .../apache/spark/sql/execution/joins/SortMergeJoin.scala | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f836122b3e0e4..5ec1227a1afa8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -478,7 +478,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) + private[spark] def unsafeEnabled: Boolean = false private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index dff780fb98d96..8a107b6a242d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -142,7 +142,10 @@ private[joins] class SortMergeJoinScanner( * [[getStreamedRow]] and [[getBuildMatches]] can be called to produce the join results. */ final def findNextInnerJoinRows(): Boolean = { - advancedStreamed() + while (advancedStreamed() && streamedRowKey.anyNull) { + // Advance the streamed side of the join until we find the next row whose join key contains + // no nulls or we hit the end of the streamed iterator. + } if (streamedRow == null) { // We have consumed the entire streamed iterator, so there can be no more matches. false @@ -155,7 +158,7 @@ private[joins] class SortMergeJoinScanner( false } else { // Advance both the streamed and build iterators to find the next pair of matching rows. - var comp = 0 + var comp = keyOrdering.compare(streamedRowKey, buildRowKey) do { if (streamedRowKey.anyNull) { advancedStreamed() From 82632c8953384325efaffa3b476177147ca5565f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 6 Aug 2015 17:47:48 -0700 Subject: [PATCH 22/56] Allow UnsafeRows to be processed in SortMergeJoin --- .../apache/spark/sql/execution/joins/OuterJoin.scala | 1 + .../spark/sql/execution/joins/SortMergeJoin.scala | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala index 5fcac3e9fea68..b73d8a5e6100e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala @@ -169,6 +169,7 @@ trait OuterJoin { protected[this] def fullOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], joinedRow: JoinedRow): Iterator[InternalRow] = { + // TODO(josh): why doesn't this use resultProjection? if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 8a107b6a242d1..1143a99ba2d49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -52,6 +52,17 @@ case class SortMergeJoin( @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + protected[this] def isUnsafeMode: Boolean = { + // TODO(josh): there is an existing bug here: this should also check whether unsafe mode + // is enabled. also, the default for self.codegenEnabled looks inconsistent to me. + codegenEnabled && UnsafeProjection.canSupport(leftKeys) && UnsafeProjection.canSupport(schema) + } + + // TODO(josh): this will need to change once we use an Unsafe row joiner + override def outputsUnsafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) From 6e18bc3d3901cc28317d638f5fd621be4b627b05 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 6 Aug 2015 19:24:37 -0700 Subject: [PATCH 23/56] Rename HashJoin to EquiJoinSelection --- sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 2 +- .../org/apache/spark/sql/execution/SparkStrategies.scala | 4 ++-- sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala | 4 ++-- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 6f8ffb54402a7..555223c53e224 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -873,7 +873,7 @@ class SQLContext(@transient val sparkContext: SparkContext) HashAggregation :: Aggregation :: LeftSemiJoin :: - HashJoin :: + EquiJoinSelection :: InMemoryScans :: BasicOperators :: CartesianProduct :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d69555111a7ec..80932118ddf7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -66,7 +66,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // joins, such as SortMergeJoin. Maybe we could just name this something like JoinSelection. /** * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates - * can be evaluated by matching hash keys. + * can be evaluated by matching join keys. * * Join implementations are chosen with the following precedence: * @@ -81,7 +81,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * will be used. * - Hash: will be chosen if neither of the above optimizations apply to this join. */ - object HashJoin extends Strategy with PredicateHelper { + object EquiJoinSelection extends Strategy with PredicateHelper { private[this] def makeBroadcastHashJoin( leftKeys: Seq[Expression], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 13b0eea1e7ba2..870fb599cbaa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -38,7 +38,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = ctx.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -173,7 +173,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = ctx.planner.EquiJoinSelection(join) assert(planned.size === 1) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 567d7fa12ff14..f17177a771c3b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -531,7 +531,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { HashAggregation, Aggregation, LeftSemiJoin, - HashJoin, + EquiJoinSelection, BasicOperators, CartesianProduct, BroadcastNestedLoopJoin From 289e91da1ad9db0fc77723b0fc4fa33b99371980 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 6 Aug 2015 19:26:53 -0700 Subject: [PATCH 24/56] Remove unnecessary requiredChildDistribution from BroadcastHashOuterJoin --- .../spark/sql/execution/joins/BroadcastHashOuterJoin.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 0a562f13a6214..cebdf4912da4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -54,9 +54,6 @@ case class BroadcastHashOuterJoin( } } - override def requiredChildDistribution: Seq[Distribution] = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value From e3f6d71ef9a9e76d70ee62356027cde483dde184 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 6 Aug 2015 19:41:32 -0700 Subject: [PATCH 25/56] Use ArrayBuffer instead of CompactBuffer --- .../sql/execution/joins/SortMergeJoin.scala | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 1143a99ba2d49..737ed395a1a63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution.joins +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.util.collection.CompactBuffer /** * :: DeveloperApi :: @@ -74,7 +75,7 @@ case class SortMergeJoin( // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) private[this] var currentLeftRow: InternalRow = _ - private[this] var currentRightMatches: CompactBuffer[InternalRow] = _ + private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ private[this] var currentMatchIdx: Int = -1 private[this] val smjScanner = new SortMergeJoinScanner( leftKeyGenerator, @@ -139,14 +140,18 @@ private[joins] class SortMergeJoinScanner( private[this] var buildRowKey: InternalRow = _ /** The join key for the rows buffered in `buildMatches`, or null if `buildMatches` is empty */ private[this] var matchJoinKey: InternalRow = _ - /** Buffered rows from the build side of the join. This is null if there are no matches */ - private[this] var buildMatches: CompactBuffer[InternalRow] = _ + /** Buffered rows from the build side of the join. This is empty if there are no matches. */ + private[this] val buildMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] // Initialization (note: do _not_ want to advance streamed here). advanceBuild() // --- Public methods --------------------------------------------------------------------------- + def getStreamedRow: InternalRow = streamedRow + + def getBuildMatches: ArrayBuffer[InternalRow] = buildMatches + /** * Advances both input iterators, stopping when we have found rows with matching join keys. * @return true if matching rows have been found and false otherwise. If this returns true, then @@ -206,13 +211,13 @@ private[joins] class SortMergeJoinScanner( if (streamedRowKey.anyNull) { // Since at least one join column is null, the streamed row has no matches. matchJoinKey = null - buildMatches = null + buildMatches.clear() } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { // Matches the current group, so do nothing. } else { // The streamed row does not match the current group. matchJoinKey = null - buildMatches = null + buildMatches.clear() if (buildRow != null) { // The build iterator could still contain matching rows, so we'll need to walk through it // until we either find matches or pass where they would be found. @@ -237,8 +242,6 @@ private[joins] class SortMergeJoinScanner( } } - def getStreamedRow: InternalRow = streamedRow - def getBuildMatches: CompactBuffer[InternalRow] = buildMatches // --- Private methods -------------------------------------------------------------------------- @@ -284,7 +287,7 @@ private[joins] class SortMergeJoinScanner( assert(!buildRowKey.anyNull) assert(keyOrdering.compare(streamedRowKey, buildRowKey) == 0) matchJoinKey = streamedRowKey.copy() - buildMatches = new CompactBuffer[InternalRow] + buildMatches.clear() do { // TODO(josh): could maybe avoid a copy for case where all rows have exactly one match buildMatches += buildRow.copy() // need to copy mutable rows before buffering them From 075f3722c20b0aaab0e210b05f7ac3334a9dab73 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 6 Aug 2015 19:47:13 -0700 Subject: [PATCH 26/56] Add missing row key null checks in BroadcastHashOuterJoin --- .../sql/execution/joins/BroadcastHashOuterJoin.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index cebdf4912da4a..5ce51c3add222 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -99,15 +99,15 @@ case class BroadcastHashOuterJoin( case LeftOuter => streamedIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(joinedRow, hashTable.get(rowKey), resultProj) + val matches = if (rowKey.anyNull) null else hashTable.get(rowKey) + leftOuterIterator(joinedRow.withLeft(currentRow), matches, resultProj) } case RightOuter => streamedIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(hashTable.get(rowKey), joinedRow, resultProj) + val matches = if (rowKey.anyNull) null else hashTable.get(rowKey) + rightOuterIterator(matches,joinedRow.withRight(currentRow), resultProj) } case x => From df250c8e50088325c5faeebc6fdf0748f101dcc7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 6 Aug 2015 19:55:43 -0700 Subject: [PATCH 27/56] Update to reflect deferral of full outer join to followup patch --- .../spark/sql/execution/SparkStrategies.scala | 9 ++++-- .../execution/joins/SortMergeOuterJoin.scala | 29 ++++--------------- .../org/apache/spark/sql/JoinSuite.scala | 3 +- .../sql/execution/joins/OuterJoinSuite.scala | 12 ++++---- 4 files changed, 21 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 80932118ddf7f..d6291beb4edc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -134,10 +134,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.BroadcastHashOuterJoin( leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => joins.SortMergeOuterJoin( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + joins.SortMergeOuterJoin( + leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.ShuffledHashOuterJoin( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index aa08043948025..5d5dadfea3059 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql.execution.joins -import scala.collection.JavaConverters._ - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType} +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.util.collection.CompactBuffer /** * :: DeveloperApi :: * Performs an sort merge outer join of two child relations. + * + * Note: this does not support full outer join yet; see SPARK-9730 for progress on this. */ @DeveloperApi case class SortMergeOuterJoin( @@ -45,8 +44,9 @@ case class SortMergeOuterJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil override def outputOrdering: Seq[SortOrder] = joinType match { - case FullOuter => Nil // when doing Full Outer join, NULL rows from both sides are not ordered. - case _ => requiredOrders(leftKeys) + case LeftOuter | RightOuter => requiredOrders(leftKeys) + case x => throw new IllegalArgumentException( + s"SortMergeOuterJoin should not take $x as the JoinType") } override def requiredChildOrdering: Seq[Seq[SortOrder]] = @@ -97,23 +97,6 @@ case class SortMergeOuterJoin( resultProj) } - case FullOuter => - // TODO(josh): handle this case efficiently in SMJ - // TODO(davies): use UnsafeRow - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - (leftHashTable.keySet.asScala ++ rightHashTable.keySet.asScala).iterator.flatMap { key => - val leftRows: CompactBuffer[InternalRow] = { - val rows = leftHashTable.get(key) - if (rows == null) EMPTY_LIST else rows - } - val rightRows: CompactBuffer[InternalRow] = { - val rows = rightHashTable.get(key) - if (rows == null) EMPTY_LIST else rows - } - fullOuterIterator(key, leftRows, rightRows, joinedRow) - } - case x => throw new IllegalArgumentException( s"SortMergeOuterJoin should not take $x as the JoinType") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 870fb599cbaa3..ae07eaf91c872 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -89,7 +89,8 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { classOf[SortMergeOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 5d8e35baec48a..27b185f700f81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -50,13 +50,13 @@ class OuterJoinSuite extends SparkPlanTest { expectedAnswer.map(Row.fromTuple), sortAnswers = false) } - } - test(s"$testName using SortMergeOuterJoin") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - SortMergeOuterJoin(leftKeys, rightKeys, joinType, condition, left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = false) + test(s"$testName using SortMergeOuterJoin") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + SortMergeOuterJoin(leftKeys, rightKeys, joinType, condition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = false) + } } } From bdf513c037eb9e5c746b848da884991938926dd9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 6 Aug 2015 23:11:42 -0700 Subject: [PATCH 28/56] Rename build to buffered --- .../sql/execution/joins/SortMergeJoin.scala | 124 +++++++++--------- .../execution/joins/SortMergeOuterJoin.scala | 8 +- 2 files changed, 69 insertions(+), 63 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index a788ac4556477..ade5190493ac0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -90,7 +90,7 @@ case class SortMergeJoin( private[this] def fetchNext(): Boolean = { if (smjScanner.findNextInnerJoinRows()) { - currentRightMatches = smjScanner.getBuildMatches + currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow currentMatchIdx = 0 true @@ -122,39 +122,41 @@ case class SortMergeJoin( * * // todo(josh): scaladoc * @param streamedKeyGenerator - * @param buildKeyGenerator + * @param bufferedKeyGenerator * @param keyOrdering * @param streamedIter - * @param buildIter + * @param bufferedIter */ private[joins] class SortMergeJoinScanner( streamedKeyGenerator: Projection, - buildKeyGenerator: Projection, + bufferedKeyGenerator: Projection, keyOrdering: Ordering[InternalRow], streamedIter: Iterator[InternalRow], - buildIter: Iterator[InternalRow]) { + bufferedIter: Iterator[InternalRow]) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ - private[this] var buildRow: InternalRow = _ - private[this] var buildRowKey: InternalRow = _ - /** The join key for the rows buffered in `buildMatches`, or null if `buildMatches` is empty */ + private[this] var bufferedRow: InternalRow = _ + private[this] var bufferedRowKey: InternalRow = _ + /** + * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty + */ private[this] var matchJoinKey: InternalRow = _ - /** Buffered rows from the build side of the join. This is empty if there are no matches. */ - private[this] val buildMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ + private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] // Initialization (note: do _not_ want to advance streamed here). - advanceBuild() + advancedBuffered() // --- Public methods --------------------------------------------------------------------------- def getStreamedRow: InternalRow = streamedRow - def getBuildMatches: ArrayBuffer[InternalRow] = buildMatches + def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches /** * Advances both input iterators, stopping when we have found rows with matching join keys. * @return true if matching rows have been found and false otherwise. If this returns true, then - * [[getStreamedRow]] and [[getBuildMatches]] can be called to produce the join results. + * [[getStreamedRow]] and [[getBufferedMatches]] can be called to produce the join results. */ final def findNextInnerJoinRows(): Boolean = { while (advancedStreamed() && streamedRowKey.anyNull) { @@ -167,42 +169,42 @@ private[joins] class SortMergeJoinScanner( } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { // The new streamed row has the same join key as the previous row, so return the same matches. true - } else if (buildRow == null) { - // The streamed row's join key does not match the current batch of build rows and there are no - // more rows to read from the build iterator, so there can be no more matches. + } else if (bufferedRow == null) { + // The streamed row's join key does not match the current batch of buffered rows and there are + // no more rows to read from the buffered iterator, so there can be no more matches. false } else { - // Advance both the streamed and build iterators to find the next pair of matching rows. - var comp = keyOrdering.compare(streamedRowKey, buildRowKey) + // Advance both the streamed and buffered iterators to find the next pair of matching rows. + var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) do { if (streamedRowKey.anyNull) { advancedStreamed() - } else if (buildRowKey.anyNull) { - advanceBuild() + } else if (bufferedRowKey.anyNull) { + advancedBuffered() } else { - comp = keyOrdering.compare(streamedRowKey, buildRowKey) - if (comp > 0) advanceBuild() + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + if (comp > 0) advancedBuffered() else if (comp < 0) advancedStreamed() } - } while (streamedRow != null && buildRow != null && comp != 0) - if (streamedRow == null || buildRow == null) { + } while (streamedRow != null && bufferedRow != null && comp != 0) + if (streamedRow == null || bufferedRow == null) { // We have either hit the end of one of the iterators, so there can be no more matches. false } else { - // The streamed and build rows have matching join keys, so walk through the build iterator - // to buffer all matching rows. + // The streamed row's join key matches the current buffered row's join, so walk through the + // buffered iterator to buffer the rest of the matching rows. assert(comp == 0) - bufferMatchingBuildRows() + bufferMatchingRows() true } } } /** - * Advances the streamed input iterator and buffers all rows from the build input with matching - * keys. + * Advances the streamed input iterator and buffers all rows from the buffered input that + * have matching keys. * @return true if the streamed iterator returned a row, false otherwise. If this returns true, - * then [getStreamedRow and [[getBuildMatches]] can be called to produce the outer + * then [getStreamedRow and [[getBufferedMatches]] can be called to produce the outer * join results. */ final def findNextOuterJoinRows(): Boolean = { @@ -210,24 +212,28 @@ private[joins] class SortMergeJoinScanner( if (streamedRowKey.anyNull) { // Since at least one join column is null, the streamed row has no matches. matchJoinKey = null - buildMatches.clear() + bufferedMatches.clear() } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { // Matches the current group, so do nothing. } else { // The streamed row does not match the current group. matchJoinKey = null - buildMatches.clear() - if (buildRow != null) { - // The build iterator could still contain matching rows, so we'll need to walk through it - // until we either find matches or pass where they would be found. + bufferedMatches.clear() + if (bufferedRow != null) { + // The buffered iterator could still contain matching rows, so we'll need to walk through + // it until we either find matches or pass where they would be found. var comp = - if (buildRowKey.anyNull) 1 else keyOrdering.compare(streamedRowKey, buildRowKey) - while (comp > 0 && advanceBuild()) { - comp = if (buildRowKey.anyNull) 1 else keyOrdering.compare(streamedRowKey, buildRowKey) + if (bufferedRowKey.anyNull) 1 else keyOrdering.compare(streamedRowKey, bufferedRowKey) + while (comp > 0 && advancedBuffered()) { + comp = if (bufferedRowKey.anyNull) { + 1 + } else { + keyOrdering.compare(streamedRowKey, bufferedRowKey) + } } if (comp == 0) { // We have found matches, so buffer them (this updates matchJoinKey) - bufferMatchingBuildRows() + bufferMatchingRows() } else { // We have overshot the position where the row would be found, hence no matches. } @@ -261,40 +267,40 @@ private[joins] class SortMergeJoinScanner( } /** - * Advance the build iterator and compute the new row's join key. - * @return true if the build iterator returned a row and false otherwise. + * Advance the buffered iterator and compute the new row's join key. + * @return true if the buffered iterator returned a row and false otherwise. */ - private def advanceBuild(): Boolean = { - if (buildIter.hasNext) { - buildRow = buildIter.next() - buildRowKey = buildKeyGenerator(buildRow) + private def advancedBuffered(): Boolean = { + if (bufferedIter.hasNext) { + bufferedRow = bufferedIter.next() + bufferedRowKey = bufferedKeyGenerator(bufferedRow) true } else { - buildRow = null - buildRowKey = null + bufferedRow = null + bufferedRowKey = null false } } /** - * Called when the streamed and build join keys match in order to buffer the matching build rows. + * Called when the streamed and buffered join keys match in order to buffer the matching rows. */ - private def bufferMatchingBuildRows(): Unit = { + private def bufferMatchingRows(): Unit = { assert(streamedRowKey != null) assert(!streamedRowKey.anyNull) - assert(buildRowKey != null) - assert(!buildRowKey.anyNull) - assert(keyOrdering.compare(streamedRowKey, buildRowKey) == 0) + assert(bufferedRowKey != null) + assert(!bufferedRowKey.anyNull) + assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + // This join key may have been produced by a mutable projection, so we need to make a copy: matchJoinKey = streamedRowKey.copy() - buildMatches.clear() + bufferedMatches.clear() do { - // TODO(josh): could maybe avoid a copy for case where all rows have exactly one match - buildMatches += buildRow.copy() // need to copy mutable rows before buffering them - advanceBuild() + bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them + advancedBuffered() } while ( - buildRow != null && - !buildRowKey.anyNull && - keyOrdering.compare(streamedRowKey, buildRowKey) == 0 + bufferedRow != null && + !bufferedRowKey.anyNull && + keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0 ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 5d5dadfea3059..7fdf4907ef0d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -70,13 +70,13 @@ case class SortMergeOuterJoin( buildKeyGenerator, keyOrdering, streamedIter = leftIter, - buildIter = rightIter + bufferedIter = rightIter ) // TODO(josh): this is a little terse and needs explanation: Iterator.continually(0).takeWhile(_ => smjScanner.findNextOuterJoinRows()).flatMap { _ => leftOuterIterator( joinedRow.withLeft(smjScanner.getStreamedRow), - smjScanner.getBuildMatches, + smjScanner.getBufferedMatches, resultProj) } @@ -87,12 +87,12 @@ case class SortMergeOuterJoin( buildKeyGenerator, keyOrdering, streamedIter = rightIter, - buildIter = leftIter + bufferedIter = leftIter ) // TODO(josh): this is a little terse and needs explanation: Iterator.continually(0).takeWhile(_ => smjScanner.findNextOuterJoinRows()).flatMap { _ => rightOuterIterator( - smjScanner.getBuildMatches, + smjScanner.getBufferedMatches, joinedRow.withRight(smjScanner.getStreamedRow), resultProj) } From 1d8a48cb11ff101cf45078f1a1536f63e9225651 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 6 Aug 2015 23:22:47 -0700 Subject: [PATCH 29/56] Update SortMergeJoin to output UnsafeRow in Unsafe mode --- .../spark/sql/execution/SparkStrategies.scala | 2 -- .../spark/sql/execution/joins/OuterJoin.scala | 1 - .../sql/execution/joins/SortMergeJoin.scala | 21 +++++++++++++------ 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d6291beb4edc8..4c3e6db3a4027 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -62,8 +62,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - // TODO(josh): this class's name is slightly misleading in the sense that it also plans non-hash - // joins, such as SortMergeJoin. Maybe we could just name this something like JoinSelection. /** * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates * can be evaluated by matching join keys. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala index e19b57c8633c2..9869f7e36e3a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala @@ -167,7 +167,6 @@ trait OuterJoin { protected[this] def fullOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], joinedRow: JoinedRow): Iterator[InternalRow] = { - // TODO(josh): why doesn't this use resultProjection? if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index ade5190493ac0..8a8c0227faf48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -54,15 +54,16 @@ case class SortMergeJoin( @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) protected[this] def isUnsafeMode: Boolean = { - codegenEnabled && unsafeEnabled && - UnsafeProjection.canSupport(leftKeys) && UnsafeProjection.canSupport(schema) + (codegenEnabled && unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(schema)) } - // TODO(josh): this will need to change once we use an Unsafe row joiner - override def outputsUnsafeRows: Boolean = false + override def outputsUnsafeRows: Boolean = isUnsafeMode override def canProcessUnsafeRows: Boolean = isUnsafeMode override def canProcessSafeRows: Boolean = !isUnsafeMode + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) @@ -84,6 +85,13 @@ case class SortMergeJoin( rightIter ) private[this] val joinRow = new JoinedRow + private[this] val resultProjection: (InternalRow) => InternalRow = { + if (isUnsafeMode) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } override final def hasNext: Boolean = (currentMatchIdx != -1 && currentMatchIdx < currentRightMatches.length) || fetchNext() @@ -108,7 +116,7 @@ case class SortMergeJoin( } val joinedRow = joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) currentMatchIdx += 1 - joinedRow + resultProjection(joinedRow) } } } @@ -156,7 +164,8 @@ private[joins] class SortMergeJoinScanner( /** * Advances both input iterators, stopping when we have found rows with matching join keys. * @return true if matching rows have been found and false otherwise. If this returns true, then - * [[getStreamedRow]] and [[getBufferedMatches]] can be called to produce the join results. + * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join + * results. */ final def findNextInnerJoinRows(): Boolean = { while (advancedStreamed() && streamedRowKey.anyNull) { From 82b7e452964d10284a0639e2027abe750dd966c0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 6 Aug 2015 23:34:32 -0700 Subject: [PATCH 30/56] Try to clean up confusingly dense one-liner --- .../execution/joins/SortMergeOuterJoin.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 7fdf4907ef0d7..25c1f6081e99d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -72,13 +72,14 @@ case class SortMergeOuterJoin( streamedIter = leftIter, bufferedIter = rightIter ) - // TODO(josh): this is a little terse and needs explanation: - Iterator.continually(0).takeWhile(_ => smjScanner.findNextOuterJoinRows()).flatMap { _ => - leftOuterIterator( + for ( + hasMoreStreamedRows <- Iterator.continually(smjScanner.findNextOuterJoinRows()) + if hasMoreStreamedRows; + result <- leftOuterIterator( joinedRow.withLeft(smjScanner.getStreamedRow), smjScanner.getBufferedMatches, resultProj) - } + ) yield result case RightOuter => val resultProj = createResultProjection() @@ -89,13 +90,14 @@ case class SortMergeOuterJoin( streamedIter = rightIter, bufferedIter = leftIter ) - // TODO(josh): this is a little terse and needs explanation: - Iterator.continually(0).takeWhile(_ => smjScanner.findNextOuterJoinRows()).flatMap { _ => - rightOuterIterator( + for ( + hasMoreStreamedRows <- Iterator.continually(smjScanner.findNextOuterJoinRows()) + if hasMoreStreamedRows; + result <- rightOuterIterator( smjScanner.getBufferedMatches, joinedRow.withRight(smjScanner.getStreamedRow), resultProj) - } + ) yield result case x => throw new IllegalArgumentException( From 4a4590f12671c5d7651592664c2989df0d1d12b0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 6 Aug 2015 23:55:01 -0700 Subject: [PATCH 31/56] Commment update --- .../joins/BroadcastHashOuterJoin.scala | 2 +- .../sql/execution/joins/SortMergeJoin.scala | 20 +++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 5ce51c3add222..a71c43f1bb747 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -107,7 +107,7 @@ case class BroadcastHashOuterJoin( streamedIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) val matches = if (rowKey.anyNull) null else hashTable.get(rowKey) - rightOuterIterator(matches,joinedRow.withRight(currentRow), resultProj) + rightOuterIterator(matches, joinedRow.withRight(currentRow), resultProj) } case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 8a8c0227faf48..715d2f22b68c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -128,12 +128,20 @@ case class SortMergeJoin( * * The streamed input is the left side of a left outer join or the right side of a right outer join. * - * // todo(josh): scaladoc - * @param streamedKeyGenerator - * @param bufferedKeyGenerator - * @param keyOrdering - * @param streamedIter - * @param bufferedIter + * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]] + * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false` + * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return + * the matching row from the streamed input and may call [[getBufferedMatches]] to return the + * sequence of matching rows from the buffered input (in the case of an outer join, this will return + * an empty sequence). For efficiency, both of these methods return mutable objects which are + * re-used across calls to the `findNext*JoinRows()` methods. + * + * @param streamedKeyGenerator a projection that produces join keys from the streamed input. + * @param bufferedKeyGenerator a projection that produces join keys from the buffered input. + * @param keyOrdering an ordering which can be used to compare join keys. + * @param streamedIter an input whose rows will be streamed. + * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that + * have the same join key. */ private[joins] class SortMergeJoinScanner( streamedKeyGenerator: Projection, From 93723e2726ebd2eaeada11c64ca84d10aa0e5e57 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 7 Aug 2015 12:39:38 -0700 Subject: [PATCH 32/56] Experiment towards using efficient internal iterators. --- .../spark/sql/execution/joins/OuterJoin.scala | 6 +- .../sql/execution/joins/RowIterator.scala | 39 +++++++ .../execution/joins/SortMergeOuterJoin.scala | 108 +++++++++++++++--- 3 files changed, 133 insertions(+), 20 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala index 9869f7e36e3a5..3e187d19f9cbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala @@ -113,9 +113,9 @@ trait OuterJoin { @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() - @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) - @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) - @transient private[this] lazy val boundCondition = + @transient protected[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) + @transient protected[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) + @transient protected[this] lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala new file mode 100644 index 0000000000000..25c46c11e442a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.NoSuchElementException + +import org.apache.spark.sql.catalyst.InternalRow + +private[sql] abstract class RowIterator { + def advanceNext(): Boolean + def getNext: InternalRow + def toScala: Iterator[InternalRow] = new RowIteratorToScala(this) +} + +private final class RowIteratorToScala(rowIter: RowIterator) extends Iterator[InternalRow] { + private [this] var _hasNext: Boolean = rowIter.advanceNext() + override def hasNext: Boolean = _hasNext + override def next(): InternalRow = { + if (!_hasNext) throw new NoSuchElementException + val row: InternalRow = rowIter.getNext.copy() + _hasNext = rowIter.advanceNext() + row + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 25c1f6081e99d..ff6bc2208c219 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -58,7 +58,6 @@ case class SortMergeOuterJoin( } protected override def doExecute(): RDD[InternalRow] = { - val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) @@ -72,14 +71,7 @@ case class SortMergeOuterJoin( streamedIter = leftIter, bufferedIter = rightIter ) - for ( - hasMoreStreamedRows <- Iterator.continually(smjScanner.findNextOuterJoinRows()) - if hasMoreStreamedRows; - result <- leftOuterIterator( - joinedRow.withLeft(smjScanner.getStreamedRow), - smjScanner.getBufferedMatches, - resultProj) - ) yield result + new LeftOuterIterator(smjScanner, rightNullRow, boundCondition, resultProj).toScala case RightOuter => val resultProj = createResultProjection() @@ -90,14 +82,7 @@ case class SortMergeOuterJoin( streamedIter = rightIter, bufferedIter = leftIter ) - for ( - hasMoreStreamedRows <- Iterator.continually(smjScanner.findNextOuterJoinRows()) - if hasMoreStreamedRows; - result <- rightOuterIterator( - smjScanner.getBufferedMatches, - joinedRow.withRight(smjScanner.getStreamedRow), - resultProj) - ) yield result + new RightOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj).toScala case x => throw new IllegalArgumentException( @@ -106,3 +91,92 @@ case class SortMergeOuterJoin( } } } + + +private class LeftOuterIterator( + smjScanner: SortMergeJoinScanner, + rightNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var rightIdx: Int = 0 + + private def advanceLeft(): Boolean = { + if (smjScanner.findNextOuterJoinRows()) { + joinedRow.withLeft(smjScanner.getStreamedRow) + if (smjScanner.getBufferedMatches.isEmpty) { + // There are no matching right rows, so return nulls for the right row + joinedRow.withRight(rightNullRow) + } else { + // Find the next row from the right input that satisfied the bound condition + if (!advanceRightUntilBoundConditionSatisfied()) { + joinedRow.withRight(rightNullRow) + } + } + true + } else { + // Left input has been exhausted + false + } + } + + private def advanceRightUntilBoundConditionSatisfied(): Boolean = { + var foundMatch: Boolean = false + if (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) { + foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx))) + rightIdx += 1 + } + foundMatch + } + + override def advanceNext(): Boolean = { + advanceRightUntilBoundConditionSatisfied() || advanceLeft() + } + + override def getNext: InternalRow = resultProj(joinedRow) +} + +private class RightOuterIterator( + smjScanner: SortMergeJoinScanner, + leftNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var rightIdx: Int = 0 + + private def advanceRight(): Boolean = { + if (smjScanner.findNextOuterJoinRows()) { + joinedRow.withRight(smjScanner.getStreamedRow) + if (smjScanner.getBufferedMatches.isEmpty) { + // There are no matching left rows, so return nulls for the left row + joinedRow.withLeft(leftNullRow) + } else { + // Find the next row from the left input that satisfied the bound condition + if (!advanceLeftUntilBoundConditionSatisfied()) { + joinedRow.withLeft(leftNullRow) + } + } + true + } else { + // Right input has been exhausted + false + } + } + + private def advanceLeftUntilBoundConditionSatisfied(): Boolean = { + var foundMatch: Boolean = false + if (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) { + foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(rightIdx))) + rightIdx += 1 + } + foundMatch + } + + override def advanceNext(): Boolean = { + advanceLeftUntilBoundConditionSatisfied() || advanceRight() + } + + override def getNext: InternalRow = resultProj(joinedRow) +} From 441b89a1ec9cce2942d992769bba867e4ce89776 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 7 Aug 2015 14:12:30 -0700 Subject: [PATCH 33/56] Back out now-unnecessary changes to other OuterJoin operators. --- .../joins/BroadcastHashOuterJoin.scala | 241 +++++++++--------- .../{OuterJoin.scala => HashOuterJoin.scala} | 72 +++--- .../joins/ShuffledHashOuterJoin.scala | 183 +++++++------ .../execution/joins/SortMergeOuterJoin.scala | 78 +++++- 4 files changed, 312 insertions(+), 262 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/joins/{OuterJoin.scala => HashOuterJoin.scala} (80%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index a71c43f1bb747..a3626de49aeab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -1,119 +1,122 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.concurrent._ -import scala.concurrent.duration._ - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} -import org.apache.spark.{InternalAccumulator, TaskContext} - -/** - * :: DeveloperApi :: - * Performs a outer hash join for two child relations. When the output RDD of this operator is - * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed - * relation is not shuffled. - */ -@DeveloperApi -case class BroadcastHashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with OuterJoin { - - val timeout = { - val timeoutValue = sqlContext.conf.broadcastTimeout - if (timeoutValue < 0) { - Duration.Inf - } else { - timeoutValue.seconds - } - } - - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - - // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value - // for the same query. - @transient - private lazy val broadcastFuture = { - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - future { - // This will run in another thread. Set the execution id so that we can connect these jobs - // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { - // Note that we use .execute().collect() because we don't want to convert data to Scala - // types - val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) - sparkContext.broadcast(hashed) - } - }(BroadcastHashJoin.broadcastHashJoinExecutionContext) - } - - protected override def doPrepare(): Unit = { - broadcastFuture - } - - override def doExecute(): RDD[InternalRow] = { - val broadcastRelation = Await.result(broadcastFuture, timeout) - - streamedPlan.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow() - val hashTable = broadcastRelation.value - val keyGenerator = streamedKeyGenerator - - hashTable match { - case unsafe: UnsafeHashedRelation => - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) - case _ => - } - - val resultProj = createResultProjection() - joinType match { - case LeftOuter => - streamedIter.flatMap { currentRow => - val rowKey = keyGenerator(currentRow) - val matches = if (rowKey.anyNull) null else hashTable.get(rowKey) - leftOuterIterator(joinedRow.withLeft(currentRow), matches, resultProj) - } - - case RightOuter => - streamedIter.flatMap { currentRow => - val rowKey = keyGenerator(currentRow) - val matches = if (rowKey.anyNull) null else hashTable.get(rowKey) - rightOuterIterator(matches, joinedRow.withRight(currentRow), resultProj) - } - - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - } - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.concurrent._ +import scala.concurrent.duration._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.{InternalAccumulator, TaskContext} + +/** + * :: DeveloperApi :: + * Performs a outer hash join for two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashOuterJoin { + + val timeout = { + val timeoutValue = sqlContext.conf.broadcastTimeout + if (timeoutValue < 0) { + Duration.Inf + } else { + timeoutValue.seconds + } + } + + override def requiredChildDistribution: Seq[Distribution] = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + + // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value + // for the same query. + @transient + private lazy val broadcastFuture = { + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + future { + // This will run in another thread. Set the execution id so that we can connect these jobs + // with the correct execution. + SQLExecution.withExecutionId(sparkContext, executionId) { + // Note that we use .execute().collect() because we don't want to convert data to Scala + // types + val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() + val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) + sparkContext.broadcast(hashed) + } + }(BroadcastHashJoin.broadcastHashJoinExecutionContext) + } + + protected override def doPrepare(): Unit = { + broadcastFuture + } + + override def doExecute(): RDD[InternalRow] = { + val broadcastRelation = Await.result(broadcastFuture, timeout) + + streamedPlan.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow() + val hashTable = broadcastRelation.value + val keyGenerator = streamedKeyGenerator + + hashTable match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + + val resultProj = resultProjection + joinType match { + case LeftOuter => + streamedIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj) + }) + + case RightOuter => + streamedIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj) + }) + + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala similarity index 80% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 3e187d19f9cbf..701bd3cd86372 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/OuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -23,12 +23,11 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @DeveloperApi -trait OuterJoin { +trait HashOuterJoin { self: SparkPlan => val leftKeys: Seq[Expression] @@ -38,7 +37,7 @@ trait OuterJoin { val left: SparkPlan val right: SparkPlan - final override def output: Seq[Attribute] = { + override def output: Seq[Attribute] = { joinType match { case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -47,26 +46,16 @@ trait OuterJoin { case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) case x => - throw new IllegalArgumentException( - s"${getClass.getSimpleName} should not take $x as the JoinType") + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } } - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => - throw new IllegalArgumentException( - s"${getClass.getSimpleName} should not take $x as the JoinType") - } - protected[this] lazy val (buildPlan, streamedPlan) = joinType match { case RightOuter => (left, right) case LeftOuter => (right, left) case x => throw new IllegalArgumentException( - s"${getClass.getSimpleName} should not take $x as the JoinType") + s"HashOuterJoin should not take $x as the JoinType") } protected[this] lazy val (buildKeys, streamedKeys) = joinType match { @@ -74,7 +63,7 @@ trait OuterJoin { case LeftOuter => (rightKeys, leftKeys) case x => throw new IllegalArgumentException( - s"${getClass.getSimpleName} should not take $x as the JoinType") + s"HashOuterJoin should not take $x as the JoinType") } protected[this] def isUnsafeMode: Boolean = { @@ -102,7 +91,7 @@ trait OuterJoin { } } - protected[this] def createResultProjection(): InternalRow => InternalRow = { + protected[this] def resultProjection: InternalRow => InternalRow = { if (isUnsafeMode) { UnsafeProjection.create(self.schema) } else { @@ -113,52 +102,61 @@ trait OuterJoin { @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() - @transient protected[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) - @transient protected[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) - @transient protected[this] lazy val boundCondition = + @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) + @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) + @transient private[this] lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. protected[this] def leftOuterIterator( + key: InternalRow, joinedRow: JoinedRow, rightIter: Iterable[InternalRow], resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { - val temp = if (rightIter != null) { - rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() + if (!key.anyNull) { + val temp = if (rightIter != null) { + rightIter.collect { + case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() + } + } else { + List.empty + } + if (temp.isEmpty) { + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil + } else { + temp } } else { - List.empty - } - if (temp.isEmpty) { resultProjection(joinedRow.withRight(rightNullRow)) :: Nil - } else { - temp } } ret.iterator } protected[this] def rightOuterIterator( + key: InternalRow, leftIter: Iterable[InternalRow], joinedRow: JoinedRow, resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { - val temp = if (leftIter != null) { - leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => - resultProjection(joinedRow).copy() + if (!key.anyNull) { + val temp = if (leftIter != null) { + leftIter.collect { + case l if boundCondition(joinedRow.withLeft(l)) => resultProjection(joinedRow).copy() + } + } else { + List.empty + } + if (temp.isEmpty) { + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil + } else { + temp } } else { - List.empty - } - if (temp.isEmpty) { resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } else { - temp } } ret.iterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 363194f0d0f74..df8d4ff227dd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -1,94 +1,89 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.collection.JavaConverters._ - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.util.collection.CompactBuffer - -/** - * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ -@DeveloperApi -case class ShuffledHashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with OuterJoin { - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - val joinedRow = new JoinedRow() - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - joinType match { - case LeftOuter => - val hashed = HashedRelation(rightIter, buildKeyGenerator) - val keyGenerator = streamedKeyGenerator - val resultProj = createResultProjection() - leftIter.flatMap { currentRow => - val rowKey = keyGenerator(currentRow) - val matches = if (rowKey.anyNull) null else hashed.get(rowKey) - leftOuterIterator(joinedRow.withLeft(currentRow), matches, resultProj) - } - - case RightOuter => - val hashed = HashedRelation(leftIter, buildKeyGenerator) - val keyGenerator = streamedKeyGenerator - val resultProj = createResultProjection() - rightIter.flatMap { currentRow => - val rowKey = keyGenerator(currentRow) - val matches = if (rowKey.anyNull) null else hashed.get(rowKey) - rightOuterIterator(matches, joinedRow.withRight(currentRow), resultProj) - } - - case FullOuter => - // TODO(davies): use UnsafeRow - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - (leftHashTable.keySet.asScala ++ rightHashTable.keySet.asScala).iterator.flatMap { key => - val leftRows: CompactBuffer[InternalRow] = { - val rows = leftHashTable.get(key) - if (rows == null) EMPTY_LIST else rows - } - val rightRows: CompactBuffer[InternalRow] = { - val rows = rightHashTable.get(key) - if (rows == null) EMPTY_LIST else rows - } - fullOuterIterator(key, leftRows, rightRows, joinedRow) - } - - case x => - throw new IllegalArgumentException( - s"ShuffledHashOuterJoin should not take $x as the JoinType") - } - } - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.JavaConversions._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class ShuffledHashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashOuterJoin { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + protected override def doExecute(): RDD[InternalRow] = { + val joinedRow = new JoinedRow() + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + joinType match { + case LeftOuter => + val hashed = HashedRelation(rightIter, buildKeyGenerator) + val keyGenerator = streamedKeyGenerator + val resultProj = resultProjection + leftIter.flatMap( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj) + }) + + case RightOuter => + val hashed = HashedRelation(leftIter, buildKeyGenerator) + val keyGenerator = streamedKeyGenerator + val resultProj = resultProjection + rightIter.flatMap ( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj) + }) + + case FullOuter => + // TODO(davies): use UnsafeRow + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST), + joinedRow) + } + + case x => + throw new IllegalArgumentException( + s"ShuffledHashOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index ff6bc2208c219..71b7be493c17c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -38,17 +38,38 @@ case class SortMergeOuterJoin( joinType: JoinType, condition: Option[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryNode with OuterJoin { + right: SparkPlan) extends BinaryNode { + + override def output: Seq[Attribute] = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + } - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } override def outputOrdering: Seq[SortOrder] = joinType match { - case LeftOuter | RightOuter => requiredOrders(leftKeys) + case LeftOuter => requiredOrders(leftKeys) + case RightOuter => requiredOrders(rightKeys) case x => throw new IllegalArgumentException( s"SortMergeOuterJoin should not take $x as the JoinType") } + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil @@ -57,31 +78,64 @@ case class SortMergeOuterJoin( keys.map(SortOrder(_, Ascending)) } - protected override def doExecute(): RDD[InternalRow] = { + private def isUnsafeMode: Boolean = { + (codegenEnabled && unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(schema)) + } + + private def createLeftKeyGenerator(): Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(leftKeys, left.output) + } else { + newProjection(leftKeys, left.output) + } + } + + private def createRightKeyGenerator(): Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(rightKeys, right.output) + } else { + newProjection(rightKeys, right.output) + } + } + + override def doExecute(): RDD[InternalRow] = { left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) + val boundCondition = + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + val resultProj: InternalRow => InternalRow = { + if (isUnsafeMode) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } + joinType match { case LeftOuter => - val resultProj = createResultProjection() val smjScanner = new SortMergeJoinScanner( - streamedKeyGenerator, - buildKeyGenerator, + streamedKeyGenerator = createLeftKeyGenerator(), + bufferedKeyGenerator = createRightKeyGenerator(), keyOrdering, streamedIter = leftIter, bufferedIter = rightIter ) + val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator(smjScanner, rightNullRow, boundCondition, resultProj).toScala case RightOuter => - val resultProj = createResultProjection() val smjScanner = new SortMergeJoinScanner( - streamedKeyGenerator, - buildKeyGenerator, + streamedKeyGenerator = createRightKeyGenerator(), + bufferedKeyGenerator = createLeftKeyGenerator(), keyOrdering, streamedIter = rightIter, bufferedIter = leftIter ) + val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj).toScala case x => From d16b60a662335695f6ae1a74ab63ff2427d5d1eb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 7 Aug 2015 14:13:44 -0700 Subject: [PATCH 34/56] Revert another unnecessary change. --- .../spark/sql/execution/joins/ShuffledHashOuterJoin.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index df8d4ff227dd2..6a8c35efca8f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -44,6 +44,14 @@ case class ShuffledHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") + } + protected override def doExecute(): RDD[InternalRow] = { val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => From 2c1253f61e3395b258d6b7aa17f020f7312859bc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 7 Aug 2015 14:44:09 -0700 Subject: [PATCH 35/56] Add RowIterator.fromScala and use it to guarantee that copying is unnecessary. --- .../sql/execution/joins/RowIterator.scala | 30 +++++++++++++++++-- .../sql/execution/joins/SortMergeJoin.scala | 16 +++++----- .../execution/joins/SortMergeOuterJoin.scala | 14 ++++----- 3 files changed, 42 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala index 25c46c11e442a..8efc469bab0fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala @@ -23,17 +23,41 @@ import org.apache.spark.sql.catalyst.InternalRow private[sql] abstract class RowIterator { def advanceNext(): Boolean - def getNext: InternalRow + def getRow: InternalRow def toScala: Iterator[InternalRow] = new RowIteratorToScala(this) } -private final class RowIteratorToScala(rowIter: RowIterator) extends Iterator[InternalRow] { +object RowIterator { + def fromScala(scalaIter: Iterator[InternalRow]): RowIterator = { + scalaIter match { + case wrappedRowIter: RowIteratorToScala => wrappedRowIter.rowIter + case _ => new RowIteratorFromScala(scalaIter) + } + } +} + +private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] { private [this] var _hasNext: Boolean = rowIter.advanceNext() override def hasNext: Boolean = _hasNext override def next(): InternalRow = { if (!_hasNext) throw new NoSuchElementException - val row: InternalRow = rowIter.getNext.copy() + val row: InternalRow = rowIter.getRow.copy() _hasNext = rowIter.advanceNext() row } } + +private final class RowIteratorFromScala(scalaIter: Iterator[InternalRow]) extends RowIterator { + private[this] var _next: InternalRow = null + override def advanceNext(): Boolean = { + if (scalaIter.hasNext) { + _next = scalaIter.next() + true + } else { + _next = null + false + } + } + override def getRow: InternalRow = _next + override def toScala: Iterator[InternalRow] = scalaIter +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 715d2f22b68c5..a4cfac30a1945 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -81,8 +81,8 @@ case class SortMergeJoin( leftKeyGenerator, rightKeyGenerator, keyOrdering, - leftIter, - rightIter + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter) ) private[this] val joinRow = new JoinedRow private[this] val resultProjection: (InternalRow) => InternalRow = { @@ -147,8 +147,8 @@ private[joins] class SortMergeJoinScanner( streamedKeyGenerator: Projection, bufferedKeyGenerator: Projection, keyOrdering: Ordering[InternalRow], - streamedIter: Iterator[InternalRow], - bufferedIter: Iterator[InternalRow]) { + streamedIter: RowIterator, + bufferedIter: RowIterator) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -272,8 +272,8 @@ private[joins] class SortMergeJoinScanner( * @return true if the streamed iterator returned a row and false otherwise. */ private def advancedStreamed(): Boolean = { - if (streamedIter.hasNext) { - streamedRow = streamedIter.next() + if (streamedIter.advanceNext()) { + streamedRow = streamedIter.getRow streamedRowKey = streamedKeyGenerator(streamedRow) true } else { @@ -288,8 +288,8 @@ private[joins] class SortMergeJoinScanner( * @return true if the buffered iterator returned a row and false otherwise. */ private def advancedBuffered(): Boolean = { - if (bufferedIter.hasNext) { - bufferedRow = bufferedIter.next() + if (bufferedIter.advanceNext()) { + bufferedRow = bufferedIter.getRow bufferedRowKey = bufferedKeyGenerator(bufferedRow) true } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 71b7be493c17c..d0d075017c132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -121,8 +121,8 @@ case class SortMergeOuterJoin( streamedKeyGenerator = createLeftKeyGenerator(), bufferedKeyGenerator = createRightKeyGenerator(), keyOrdering, - streamedIter = leftIter, - bufferedIter = rightIter + streamedIter = RowIterator.fromScala(leftIter), + bufferedIter = RowIterator.fromScala(rightIter) ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator(smjScanner, rightNullRow, boundCondition, resultProj).toScala @@ -132,8 +132,8 @@ case class SortMergeOuterJoin( streamedKeyGenerator = createRightKeyGenerator(), bufferedKeyGenerator = createLeftKeyGenerator(), keyOrdering, - streamedIter = rightIter, - bufferedIter = leftIter + streamedIter = RowIterator.fromScala(rightIter), + bufferedIter = RowIterator.fromScala(leftIter) ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj).toScala @@ -188,7 +188,7 @@ private class LeftOuterIterator( advanceRightUntilBoundConditionSatisfied() || advanceLeft() } - override def getNext: InternalRow = resultProj(joinedRow) + override def getRow: InternalRow = resultProj(joinedRow) } private class RightOuterIterator( @@ -232,5 +232,5 @@ private class RightOuterIterator( advanceLeftUntilBoundConditionSatisfied() || advanceRight() } - override def getNext: InternalRow = resultProj(joinedRow) + override def getRow: InternalRow = resultProj(joinedRow) } From 2b6845261c74836225e99d840c161dd2d431bbf7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 7 Aug 2015 16:25:10 -0700 Subject: [PATCH 36/56] Override row format methods for SortMergeOuterJoin --- .../org/apache/spark/sql/execution/joins/SortMergeJoin.scala | 1 - .../apache/spark/sql/execution/joins/SortMergeOuterJoin.scala | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index a4cfac30a1945..3ab58709b21b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -63,7 +63,6 @@ case class SortMergeJoin( override def canProcessUnsafeRows: Boolean = isUnsafeMode override def canProcessSafeRows: Boolean = !isUnsafeMode - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index d0d075017c132..94b9482a2044b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -85,6 +85,10 @@ case class SortMergeOuterJoin( && UnsafeProjection.canSupport(schema)) } + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + private def createLeftKeyGenerator(): Projection = { if (isUnsafeMode) { UnsafeProjection.create(leftKeys, left.output) From d41ac51efe8c6481293ec97bc913c290e91c1b01 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 7 Aug 2015 18:31:57 -0700 Subject: [PATCH 37/56] Fix bug in advancing leftIdx/rightIdx. --- .../sql/execution/joins/SortMergeOuterJoin.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 94b9482a2044b..06c23f79b00c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -159,8 +159,10 @@ private class LeftOuterIterator( ) extends RowIterator { private[this] val joinedRow: JoinedRow = new JoinedRow() private[this] var rightIdx: Int = 0 + assert(smjScanner.getBufferedMatches.length == 0) private def advanceLeft(): Boolean = { + rightIdx = 0 if (smjScanner.findNextOuterJoinRows()) { joinedRow.withLeft(smjScanner.getStreamedRow) if (smjScanner.getBufferedMatches.isEmpty) { @@ -202,9 +204,11 @@ private class RightOuterIterator( resultProj: InternalRow => InternalRow ) extends RowIterator { private[this] val joinedRow: JoinedRow = new JoinedRow() - private[this] var rightIdx: Int = 0 + private[this] var leftIdx: Int = 0 + assert(smjScanner.getBufferedMatches.length == 0) private def advanceRight(): Boolean = { + leftIdx = 0 if (smjScanner.findNextOuterJoinRows()) { joinedRow.withRight(smjScanner.getStreamedRow) if (smjScanner.getBufferedMatches.isEmpty) { @@ -225,9 +229,9 @@ private class RightOuterIterator( private def advanceLeftUntilBoundConditionSatisfied(): Boolean = { var foundMatch: Boolean = false - if (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) { - foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(rightIdx))) - rightIdx += 1 + if (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) { + foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx))) + leftIdx += 1 } foundMatch } From 9f48a5c532a7af8ce05e043bcf48ef2782ba02be Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 7 Aug 2015 18:32:23 -0700 Subject: [PATCH 38/56] Efficiency improvement in boundCondition. --- .../spark/sql/execution/joins/SortMergeOuterJoin.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 06c23f79b00c9..4e3445e6eaa7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -109,8 +109,13 @@ case class SortMergeOuterJoin( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) - val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + val boundCondition: (InternalRow) => Boolean = { + condition.map { cond => + newPredicate(cond, left.output ++ right.output) + }.getOrElse { + (r: InternalRow) => true + } + } val resultProj: InternalRow => InternalRow = { if (isUnsafeMode) { UnsafeProjection.create(schema) From 1813a457839670015f2527f1f93438d64e294969 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 8 Aug 2015 12:39:58 -0700 Subject: [PATCH 39/56] For left and right outer joins, streamed rows should not have null join keys. --- .../sql/execution/joins/SortMergeJoin.scala | 29 ++++++++++++------- .../execution/joins/SortMergeOuterJoin.scala | 2 ++ 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 3ab58709b21b5..5a42ebaf6e047 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -181,6 +181,8 @@ private[joins] class SortMergeJoinScanner( } if (streamedRow == null) { // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() false } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { // The new streamed row has the same join key as the previous row, so return the same matches. @@ -188,6 +190,8 @@ private[joins] class SortMergeJoinScanner( } else if (bufferedRow == null) { // The streamed row's join key does not match the current batch of buffered rows and there are // no more rows to read from the buffered iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() false } else { // Advance both the streamed and buffered iterators to find the next pair of matching rows. @@ -205,6 +209,8 @@ private[joins] class SortMergeJoinScanner( } while (streamedRow != null && bufferedRow != null && comp != 0) if (streamedRow == null || bufferedRow == null) { // We have either hit the end of one of the iterators, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() false } else { // The streamed row's join key matches the current buffered row's join, so walk through the @@ -224,12 +230,17 @@ private[joins] class SortMergeJoinScanner( * join results. */ final def findNextOuterJoinRows(): Boolean = { - if (advancedStreamed()) { - if (streamedRowKey.anyNull) { - // Since at least one join column is null, the streamed row has no matches. - matchJoinKey = null - bufferedMatches.clear() - } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + while (advancedStreamed() && streamedRowKey.anyNull) { + // Advance the streamed side of the join until we find the next row whose join key contains + // no nulls or we hit the end of the streamed iterator. + } + if (streamedRow == null) { + // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { // Matches the current group, so do nothing. } else { // The streamed row does not match the current group. @@ -255,15 +266,11 @@ private[joins] class SortMergeJoinScanner( } } } - // If there is a streamed input, then we always return true + // If there is a streamed input with a non-null join key, then we always return true true - } else { - // End of streamed input, hence no more results. - false } } - // --- Private methods -------------------------------------------------------------------------- /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 4e3445e6eaa7c..fe5f835167410 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -169,6 +169,7 @@ private class LeftOuterIterator( private def advanceLeft(): Boolean = { rightIdx = 0 if (smjScanner.findNextOuterJoinRows()) { + assert(!smjScanner.getStreamedRow.anyNull) joinedRow.withLeft(smjScanner.getStreamedRow) if (smjScanner.getBufferedMatches.isEmpty) { // There are no matching right rows, so return nulls for the right row @@ -215,6 +216,7 @@ private class RightOuterIterator( private def advanceRight(): Boolean = { leftIdx = 0 if (smjScanner.findNextOuterJoinRows()) { + assert(!smjScanner.getStreamedRow.anyNull) joinedRow.withRight(smjScanner.getStreamedRow) if (smjScanner.getBufferedMatches.isEmpty) { // There are no matching left rows, so return nulls for the left row From f183307f7bc7e5092979317346af531364a94a33 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 8 Aug 2015 12:41:39 -0700 Subject: [PATCH 40/56] Two minor comments on output ordering --- .../apache/spark/sql/execution/joins/SortMergeOuterJoin.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index fe5f835167410..f7462cbd901cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -53,6 +53,7 @@ case class SortMergeOuterJoin( } override def outputPartitioning: Partitioning = joinType match { + // For left and right outer joins, the output is partitioned by the streamed input's join keys. case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case x => @@ -61,6 +62,7 @@ case class SortMergeOuterJoin( } override def outputOrdering: Seq[SortOrder] = joinType match { + // For left and right outer joins, the output is ordered by the streamed input's join keys. case LeftOuter => requiredOrders(leftKeys) case RightOuter => requiredOrders(rightKeys) case x => throw new IllegalArgumentException( From f45608624bd566f08e2c5e7507d43c1ae16a61e2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 8 Aug 2015 12:54:53 -0700 Subject: [PATCH 41/56] Add note RE: non-nullability of streamed side's join keys. --- .../apache/spark/sql/execution/joins/SortMergeOuterJoin.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index f7462cbd901cf..acfb04a3f044b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -43,8 +43,10 @@ case class SortMergeOuterJoin( override def output: Seq[Attribute] = { joinType match { case LeftOuter => + // Note: technically the left join keys will not be nullable here: left.output ++ right.output.map(_.withNullability(true)) case RightOuter => + // Note: technically the left right keys will not be nullable here: left.output.map(_.withNullability(true)) ++ right.output case x => throw new IllegalArgumentException( From 2e5eb2d6beb13227d9fe96134a26e82bbeb2fac8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 8 Aug 2015 13:10:22 -0700 Subject: [PATCH 42/56] Fix loss of rows when removing RowIteratorToScala wrapper. --- .../spark/sql/execution/joins/RowIterator.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala index 8efc469bab0fe..1e978da01e1de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala @@ -30,17 +30,24 @@ private[sql] abstract class RowIterator { object RowIterator { def fromScala(scalaIter: Iterator[InternalRow]): RowIterator = { scalaIter match { - case wrappedRowIter: RowIteratorToScala => wrappedRowIter.rowIter + case wrappedRowIter: RowIteratorToScala if !wrappedRowIter._wasUsed => wrappedRowIter.rowIter case _ => new RowIteratorFromScala(scalaIter) } } } private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] { - private [this] var _hasNext: Boolean = rowIter.advanceNext() - override def hasNext: Boolean = _hasNext + var _wasUsed: Boolean = false + private [this] var _hasNext: Boolean = false + override def hasNext: Boolean = { + if (!_wasUsed) { + _hasNext = rowIter.advanceNext() + _wasUsed = true + } + _hasNext + } override def next(): InternalRow = { - if (!_hasNext) throw new NoSuchElementException + if (!hasNext) throw new NoSuchElementException val row: InternalRow = rowIter.getRow.copy() _hasNext = rowIter.advanceNext() row From a7a24f5b6cafb268b1d75dbf3d2f9a21d27f57f3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 8 Aug 2015 13:13:48 -0700 Subject: [PATCH 43/56] Use RowIterator in SortMergeJoin as well --- .../sql/execution/joins/SortMergeJoin.scala | 39 +++++++++---------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 5a42ebaf6e047..5e990396ad0ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -70,7 +70,7 @@ case class SortMergeJoin( protected override def doExecute(): RDD[InternalRow] = { left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - new Iterator[InternalRow] { + new RowIterator { // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) private[this] var currentLeftRow: InternalRow = _ @@ -92,32 +92,29 @@ case class SortMergeJoin( } } - override final def hasNext: Boolean = - (currentMatchIdx != -1 && currentMatchIdx < currentRightMatches.length) || fetchNext() - - private[this] def fetchNext(): Boolean = { - if (smjScanner.findNextInnerJoinRows()) { - currentRightMatches = smjScanner.getBufferedMatches - currentLeftRow = smjScanner.getStreamedRow - currentMatchIdx = 0 + override def advanceNext(): Boolean = { + if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + } else { + currentRightMatches = null + currentLeftRow = null + currentMatchIdx = -1 + } + } + if (currentLeftRow != null) { + joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) + currentMatchIdx += 1 true } else { - currentRightMatches = null - currentLeftRow = null - currentMatchIdx = -1 false } } - override def next(): InternalRow = { - if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { - fetchNext() - } - val joinedRow = joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) - currentMatchIdx += 1 - resultProjection(joinedRow) - } - } + override def getRow: InternalRow = resultProjection(joinRow) + }.toScala } } } From 7910e8384158ba40f1a7863d185ea9020b6fac76 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 8 Aug 2015 14:57:15 -0700 Subject: [PATCH 44/56] Add giant comment to RowIterator --- .../sql/execution/joins/RowIterator.scala | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala index 1e978da01e1de..b239145b4ece9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala @@ -21,9 +21,49 @@ import java.util.NoSuchElementException import org.apache.spark.sql.catalyst.InternalRow +/** + * An internal iterator interface which presents a more restrictive API than + * [[scala.collection.Iterator]]. + * + * One major departure from the Scala iterator API is the fusing of the `hasNext()` and `next()` + * calls: Scala's iterator allows users to call `hasNext()` without immediately advancing the + * iterator to consume the next row, whereas RowIterator combines these calls into a single + * [[advanceNext()]] method. + * + * In some cases, significant work may need to be performed in order to determine whether there is a + * next element (for example, a `filter`ed iterator may need to consume many elements of its parent + * iterator in order to determine whether there is a next row). As a result, many Scala iterators + * perform implicit internal buffering, which can cause problems with iterators that return the same + * mutable Row on every `next()` call. If we call `.filter()` on a Scala iterator of InternalRow, + * calling `hasNext()` may mutate the row that has already been returned from the iterator. This + * can cause problems unless the caller expects to immediately call `next()` after `hasNext()` + * returned true. + * + * We can guard against this anomaly by automatically copying rows before returning them to a Scala + * iterator; RowIterator's [[toScala]] method returns a wrapper which automatically performs this + * defensive copying. These copies carry a performance penalty, though, so ideally we should avoid + * this. The `RowIterator.fromScala` method wraps a Scala iterator behind our more restrictive + * iterator interface. As an optimization, calling `RowIterator.fromScala` on a wrapped RowIterator + * will return the underlying RowIterator, avoiding the copying. Thus, by gradually re-writing + * operators to use our [[RowIterator]] wrappers we can safely remove this defensive row copying. + */ private[sql] abstract class RowIterator { + /** + * Advance this iterator by a single row. Returns `false` if this iterator has no more rows + * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling + * [[getRow]]. + */ def advanceNext(): Boolean + + /** + * Retrieve the row from this iterator. This method is idempotent. It is illegal to call this + * method after [[advanceNext()]] has returned `false`. + */ def getRow: InternalRow + + /** + * Convert this RowIterator into a [[scala.collection.Iterator]]. + */ def toScala: Iterator[InternalRow] = new RowIteratorToScala(this) } From fd439cb98bb496cebd0075c10c6ff0153f7dcfe8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 8 Aug 2015 15:02:07 -0700 Subject: [PATCH 45/56] Move RowIterator to execution package --- .../apache/spark/sql/execution/{joins => }/RowIterator.scala | 2 +- .../org/apache/spark/sql/execution/joins/SortMergeJoin.scala | 2 +- .../apache/spark/sql/execution/joins/SortMergeOuterJoin.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{joins => }/RowIterator.scala (99%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala index b239145b4ece9..6cc83f9c5e4fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/RowIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.joins +package org.apache.spark.sql.execution import java.util.NoSuchElementException diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 5e990396ad0ac..ba8140708aab1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -24,7 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index acfb04a3f044b..74582d09d1637 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} /** * :: DeveloperApi :: From 51ee4b24d4cded56b0c5e5702d61123f4a1bed18 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 8 Aug 2015 15:11:24 -0700 Subject: [PATCH 46/56] Remove incorrect assertions; the non-join-key columns can be null --- .../apache/spark/sql/execution/joins/SortMergeOuterJoin.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 74582d09d1637..e72b85ba56afe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -173,7 +173,6 @@ private class LeftOuterIterator( private def advanceLeft(): Boolean = { rightIdx = 0 if (smjScanner.findNextOuterJoinRows()) { - assert(!smjScanner.getStreamedRow.anyNull) joinedRow.withLeft(smjScanner.getStreamedRow) if (smjScanner.getBufferedMatches.isEmpty) { // There are no matching right rows, so return nulls for the right row @@ -220,7 +219,6 @@ private class RightOuterIterator( private def advanceRight(): Boolean = { leftIdx = 0 if (smjScanner.findNextOuterJoinRows()) { - assert(!smjScanner.getStreamedRow.anyNull) joinedRow.withRight(smjScanner.getStreamedRow) if (smjScanner.getBufferedMatches.isEmpty) { // There are no matching left rows, so return nulls for the left row From e23db3db155658c8d916a30d763fc77152ab92bc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 8 Aug 2015 15:16:03 -0700 Subject: [PATCH 47/56] Experiment with removing copy --- .../apache/spark/sql/execution/RowIterator.scala | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala index 6cc83f9c5e4fe..5ea1a7bd4fcaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala @@ -70,27 +70,28 @@ private[sql] abstract class RowIterator { object RowIterator { def fromScala(scalaIter: Iterator[InternalRow]): RowIterator = { scalaIter match { - case wrappedRowIter: RowIteratorToScala if !wrappedRowIter._wasUsed => wrappedRowIter.rowIter + case wrappedRowIter: RowIteratorToScala => wrappedRowIter.rowIter case _ => new RowIteratorFromScala(scalaIter) } } } private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] { - var _wasUsed: Boolean = false + private [this] var hasNextWasCalled: Boolean = false private [this] var _hasNext: Boolean = false override def hasNext: Boolean = { - if (!_wasUsed) { + // Idempotency: + if (!hasNextWasCalled) { _hasNext = rowIter.advanceNext() - _wasUsed = true + hasNextWasCalled = true } _hasNext } override def next(): InternalRow = { if (!hasNext) throw new NoSuchElementException - val row: InternalRow = rowIter.getRow.copy() - _hasNext = rowIter.advanceNext() - row + // TODO(josh): see whether we need to re-add the copy() here: + hasNextWasCalled = false + rowIter.getRow } } From f70165228ddb6041e17756092d487d61267e3f00 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 8 Aug 2015 17:54:46 -0700 Subject: [PATCH 48/56] Fix incorrectly-placed null check. --- .../spark/sql/execution/joins/SortMergeJoin.scala | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index ba8140708aab1..5970e4546f2df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -227,11 +227,7 @@ private[joins] class SortMergeJoinScanner( * join results. */ final def findNextOuterJoinRows(): Boolean = { - while (advancedStreamed() && streamedRowKey.anyNull) { - // Advance the streamed side of the join until we find the next row whose join key contains - // no nulls or we hit the end of the streamed iterator. - } - if (streamedRow == null) { + if (!advancedStreamed()) { // We have consumed the entire streamed iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() @@ -243,7 +239,7 @@ private[joins] class SortMergeJoinScanner( // The streamed row does not match the current group. matchJoinKey = null bufferedMatches.clear() - if (bufferedRow != null) { + if (bufferedRow != null && !streamedRowKey.anyNull) { // The buffered iterator could still contain matching rows, so we'll need to walk through // it until we either find matches or pass where they would be found. var comp = @@ -263,7 +259,7 @@ private[joins] class SortMergeJoinScanner( } } } - // If there is a streamed input with a non-null join key, then we always return true + // If there is a streamed input then we always return true true } } From 7d3cc5d03e1f715997c99a81fcc8f5effa354120 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 9 Aug 2015 00:26:14 -0700 Subject: [PATCH 49/56] It turns out that the copy is unnecessary. --- .../spark/sql/execution/RowIterator.scala | 18 ------------------ .../execution/joins/SortMergeOuterJoin.scala | 2 -- 2 files changed, 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala index 5ea1a7bd4fcaa..7462dbc4eba3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala @@ -29,23 +29,6 @@ import org.apache.spark.sql.catalyst.InternalRow * calls: Scala's iterator allows users to call `hasNext()` without immediately advancing the * iterator to consume the next row, whereas RowIterator combines these calls into a single * [[advanceNext()]] method. - * - * In some cases, significant work may need to be performed in order to determine whether there is a - * next element (for example, a `filter`ed iterator may need to consume many elements of its parent - * iterator in order to determine whether there is a next row). As a result, many Scala iterators - * perform implicit internal buffering, which can cause problems with iterators that return the same - * mutable Row on every `next()` call. If we call `.filter()` on a Scala iterator of InternalRow, - * calling `hasNext()` may mutate the row that has already been returned from the iterator. This - * can cause problems unless the caller expects to immediately call `next()` after `hasNext()` - * returned true. - * - * We can guard against this anomaly by automatically copying rows before returning them to a Scala - * iterator; RowIterator's [[toScala]] method returns a wrapper which automatically performs this - * defensive copying. These copies carry a performance penalty, though, so ideally we should avoid - * this. The `RowIterator.fromScala` method wraps a Scala iterator behind our more restrictive - * iterator interface. As an optimization, calling `RowIterator.fromScala` on a wrapped RowIterator - * will return the underlying RowIterator, avoiding the copying. Thus, by gradually re-writing - * operators to use our [[RowIterator]] wrappers we can safely remove this defensive row copying. */ private[sql] abstract class RowIterator { /** @@ -89,7 +72,6 @@ private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterato } override def next(): InternalRow = { if (!hasNext) throw new NoSuchElementException - // TODO(josh): see whether we need to re-add the copy() here: hasNextWasCalled = false rowIter.getRow } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index e72b85ba56afe..a984aaec72a4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -43,10 +43,8 @@ case class SortMergeOuterJoin( override def output: Seq[Attribute] = { joinType match { case LeftOuter => - // Note: technically the left join keys will not be nullable here: left.output ++ right.output.map(_.withNullability(true)) case RightOuter => - // Note: technically the left right keys will not be nullable here: left.output.map(_.withNullability(true)) ++ right.output case x => throw new IllegalArgumentException( From f83b412a08e7cbf4fe07f5bd1a266efaad78b0b8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 9 Aug 2015 00:44:23 -0700 Subject: [PATCH 50/56] Push null check into buffered iterator next(). --- .../sql/execution/joins/SortMergeJoin.scala | 43 ++++++++----------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 5970e4546f2df..ef5e26f0f25e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -148,6 +148,7 @@ private[joins] class SortMergeJoinScanner( private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ + // Note: this is guaranteed to never have any null columns: private[this] var bufferedRowKey: InternalRow = _ /** * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty @@ -157,7 +158,7 @@ private[joins] class SortMergeJoinScanner( private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] // Initialization (note: do _not_ want to advance streamed here). - advancedBuffered() + advancedBufferedToRowWithNullFreeJoinKey() // --- Public methods --------------------------------------------------------------------------- @@ -196,11 +197,10 @@ private[joins] class SortMergeJoinScanner( do { if (streamedRowKey.anyNull) { advancedStreamed() - } else if (bufferedRowKey.anyNull) { - advancedBuffered() } else { + assert(!bufferedRowKey.anyNull) comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) - if (comp > 0) advancedBuffered() + if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey() else if (comp < 0) advancedStreamed() } } while (streamedRow != null && bufferedRow != null && comp != 0) @@ -242,15 +242,10 @@ private[joins] class SortMergeJoinScanner( if (bufferedRow != null && !streamedRowKey.anyNull) { // The buffered iterator could still contain matching rows, so we'll need to walk through // it until we either find matches or pass where they would be found. - var comp = - if (bufferedRowKey.anyNull) 1 else keyOrdering.compare(streamedRowKey, bufferedRowKey) - while (comp > 0 && advancedBuffered()) { - comp = if (bufferedRowKey.anyNull) { - 1 - } else { - keyOrdering.compare(streamedRowKey, bufferedRowKey) - } - } + var comp = 1 + do { + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + } while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey()) if (comp == 0) { // We have found matches, so buffer them (this updates matchJoinKey) bufferMatchingRows() @@ -283,18 +278,22 @@ private[joins] class SortMergeJoinScanner( } /** - * Advance the buffered iterator and compute the new row's join key. + * Advance the buffered iterator until we find a row with join key that does not contain nulls. * @return true if the buffered iterator returned a row and false otherwise. */ - private def advancedBuffered(): Boolean = { - if (bufferedIter.advanceNext()) { + private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = { + var foundRow: Boolean = false + while (!foundRow && bufferedIter.advanceNext()) { bufferedRow = bufferedIter.getRow bufferedRowKey = bufferedKeyGenerator(bufferedRow) - true - } else { + foundRow = !bufferedRowKey.anyNull + } + if (!foundRow) { bufferedRow = null bufferedRowKey = null false + } else { + true } } @@ -312,11 +311,7 @@ private[joins] class SortMergeJoinScanner( bufferedMatches.clear() do { bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them - advancedBuffered() - } while ( - bufferedRow != null && - !bufferedRowKey.anyNull && - keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0 - ) + advancedBufferedToRowWithNullFreeJoinKey() + } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) } } From 81956b0f97f87e6fb988b6b693c4af5a240a9de3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 5 Aug 2015 18:55:14 -0700 Subject: [PATCH 51/56] Improve unit test coverage of join physical operators. --- .../joins/BroadcastNestedLoopJoin.scala | 5 +- .../sql/execution/joins/InnerJoinSuite.scala | 172 +++++++++ .../sql/execution/joins/OuterJoinSuite.scala | 340 +++++++++--------- .../sql/execution/joins/SemiJoinSuite.scala | 121 ++++--- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- 5 files changed, 429 insertions(+), 211 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 23aebf4b068b4..017a44b9ca863 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -65,8 +65,9 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case _ => - left.output ++ right.output + case x => + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not take $x as the JoinType") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala new file mode 100644 index 0000000000000..b9d5f0ef8fa46 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution._ + +class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { + + private def testInnerJoin( + testName: String, + leftRows: DataFrame, + rightRows: DataFrame, + condition: Expression, + expectedAnswer: Seq[Product]): Unit = { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + + def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { + val broadcastHashJoin = + execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right) + boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + } + + def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { + val shuffledHashJoin = + execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right) + val filteredJoin = + boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = { + val sortMergeJoin = + execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right) + val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + test(s"$testName using BroadcastHashJoin (build=left)") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeBroadcastHashJoin(left, right, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + + test(s"$testName using BroadcastHashJoin (build=right)") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeBroadcastHashJoin(left, right, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + + test(s"$testName using ShuffledHashJoin (build=left)") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeShuffledHashJoin(left, right, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + + test(s"$testName using ShuffledHashJoin (build=right)") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeShuffledHashJoin(left, right, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + + test(s"$testName using SortMergeJoin") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeSortMergeJoin(left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + { + val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"), + Row(5, "E"), + Row(6, "F"), + Row(null, "G") + )), new StructType().add("N", IntegerType).add("L", StringType)) + + val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "d"), + Row(null, "e") + )), new StructType().add("n", IntegerType).add("l", StringType)) + + testInnerJoin( + "inner join, one match per row", + upperCaseData, + lowerCaseData, + (upperCaseData.col("N") === lowerCaseData.col("n")).expr, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + } + + private val testData2 = Seq( + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + ).toDF("a", "b") + + { + val left = testData2.where("a = 1") + val right = testData2.where("a = 1") + testInnerJoin( + "inner join, multiple matches", + left, + right, + (left.col("a") === right.col("a")).expr, + Seq( + (1, 1, 1, 1), + (1, 1, 1, 2), + (1, 2, 1, 1), + (1, 2, 1, 2) + ) + ) + } + + { + val left = testData2.where("a = 1") + val right = testData2.where("a = 2") + testInnerJoin( + "inner join, no matches", + left, + right, + (left.col("a") === right.col("a")).expr, + Seq.empty + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 27b185f700f81..97dbca3cb3051 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -1,162 +1,178 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} - -class OuterJoinSuite extends SparkPlanTest { - - private def testOuterJoin( - testName: String, - leftRows: DataFrame, - rightRows: DataFrame, - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - expectedAnswer: Seq[Product]): Unit = { - // Precondition: leftRows and rightRows should be sorted according to the join keys. - - test(s"$testName using ShuffledHashOuterJoin") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, condition, left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = false) - } - - if (joinType != FullOuter) { - test(s"$testName using BroadcastHashOuterJoin") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, condition, left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = false) - } - - test(s"$testName using SortMergeOuterJoin") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - SortMergeOuterJoin(leftKeys, rightKeys, joinType, condition, left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = false) - } - } - } - - val left = Seq( - (1, 2.0), - (2, 1.0), - (3, 3.0) - ).toDF("a", "b") - - val right = Seq( - (2, 3.0), - (3, 2.0), - (4, 1.0) - ).toDF("c", "d") - - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) - - // --- Basic outer joins ------------------------------------------------------------------------ - - testOuterJoin( - "basic left outer join", - left, - right, - leftKeys, - rightKeys, - LeftOuter, - condition, - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null) - ) - ) - - testOuterJoin( - "basic right outer join", - left, - right, - leftKeys, - rightKeys, - RightOuter, - condition, - Seq( - (2, 1.0, 2, 3.0), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ) - ) - - testOuterJoin( - "basic full outer join", - left, - right, - leftKeys, - rightKeys, - FullOuter, - condition, - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ) - ) - - // --- Both inputs empty ------------------------------------------------------------------------ - - testOuterJoin( - "left outer join with both inputs empty", - left.filter("false"), - right.filter("false"), - leftKeys, - rightKeys, - LeftOuter, - condition, - Seq.empty - ) - - testOuterJoin( - "right outer join with both inputs empty", - left.filter("false"), - right.filter("false"), - leftKeys, - rightKeys, - RightOuter, - condition, - Seq.empty - ) - - testOuterJoin( - "full outer join with both inputs empty", - left.filter("false"), - right.filter("false"), - leftKeys, - rightKeys, - FullOuter, - condition, - Seq.empty - ) -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} +import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest} + +class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { + + private def testOuterJoin( + testName: String, + leftRows: DataFrame, + rightRows: DataFrame, + joinType: JoinType, + condition: Expression, + expectedAnswer: Seq[Product]): Unit = { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + test(s"$testName using ShuffledHashOuterJoin") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + + if (joinType != FullOuter) { + test(s"$testName using BroadcastHashOuterJoin") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + + test(s"$testName using SortMergeOuterJoin") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = false) + } + } + } + + test(s"$testName using BroadcastNestedLoopJoin (build=left)") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + + test(s"$testName using BroadcastNestedLoopJoin (build=right)") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + val condition = { + And( + (left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // --- Basic outer joins ------------------------------------------------------------------------ + + testOuterJoin( + "basic left outer join", + left, + right, + LeftOuter, + condition, + Seq( + (null, null, null, null), + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null) + ) + ) + + testOuterJoin( + "basic right outer join", + left, + right, + RightOuter, + condition, + Seq( + (null, null, null, null), + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + ) + ) + + testOuterJoin( + "basic full outer join", + left, + right, + FullOuter, + condition, + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (null, null, null, null), + (null, null, null, null) + ) + ) + + // --- Both inputs empty ------------------------------------------------------------------------ + + testOuterJoin( + "left outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + LeftOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "right outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + RightOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "full outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + FullOuter, + condition, + Seq.empty + ) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 927e85a7db3dc..9a8a667d365a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -17,58 +17,87 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression} -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} +import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +class SemiJoinSuite extends SparkPlanTest with SQLTestUtils { -class SemiJoinSuite extends SparkPlanTest{ - val left = Seq( - (1, 2.0), - (1, 2.0), - (2, 1.0), - (2, 1.0), - (3, 3.0) - ).toDF("a", "b") + private def testLeftSemiJoin( + testName: String, + leftRows: DataFrame, + rightRows: DataFrame, + condition: Expression, + expectedAnswer: Seq[Product]): Unit = { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + test(s"$testName using LeftSemiJoinHash") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext).apply( + LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } - val right = Seq( - (2, 3.0), - (2, 3.0), - (3, 2.0), - (4, 1.0) - ).toDF("c", "d") + test(s"$testName using BroadcastLeftSemiJoinHash") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) - - test("left semi join hash") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), - Seq( - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) + test(s"$testName using LeftSemiJoinBNL") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinBNL(left, right, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } - test("left semi join BNL") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinBNL(left, right, condition), - Seq( - (1, 2.0), - (1, 2.0), - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) - } + val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) - test("broadcast left semi join hash") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), - Seq( - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) + val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + val condition = { + And( + (left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) } + + testLeftSemiJoin( + "basic test", + left, + right, + condition, + Seq( + (2, 1.0), + (2, 1.0) + ) + ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 4c11acdab9ec0..1066695589778 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils trait SQLTestUtils { this: SparkFunSuite => - def sqlContext: SQLContext + protected def sqlContext: SQLContext protected def configuration = sqlContext.sparkContext.hadoopConfiguration From 899dce29400f2ae9913e3c02c98a28885c0b0a22 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 10 Aug 2015 16:00:59 -0700 Subject: [PATCH 52/56] Expand test data to cover multiple buffered rows per group. --- .../spark/sql/execution/joins/OuterJoinSuite.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 97dbca3cb3051..82fa3321fe4e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -82,12 +82,14 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( Row(1, 2.0), + Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches Row(2, 1.0), Row(3, 3.0), Row(null, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches Row(2, 3.0), Row(3, 2.0), Row(4, 1.0), @@ -112,6 +114,9 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { (null, null, null, null), (1, 2.0, null, null), (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), (3, 3.0, null, null) ) ) @@ -125,6 +130,9 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { Seq( (null, null, null, null), (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), (null, null, 3, 2.0), (null, null, 4, 1.0) ) @@ -139,6 +147,9 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { Seq( (1, 2.0, null, null), (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), (3, 3.0, null, null), (null, null, 3, 2.0), (null, null, 4, 1.0), From e79909ed347f45f76268b88917967147f42841f2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 10 Aug 2015 17:18:37 -0700 Subject: [PATCH 53/56] Fix parallelism in join operator unit tests. --- .../sql/execution/joins/InnerJoinSuite.scala | 72 ++++++++++--------- .../sql/execution/joins/OuterJoinSuite.scala | 51 +++++++++---- .../sql/execution/joins/SemiJoinSuite.scala | 20 +++--- 3 files changed, 89 insertions(+), 54 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index b9d5f0ef8fa46..ddff7cebcc17d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -34,67 +34,75 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { rightRows: DataFrame, condition: Expression, expectedAnswer: Seq[Product]): Unit = { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join).foreach { - case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => - - def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { - val broadcastHashJoin = - execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right) - boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) - } - - def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { - val shuffledHashJoin = - execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right) - val filteredJoin = - boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) - EnsureRequirements(sqlContext).apply(filteredJoin) - } - - def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = { - val sortMergeJoin = - execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right) - val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) - EnsureRequirements(sqlContext).apply(filteredJoin) - } - - test(s"$testName using BroadcastHashJoin (build=left)") { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + + def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { + val broadcastHashJoin = + execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right) + boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + } + + def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { + val shuffledHashJoin = + execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right) + val filteredJoin = + boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = { + val sortMergeJoin = + execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right) + val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + test(s"$testName using BroadcastHashJoin (build=left)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => makeBroadcastHashJoin(left, right, joins.BuildLeft), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } + } - test(s"$testName using BroadcastHashJoin (build=right)") { + test(s"$testName using BroadcastHashJoin (build=right)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => makeBroadcastHashJoin(left, right, joins.BuildRight), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } + } - test(s"$testName using ShuffledHashJoin (build=left)") { + test(s"$testName using ShuffledHashJoin (build=left)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => makeShuffledHashJoin(left, right, joins.BuildLeft), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } + } - test(s"$testName using ShuffledHashJoin (build=right)") { + test(s"$testName using ShuffledHashJoin (build=right)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => makeShuffledHashJoin(left, right, joins.BuildRight), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } + } - test(s"$testName using SortMergeJoin") { + test(s"$testName using SortMergeJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => makeSortMergeJoin(left, right), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } - } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 82fa3321fe4e0..d8c3327a0a315 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -35,43 +35,52 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { joinType: JoinType, condition: Expression, expectedAnswer: Seq[Product]): Unit = { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join).foreach { - case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => - test(s"$testName using ShuffledHashOuterJoin") { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + test(s"$testName using ShuffledHashOuterJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(sqlContext).apply( ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } + } - if (joinType != FullOuter) { - test(s"$testName using BroadcastHashOuterJoin") { + if (joinType != FullOuter) { + test(s"$testName using BroadcastHashOuterJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } + } - test(s"$testName using SortMergeOuterJoin") { + test(s"$testName using SortMergeOuterJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + EnsureRequirements(sqlContext).apply( + SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = false) } } - } + } + } - test(s"$testName using BroadcastNestedLoopJoin (build=left)") { + test(s"$testName using BroadcastNestedLoopJoin (build=left)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } + } - test(s"$testName using BroadcastNestedLoopJoin (build=right)") { + test(s"$testName using BroadcastNestedLoopJoin (build=right)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)), expectedAnswer.map(Row.fromTuple), @@ -85,14 +94,19 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches Row(2, 1.0), Row(3, 3.0), + Row(5, 1.0), + Row(6, 6.0), Row(null, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(0, 0.0), Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches Row(2, 3.0), Row(3, 2.0), Row(4, 1.0), + Row(5, 3.0), + Row(7, 7.0), Row(null, null) )), new StructType().add("c", IntegerType).add("d", DoubleType)) @@ -117,7 +131,9 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), - (3, 3.0, null, null) + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null) ) ) @@ -129,12 +145,15 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { condition, Seq( (null, null, null, null), + (null, null, 0, 0.0), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), (null, null, 3, 2.0), - (null, null, 4, 1.0) + (null, null, 4, 1.0), + (5, 1.0, 5, 3.0), + (null, null, 7, 7.0) ) ) @@ -151,8 +170,12 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null), + (null, null, 0, 0.0), (null, null, 3, 2.0), (null, null, 4, 1.0), + (null, null, 7, 7.0), (null, null, null, null), (null, null, null, null) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 9a8a667d365a8..4503ed251fcb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -34,27 +34,31 @@ class SemiJoinSuite extends SparkPlanTest with SQLTestUtils { rightRows: DataFrame, condition: Expression, expectedAnswer: Seq[Product]): Unit = { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join).foreach { - case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => - test(s"$testName using LeftSemiJoinHash") { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + test(s"$testName using LeftSemiJoinHash") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext).apply( LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } + } - test(s"$testName using BroadcastLeftSemiJoinHash") { + test(s"$testName using BroadcastLeftSemiJoinHash") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } - } + } + } - test(s"$testName using LeftSemiJoinBNL") { + test(s"$testName using LeftSemiJoinBNL") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => LeftSemiJoinBNL(left, right, Some(condition)), expectedAnswer.map(Row.fromTuple), From 5c34f7571bed6711b94d610428e7e77e8655fe86 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 10 Aug 2015 17:39:21 -0700 Subject: [PATCH 54/56] Add regression test exposing bug with missing while loop --- .../apache/spark/sql/execution/joins/OuterJoinSuite.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index d8c3327a0a315..6144f43ca8c4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -102,6 +102,8 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( Row(0, 0.0), Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, -1.0), + Row(2, -1.0), Row(2, 3.0), Row(3, 2.0), Row(4, 1.0), @@ -146,6 +148,8 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { Seq( (null, null, null, null), (null, null, 0, 0.0), + (null, null, 2, -1.0), + (null, null, 2, -1.0), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), @@ -165,6 +169,8 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { condition, Seq( (1, 2.0, null, null), + (null, null, 2, -1.0), + (null, null, 2, -1.0), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), From c188a21913d180b1d831f564a05fa09c6a81886d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 10 Aug 2015 17:47:00 -0700 Subject: [PATCH 55/56] Fix while loops while adding regression tests. --- .../spark/sql/execution/joins/SortMergeOuterJoin.scala | 4 ++-- .../apache/spark/sql/execution/joins/OuterJoinSuite.scala | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index a984aaec72a4f..5326966b07a66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -190,7 +190,7 @@ private class LeftOuterIterator( private def advanceRightUntilBoundConditionSatisfied(): Boolean = { var foundMatch: Boolean = false - if (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) { + while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) { foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx))) rightIdx += 1 } @@ -236,7 +236,7 @@ private class RightOuterIterator( private def advanceLeftUntilBoundConditionSatisfied(): Boolean = { var foundMatch: Boolean = false - if (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) { + while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) { foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx))) leftIdx += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 6144f43ca8c4a..e16f5e39aa2f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -91,6 +91,7 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( Row(1, 2.0), + Row(2, 100.0), Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches Row(2, 1.0), Row(3, 3.0), @@ -129,6 +130,7 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { Seq( (null, null, null, null), (1, 2.0, null, null), + (2, 100.0, null, null), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), @@ -148,10 +150,10 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { Seq( (null, null, null, null), (null, null, 0, 0.0), - (null, null, 2, -1.0), - (null, null, 2, -1.0), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), + (null, null, 2, -1.0), + (null, null, 2, -1.0), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), (null, null, 3, 2.0), @@ -171,6 +173,7 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { (1, 2.0, null, null), (null, null, 2, -1.0), (null, null, 2, -1.0), + (2, 100.0, null, null), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), (2, 1.0, 2, 3.0), From eabacca9864e609a1b085a2acbe10907929700a4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 10 Aug 2015 18:54:15 -0700 Subject: [PATCH 56/56] comment updates --- .../apache/spark/sql/execution/joins/SortMergeJoin.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index ef5e26f0f25e3..6d656ea2849a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -56,6 +56,7 @@ case class SortMergeJoin( protected[this] def isUnsafeMode: Boolean = { (codegenEnabled && unsafeEnabled && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) && UnsafeProjection.canSupport(schema)) } @@ -122,15 +123,14 @@ case class SortMergeJoin( /** * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]]. * - * The streamed input is the left side of a left outer join or the right side of a right outer join. - * * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]] * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false` * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return * the matching row from the streamed input and may call [[getBufferedMatches]] to return the * sequence of matching rows from the buffered input (in the case of an outer join, this will return - * an empty sequence). For efficiency, both of these methods return mutable objects which are - * re-used across calls to the `findNext*JoinRows()` methods. + * an empty sequence if there are no matches from the buffered input). For efficiency, both of these + * methods return mutable objects which are re-used across calls to the `findNext*JoinRows()` + * methods. * * @param streamedKeyGenerator a projection that produces join keys from the streamed input. * @param bufferedKeyGenerator a projection that produces join keys from the buffered input.