From abc65974905c872f00e6ab9cf51218ddfadff0cc Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Mon, 3 Aug 2020 14:54:28 -0700 Subject: [PATCH 01/11] Full outer shuffled hash join --- .../spark/sql/catalyst/optimizer/joins.scala | 27 ++- .../spark/sql/execution/SparkStrategies.scala | 6 +- .../spark/sql/execution/joins/HashJoin.scala | 4 +- .../sql/execution/joins/HashedRelation.scala | 100 +++++++++-- .../joins/ShuffledHashJoinExec.scala | 159 +++++++++++++++++- .../sql/execution/joins/ShuffledJoin.scala | 23 ++- .../execution/joins/SortMergeJoinExec.scala | 20 --- .../org/apache/spark/sql/JoinSuite.scala | 38 +++++ 8 files changed, 323 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 85c6600685bd1..57c3f3dbd050d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -235,8 +235,8 @@ trait JoinSelectionHelper { canBroadcastBySize(right, conf) && !hintToNotBroadcastRight(hint) } getBuildSide( - canBuildLeft(joinType) && buildLeft, - canBuildRight(joinType) && buildRight, + canBuildBroadcastLeft(joinType) && buildLeft, + canBuildBroadcastRight(joinType) && buildRight, left, right ) @@ -260,8 +260,8 @@ trait JoinSelectionHelper { canBuildLocalHashMapBySize(right, conf) && muchSmaller(right, left) } getBuildSide( - canBuildLeft(joinType) && buildLeft, - canBuildRight(joinType) && buildRight, + canBuildShuffledHashJoinLeft(joinType) && buildLeft, + canBuildShuffledHashJoinRight(joinType) && buildRight, left, right ) @@ -278,20 +278,35 @@ trait JoinSelectionHelper { plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold } - def canBuildLeft(joinType: JoinType): Boolean = { + def canBuildBroadcastLeft(joinType: JoinType): Boolean = { joinType match { case _: InnerLike | RightOuter => true case _ => false } } - def canBuildRight(joinType: JoinType): Boolean = { + def canBuildBroadcastRight(joinType: JoinType): Boolean = { joinType match { case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true case _ => false } } + def canBuildShuffledHashJoinLeft(joinType: JoinType): Boolean = { + joinType match { + case _: InnerLike | RightOuter | FullOuter => true + case _ => false + } + } + + def canBuildShuffledHashJoinRight(joinType: JoinType): Boolean = { + joinType match { + case _: InnerLike | LeftOuter | FullOuter | + LeftSemi | LeftAnti | _: ExistenceJoin => true + case _ => false + } + } + def hintToBroadcastLeft(hint: JoinHint): Boolean = { hint.leftHint.exists(_.strategy.contains(BROADCAST)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index eb32bfcecae7b..391e2524e4794 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -116,7 +116,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * * - Shuffle hash join: * Only supported for equi-joins, while the join keys do not need to be sortable. - * Supported for all join types except full outer joins. + * Supported for all join types. + * Building hash map from table is a memory-intensive operation and it could cause OOM + * when the build side is big. * * - Shuffle sort merge join (SMJ): * Only supported for equi-joins and the join keys have to be sortable. @@ -260,7 +262,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // it's a right join, and broadcast right side if it's a left join. // TODO: revisit it. If left side is much smaller than the right side, it may be better // to broadcast the left side even if it's a left join. - if (canBuildLeft(joinType)) BuildLeft else BuildRight + if (canBuildBroadcastLeft(joinType)) BuildLeft else BuildRight } def createBroadcastNLJoin(buildLeft: Boolean, buildRight: Boolean) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 2154e370a1596..1a7554c905c6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -114,7 +114,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { } } - @transient private lazy val (buildOutput, streamedOutput) = { + @transient protected lazy val (buildOutput, streamedOutput) = { buildSide match { case BuildLeft => (left.output, right.output) case BuildRight => (right.output, left.output) @@ -133,7 +133,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { protected def streamSideKeyGenerator(): UnsafeProjection = UnsafeProjection.create(streamedBoundKeys) - @transient private[this] lazy val boundCondition = if (condition.isDefined) { + @transient protected[this] lazy val boundCondition = if (condition.isDefined) { Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _ } else { (r: InternalRow) => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 0d40520ae71a0..ee8926371b01a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -76,6 +76,11 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { */ def keys(): Iterator[InternalRow] + /** + * Returns an iterator for values of InternalRow type. + */ + def values(): Iterator[InternalRow] + /** * Returns a read-only copy of this, to be safely used in current thread. */ @@ -97,7 +102,9 @@ private[execution] object HashedRelation { key: Seq[Expression], sizeEstimate: Int = 64, taskMemoryManager: TaskMemoryManager = null, - isNullAware: Boolean = false): HashedRelation = { + isNullAware: Boolean = false, + isLookupAware: Boolean = false, + value: Option[Seq[Expression]] = None): HashedRelation = { val mm = Option(taskMemoryManager).getOrElse { new TaskMemoryManager( new UnifiedMemoryManager( @@ -110,10 +117,10 @@ private[execution] object HashedRelation { if (!input.hasNext) { EmptyHashedRelation - } else if (key.length == 1 && key.head.dataType == LongType) { + } else if (key.length == 1 && key.head.dataType == LongType && !isLookupAware) { LongHashedRelation(input, key, sizeEstimate, mm, isNullAware) } else { - UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware) + UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware, isLookupAware, value) } } } @@ -128,15 +135,18 @@ private[execution] object HashedRelation { private[joins] class UnsafeHashedRelation( private var numKeys: Int, private var numFields: Int, - private var binaryMap: BytesToBytesMap) + private var binaryMap: BytesToBytesMap, + private val isLookupAware: Boolean = false) extends HashedRelation with Externalizable with KryoSerializable { - private[joins] def this() = this(0, 0, null) // Needed for serialization + private[joins] def this() = this(0, 0, null, false) // Needed for serialization - override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues() + override def keyIsUnique: Boolean = { + binaryMap.numKeys() == binaryMap.numValues() + } override def asReadOnlyCopy(): UnsafeHashedRelation = { - new UnsafeHashedRelation(numKeys, numFields, binaryMap) + new UnsafeHashedRelation(numKeys, numFields, binaryMap, isLookupAware) } override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption @@ -305,6 +315,27 @@ private[joins] class UnsafeHashedRelation( override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { read(() => in.readInt(), () => in.readLong(), in.readBytes) } + + override def values(): Iterator[InternalRow] = { + if (isLookupAware) { + val iter = binaryMap.iterator() + + new Iterator[InternalRow] { + override def hasNext: Boolean = iter.hasNext + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException("End of the iterator") + } + val loc = iter.next() + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + resultRow + } + } + } else { + throw new UnsupportedOperationException + } + } } private[joins] object UnsafeHashedRelation { @@ -314,7 +345,9 @@ private[joins] object UnsafeHashedRelation { key: Seq[Expression], sizeEstimate: Int, taskMemoryManager: TaskMemoryManager, - isNullAware: Boolean = false): HashedRelation = { + isNullAware: Boolean = false, + isLookupAware: Boolean = false, + value: Option[Seq[Expression]] = None): HashedRelation = { val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().get(BUFFER_PAGESIZE).getOrElse(16L * 1024 * 1024)) @@ -327,27 +360,52 @@ private[joins] object UnsafeHashedRelation { // Create a mapping of buildKeys -> rows val keyGenerator = UnsafeProjection.create(key) var numFields = 0 - while (input.hasNext) { - val row = input.next().asInstanceOf[UnsafeRow] - numFields = row.numFields() - val key = keyGenerator(row) - if (!key.anyNull) { + + if (isLookupAware) { + // Add one extra boolean value at the end as part of the row, + // to track the information that whether the corresponding key + // has been looked up or not. See `ShuffledHashJoin.fullOuterJoin` for example of usage. + val valueGenerator = UnsafeProjection.create(value.get :+ Literal(false)) + + while (input.hasNext) { + val row = input.next().asInstanceOf[UnsafeRow] + numFields = row.numFields() + 1 + val key = keyGenerator(row) + val value = valueGenerator(row) val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) val success = loc.append( key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, - row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) if (!success) { binaryMap.free() // scalastyle:off throwerror throw new SparkOutOfMemoryError("There is not enough memory to build hash map") // scalastyle:on throwerror } - } else if (isNullAware) { - return EmptyHashedRelationWithAllNullKeys + } + } else { + while (input.hasNext) { + val row = input.next().asInstanceOf[UnsafeRow] + numFields = row.numFields() + val key = keyGenerator(row) + if (!key.anyNull) { + val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + val success = loc.append( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) + if (!success) { + binaryMap.free() + // scalastyle:off throwerror + throw new SparkOutOfMemoryError("There is not enough memory to build hash map") + // scalastyle:on throwerror + } + } else if (isNullAware) { + return EmptyHashedRelationWithAllNullKeys + } } } - new UnsafeHashedRelation(key.size, numFields, binaryMap) + new UnsafeHashedRelation(key.size, numFields, binaryMap, isLookupAware) } } @@ -885,6 +943,10 @@ class LongHashedRelation( * Returns an iterator for keys of InternalRow type. */ override def keys(): Iterator[InternalRow] = map.keys() + + override def values(): Iterator[InternalRow] = { + throw new UnsupportedOperationException + } } /** @@ -939,6 +1001,10 @@ trait NullAwareHashedRelation extends HashedRelation with Externalizable { throw new UnsupportedOperationException } + override def values(): Iterator[InternalRow] = { + throw new UnsupportedOperationException + } + override def close(): Unit = {} override def writeExternal(out: ObjectOutput): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 41cefd03dd931..151708a57a491 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -22,13 +22,13 @@ import java.util.concurrent.TimeUnit._ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.optimizer.BuildSide +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.{RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} /** * Performs a hash join of two child relations by first shuffling the data using the join keys. @@ -48,8 +48,15 @@ case class ShuffledHashJoinExec( "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"), "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map")) + override def output: Seq[Attribute] = super[ShuffledJoin].output + override def outputPartitioning: Partitioning = super[ShuffledJoin].outputPartitioning + override def outputOrdering: Seq[SortOrder] = joinType match { + case FullOuter => Nil + case _ => super.outputOrdering + } + /** * This is called by generated Java class, should be public. */ @@ -58,8 +65,19 @@ case class ShuffledHashJoinExec( val buildTime = longMetric("buildTime") val start = System.nanoTime() val context = TaskContext.get() + + val (isLookupAware, value) = + if (joinType == FullOuter) { + (true, Some(BindReferences.bindReferences(buildOutput, buildOutput))) + } else { + (false, None) + } val relation = HashedRelation( - iter, buildBoundKeys, taskMemoryManager = context.taskMemoryManager()) + iter, + buildBoundKeys, + taskMemoryManager = context.taskMemoryManager(), + isLookupAware = isLookupAware, + value = value) buildTime += NANOSECONDS.toMillis(System.nanoTime() - start) buildDataSize += relation.estimatedSize // This relation is usually used until the end of task. @@ -71,8 +89,137 @@ case class ShuffledHashJoinExec( val numOutputRows = longMetric("numOutputRows") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter) - join(streamIter, hashed, numOutputRows) + joinType match { + case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows) + case _ => join(streamIter, hashed, numOutputRows) + } + } + } + + /** + * Full outer shuffled hash join has three steps: + * 1. Construct hash relation from build side, + * with extra boolean value at the end of row to track look up information + * (done in `buildHashedRelation`). + * 2. Process rows from stream side by looking up hash relation, + * and mark the matched rows from build side be looked up. + * 3. Process rows from build side by iterating hash relation, + * and filter out rows from build side being looked up already. + */ + private def fullOuterJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation, + numOutputRows: SQLMetric): Iterator[InternalRow] = { + abstract class HashJoinedRow extends JoinedRow { + /** Updates this JoinedRow by updating its stream side row. Returns itself. */ + def withStream(newStream: InternalRow): JoinedRow + + /** Updates this JoinedRow by updating its build side row. Returns itself. */ + def withBuild(newBuild: InternalRow): JoinedRow } + val joinRow: HashJoinedRow = buildSide match { + case BuildLeft => + new HashJoinedRow { + override def withStream(newStream: InternalRow): JoinedRow = withRight(newStream) + override def withBuild(newBuild: InternalRow): JoinedRow = withLeft(newBuild) + } + case BuildRight => + new HashJoinedRow { + override def withStream(newStream: InternalRow): JoinedRow = withLeft(newStream) + override def withBuild(newBuild: InternalRow): JoinedRow = withRight(newBuild) + } + } + val joinKeys = streamSideKeyGenerator() + val buildRowGenerator = UnsafeProjection.create(buildOutput, buildOutput) + val buildNullRow = new GenericInternalRow(buildOutput.length) + val streamNullRow = new GenericInternalRow(streamedOutput.length) + + def markRowLookedUp(row: UnsafeRow): Unit = { + if (!row.getBoolean(row.numFields() - 1)) { + row.setBoolean(row.numFields() - 1, true) + } + } + + // Process stream side with looking up hash relation + val streamResultIter = + if (hashedRelation.keyIsUnique) { + streamIter.map { srow => + joinRow.withStream(srow) + val keys = joinKeys(srow) + if (keys.anyNull) { + joinRow.withBuild(buildNullRow) + } else { + val matched = hashedRelation.getValue(keys) + if (matched != null) { + val buildRow = buildRowGenerator(matched) + if (boundCondition(joinRow.withBuild(buildRow))) { + markRowLookedUp(matched.asInstanceOf[UnsafeRow]) + joinRow + } else { + joinRow.withBuild(buildNullRow) + } + } else { + joinRow.withBuild(buildNullRow) + } + } + } + } else { + streamIter.flatMap { srow => + joinRow.withStream(srow) + val keys = joinKeys(srow) + if (keys.anyNull) { + Iterator.single(joinRow.withBuild(buildNullRow)) + } else { + val buildIter = hashedRelation.get(keys) + new RowIterator { + private var found = false + override def advanceNext(): Boolean = { + while (buildIter != null && buildIter.hasNext) { + val matched = buildIter.next() + val buildRow = buildRowGenerator(matched) + if (boundCondition(joinRow.withBuild(buildRow))) { + markRowLookedUp(matched.asInstanceOf[UnsafeRow]) + found = true + return true + } + } + if (!found) { + joinRow.withBuild(buildNullRow) + found = true + return true + } + false + } + override def getRow: InternalRow = joinRow + }.toScala + } + } + } + + // Process build side with filtering out rows looked up already + val buildResultIter = hashedRelation.values().flatMap { brow => + val unsafebrow = brow.asInstanceOf[UnsafeRow] + val isLookup = unsafebrow.getBoolean(unsafebrow.numFields() - 1) + if (!isLookup) { + val buildRow = buildRowGenerator(unsafebrow) + joinRow.withBuild(buildRow) + joinRow.withStream(streamNullRow) + Some(joinRow) + } else { + None + } + } + + val resultProj = UnsafeProjection.create(output, output) + (streamResultIter ++ buildResultIter).map { r => + numOutputRows += 1 + resultProj(r) + } + } + + // TODO: support full outer shuffled hash join code-gen + override def supportCodegen: Boolean = { + joinType != FullOuter } override def inputRDDs(): Seq[RDD[InternalRow]] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala index 7035ddc35be9c..0318baf27f9fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, LeftExistence, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftExistence, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution, Partitioning, PartitioningCollection, UnknownPartitioning} /** @@ -40,4 +41,24 @@ trait ShuffledJoin extends BaseJoinExec { throw new IllegalArgumentException( s"ShuffledJoin should not take $x as the JoinType") } + + override def output: Seq[Attribute] = { + joinType match { + case _: InnerLike => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + (left.output ++ right.output).map(_.withNullability(true)) + case j: ExistenceJoin => + left.output :+ j.exists + case LeftExistence(_) => + left.output + case x => + throw new IllegalArgumentException( + s"ShuffledJoin not take $x as the JoinType") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index b9f6684447dd8..6e7bcb8825488 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -52,26 +52,6 @@ case class SortMergeJoinExec( override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator - override def output: Seq[Attribute] = { - joinType match { - case _: InnerLike => - left.output ++ right.output - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - (left.output ++ right.output).map(_.withNullability(true)) - case j: ExistenceJoin => - left.output :+ j.exists - case LeftExistence(_) => - left.output - case x => - throw new IllegalArgumentException( - s"${getClass.getSimpleName} should not take $x as the JoinType") - } - } - override def requiredChildDistribution: Seq[Distribution] = { if (isSkewJoin) { // We re-arrange the shuffle partitions to deal with skew join, and the new children diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index bedfbffc789ac..02f69a8d651fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1188,4 +1188,42 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan classOf[BroadcastNestedLoopJoinExec])) } } + + test("Full outer shuffled hash join") { + val inputDFs = Seq( + // Test unique join key + (spark.range(10).selectExpr("id as k1"), + spark.range(30).selectExpr("id as k2")), + // Test non-unique join key + (spark.range(10).selectExpr("id % 5 as k1"), + spark.range(30).selectExpr("id % 5 as k2")), + // Test string join key + (spark.range(10).selectExpr("cast(id * 3 as string) as k1"), + spark.range(30).selectExpr("cast(id as string) as k2")), + // Test build side at right + (spark.range(30).selectExpr("cast(id / 3 as string) as k1"), + spark.range(10).selectExpr("cast(id as string) as k2")), + // Test NULL join key + (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr("value as k1"), + spark.range(30).map(i => if (i % 4 == 0) i else null).selectExpr("value as k2")) + ) + inputDFs.foreach { case (df1, df2) => + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + val smjDF = df1.join(df2, $"k1" === $"k2", "full") + assert(smjDF.queryExecution.executedPlan.collect { + case _: SortMergeJoinExec => true }.size === 1) + val smjResult = smjDF.collect() + + withSQLConf(SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { + val shjDF = df1.join(df2, $"k1" === $"k2", "full") + assert(shjDF.queryExecution.executedPlan.collect { + case _: ShuffledHashJoinExec => true }.size === 1) + // Same result between shuffled hash join and sort merge join + checkAnswer(shjDF, smjResult) + } + } + } + } } From 3b109f611717e67871b67ed4115fb9b55737e6f8 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Tue, 4 Aug 2020 14:02:47 -0700 Subject: [PATCH 02/11] Address all comments and add unit test --- .../sql/execution/joins/HashedRelation.scala | 39 ++++++++----------- .../joins/ShuffledHashJoinExec.scala | 7 +--- .../org/apache/spark/sql/JoinSuite.scala | 2 +- .../execution/joins/HashedRelationSuite.scala | 24 ++++++++++++ 4 files changed, 44 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index ee8926371b01a..2205f0c2eece6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -135,18 +135,15 @@ private[execution] object HashedRelation { private[joins] class UnsafeHashedRelation( private var numKeys: Int, private var numFields: Int, - private var binaryMap: BytesToBytesMap, - private val isLookupAware: Boolean = false) + private var binaryMap: BytesToBytesMap) extends HashedRelation with Externalizable with KryoSerializable { - private[joins] def this() = this(0, 0, null, false) // Needed for serialization + private[joins] def this() = this(0, 0, null) // Needed for serialization - override def keyIsUnique: Boolean = { - binaryMap.numKeys() == binaryMap.numValues() - } + override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues() override def asReadOnlyCopy(): UnsafeHashedRelation = { - new UnsafeHashedRelation(numKeys, numFields, binaryMap, isLookupAware) + new UnsafeHashedRelation(numKeys, numFields, binaryMap) } override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption @@ -317,23 +314,19 @@ private[joins] class UnsafeHashedRelation( } override def values(): Iterator[InternalRow] = { - if (isLookupAware) { - val iter = binaryMap.iterator() + val iter = binaryMap.iterator() - new Iterator[InternalRow] { - override def hasNext: Boolean = iter.hasNext + new Iterator[InternalRow] { + override def hasNext: Boolean = iter.hasNext - override def next(): InternalRow = { - if (!hasNext) { - throw new NoSuchElementException("End of the iterator") - } - val loc = iter.next() - resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - resultRow + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException("End of the iterator") } + val loc = iter.next() + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + resultRow } - } else { - throw new UnsupportedOperationException } } } @@ -405,7 +398,7 @@ private[joins] object UnsafeHashedRelation { } } - new UnsafeHashedRelation(key.size, numFields, binaryMap, isLookupAware) + new UnsafeHashedRelation(key.size, numFields, binaryMap) } } @@ -945,7 +938,9 @@ class LongHashedRelation( override def keys(): Iterator[InternalRow] = map.keys() override def values(): Iterator[InternalRow] = { - throw new UnsupportedOperationException + keys().flatMap { key => + get(key.getLong(0)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 151708a57a491..53bf85b0011c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -134,11 +134,8 @@ case class ShuffledHashJoinExec( val buildNullRow = new GenericInternalRow(buildOutput.length) val streamNullRow = new GenericInternalRow(streamedOutput.length) - def markRowLookedUp(row: UnsafeRow): Unit = { - if (!row.getBoolean(row.numFields() - 1)) { - row.setBoolean(row.numFields() - 1, true) - } - } + def markRowLookedUp(row: UnsafeRow): Unit = + row.setBoolean(row.numFields() - 1, true) // Process stream side with looking up hash relation val streamResultIter = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 02f69a8d651fc..0c44205f26eb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1189,7 +1189,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan } } - test("Full outer shuffled hash join") { + test("SPARK-32399: Full outer shuffled hash join") { val inputDFs = Seq( // Test unique join key (spark.range(10).selectExpr("id as k1"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 8b270bd5a2636..5969e7037a3c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -592,4 +592,28 @@ class HashedRelationSuite extends SharedSparkSession { assert(hashed.getValue(0L) == null) assert(hashed.getValue(key) == null) } + + test("SPARK-32399: test values() method for HashedRelation") { + val key = Seq(BoundReference(0, LongType, false)) + val value = Seq(BoundReference(0, IntegerType, true)) + val unsafeProj = UnsafeProjection.create(value) + val rows = (0 until 100).map(i => unsafeProj(InternalRow(i + 1)).copy()) + + // test LongHashedRelation + val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) + var values = longRelation.values() + assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === (0 until 100).map(i => i + 1)) + + // test UnsafeHashedRelation + val unsafeRelation = UnsafeHashedRelation(rows.iterator, key, 10, mm) + values = unsafeRelation.values() + assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === (0 until 100).map(i => i + 1)) + + // test lookup-aware UnsafeHashedRelation + val lookupAwareUnsafeRelation = UnsafeHashedRelation( + rows.iterator, key, 10, mm, isLookupAware = true, value = Some(value)) + values = lookupAwareUnsafeRelation.values() + assert(values.map(v => (v.getInt(0), v.getBoolean(1))).toArray.sortWith(_._1 < _._1) + === (0 until 100).map(i => (i + 1, false))) + } } From 663471b95e5c28d70bb183452de681fe74486891 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Fri, 7 Aug 2020 10:37:46 -0700 Subject: [PATCH 03/11] Address all comments --- .../sql/execution/joins/HashedRelation.scala | 83 ++++++++++--------- .../joins/ShuffledHashJoinExec.scala | 59 +++++-------- .../sql/execution/joins/ShuffledJoin.scala | 2 +- 3 files changed, 66 insertions(+), 78 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 2205f0c2eece6..3bab5bcd12be3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -96,6 +96,9 @@ private[execution] object HashedRelation { /** * Create a HashedRelation from an Iterator of InternalRow. + * + * @param isLookupAware reserve one extra boolean in value to track if value being looked up + * @param value the expressions for value inserted into HashedRelation */ def apply( input: Iterator[InternalRow], @@ -118,6 +121,8 @@ private[execution] object HashedRelation { if (!input.hasNext) { EmptyHashedRelation } else if (key.length == 1 && key.head.dataType == LongType && !isLookupAware) { + // NOTE: LongHashedRelation cannot support isLookupAware as it cannot + // handle NULL key LongHashedRelation(input, key, sizeEstimate, mm, isNullAware) } else { UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware, isLookupAware, value) @@ -148,7 +153,7 @@ private[joins] class UnsafeHashedRelation( override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption - // re-used in get()/getValue() + // re-used in get()/getValue()/values() var resultRow = new UnsafeRow(numFields) override def get(key: InternalRow): Iterator[InternalRow] = { @@ -186,6 +191,23 @@ private[joins] class UnsafeHashedRelation( } } + override def values(): Iterator[InternalRow] = { + val iter = binaryMap.iterator() + + new Iterator[InternalRow] { + override def hasNext: Boolean = iter.hasNext + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException("End of the iterator") + } + val loc = iter.next() + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + resultRow + } + } + } + override def keys(): Iterator[InternalRow] = { val iter = binaryMap.iterator() @@ -312,23 +334,6 @@ private[joins] class UnsafeHashedRelation( override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { read(() => in.readInt(), () => in.readLong(), in.readBytes) } - - override def values(): Iterator[InternalRow] = { - val iter = binaryMap.iterator() - - new Iterator[InternalRow] { - override def hasNext: Boolean = iter.hasNext - - override def next(): InternalRow = { - if (!hasNext) { - throw new NoSuchElementException("End of the iterator") - } - val loc = iter.next() - resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - resultRow - } - } - } } private[joins] object UnsafeHashedRelation { @@ -341,6 +346,10 @@ private[joins] object UnsafeHashedRelation { isNullAware: Boolean = false, isLookupAware: Boolean = false, value: Option[Seq[Expression]] = None): HashedRelation = { + if (isNullAware && isLookupAware) { + throw new SparkException( + "isLookupAware and isNullAware cannot be enabled at same time for UnsafeHashedRelation") + } val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().get(BUFFER_PAGESIZE).getOrElse(16L * 1024 * 1024)) @@ -354,27 +363,28 @@ private[joins] object UnsafeHashedRelation { val keyGenerator = UnsafeProjection.create(key) var numFields = 0 + val append = (key: UnsafeRow, value: UnsafeRow) => { + val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + val success = loc.append( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + if (!success) { + binaryMap.free() + // scalastyle:off throwerror + throw new SparkOutOfMemoryError("There is not enough memory to build hash map") + // scalastyle:on throwerror + } + } + if (isLookupAware) { // Add one extra boolean value at the end as part of the row, // to track the information that whether the corresponding key // has been looked up or not. See `ShuffledHashJoin.fullOuterJoin` for example of usage. val valueGenerator = UnsafeProjection.create(value.get :+ Literal(false)) - while (input.hasNext) { val row = input.next().asInstanceOf[UnsafeRow] numFields = row.numFields() + 1 - val key = keyGenerator(row) - val value = valueGenerator(row) - val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) - val success = loc.append( - key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, - value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) - if (!success) { - binaryMap.free() - // scalastyle:off throwerror - throw new SparkOutOfMemoryError("There is not enough memory to build hash map") - // scalastyle:on throwerror - } + append(keyGenerator(row), valueGenerator(row)) } } else { while (input.hasNext) { @@ -382,16 +392,7 @@ private[joins] object UnsafeHashedRelation { numFields = row.numFields() val key = keyGenerator(row) if (!key.anyNull) { - val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) - val success = loc.append( - key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, - row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) - if (!success) { - binaryMap.free() - // scalastyle:off throwerror - throw new SparkOutOfMemoryError("There is not enough memory to build hash map") - // scalastyle:on throwerror - } + append(key, row) } else if (isNullAware) { return EmptyHashedRelationWithAllNullKeys } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 53bf85b0011c3..6c6f5cd8cdbc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -66,12 +66,11 @@ case class ShuffledHashJoinExec( val start = System.nanoTime() val context = TaskContext.get() - val (isLookupAware, value) = - if (joinType == FullOuter) { - (true, Some(BindReferences.bindReferences(buildOutput, buildOutput))) - } else { - (false, None) - } + val (isLookupAware, value) = if (joinType == FullOuter) { + (true, Some(BindReferences.bindReferences(buildOutput, buildOutput))) + } else { + (false, None) + } val relation = HashedRelation( iter, buildBoundKeys, @@ -110,24 +109,12 @@ case class ShuffledHashJoinExec( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation, numOutputRows: SQLMetric): Iterator[InternalRow] = { - abstract class HashJoinedRow extends JoinedRow { - /** Updates this JoinedRow by updating its stream side row. Returns itself. */ - def withStream(newStream: InternalRow): JoinedRow - - /** Updates this JoinedRow by updating its build side row. Returns itself. */ - def withBuild(newBuild: InternalRow): JoinedRow - } - val joinRow: HashJoinedRow = buildSide match { - case BuildLeft => - new HashJoinedRow { - override def withStream(newStream: InternalRow): JoinedRow = withRight(newStream) - override def withBuild(newBuild: InternalRow): JoinedRow = withLeft(newBuild) - } - case BuildRight => - new HashJoinedRow { - override def withStream(newStream: InternalRow): JoinedRow = withLeft(newStream) - override def withBuild(newBuild: InternalRow): JoinedRow = withRight(newBuild) - } + val joinRow = new JoinedRow + val (joinRowWithStream, joinRowWithBuild) = { + buildSide match { + case BuildLeft => (joinRow.withRight _, joinRow.withLeft _) + case BuildRight => (joinRow.withLeft _, joinRow.withRight _) + } } val joinKeys = streamSideKeyGenerator() val buildRowGenerator = UnsafeProjection.create(buildOutput, buildOutput) @@ -141,31 +128,31 @@ case class ShuffledHashJoinExec( val streamResultIter = if (hashedRelation.keyIsUnique) { streamIter.map { srow => - joinRow.withStream(srow) + joinRowWithStream(srow) val keys = joinKeys(srow) if (keys.anyNull) { - joinRow.withBuild(buildNullRow) + joinRowWithBuild(buildNullRow) } else { val matched = hashedRelation.getValue(keys) if (matched != null) { val buildRow = buildRowGenerator(matched) - if (boundCondition(joinRow.withBuild(buildRow))) { + if (boundCondition(joinRowWithBuild(buildRow))) { markRowLookedUp(matched.asInstanceOf[UnsafeRow]) joinRow } else { - joinRow.withBuild(buildNullRow) + joinRowWithBuild(buildNullRow) } } else { - joinRow.withBuild(buildNullRow) + joinRowWithBuild(buildNullRow) } } } } else { streamIter.flatMap { srow => - joinRow.withStream(srow) + joinRowWithStream(srow) val keys = joinKeys(srow) if (keys.anyNull) { - Iterator.single(joinRow.withBuild(buildNullRow)) + Iterator.single(joinRowWithBuild(buildNullRow)) } else { val buildIter = hashedRelation.get(keys) new RowIterator { @@ -174,14 +161,14 @@ case class ShuffledHashJoinExec( while (buildIter != null && buildIter.hasNext) { val matched = buildIter.next() val buildRow = buildRowGenerator(matched) - if (boundCondition(joinRow.withBuild(buildRow))) { + if (boundCondition(joinRowWithBuild(buildRow))) { markRowLookedUp(matched.asInstanceOf[UnsafeRow]) found = true return true } } if (!found) { - joinRow.withBuild(buildNullRow) + joinRowWithBuild(buildNullRow) found = true return true } @@ -199,8 +186,8 @@ case class ShuffledHashJoinExec( val isLookup = unsafebrow.getBoolean(unsafebrow.numFields() - 1) if (!isLookup) { val buildRow = buildRowGenerator(unsafebrow) - joinRow.withBuild(buildRow) - joinRow.withStream(streamNullRow) + joinRowWithBuild(buildRow) + joinRowWithStream(streamNullRow) Some(joinRow) } else { None @@ -214,7 +201,7 @@ case class ShuffledHashJoinExec( } } - // TODO: support full outer shuffled hash join code-gen + // TODO(SPARK-32567): support full outer shuffled hash join code-gen override def supportCodegen: Boolean = { joinType != FullOuter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala index 0318baf27f9fd..92bfc1ed4861d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala @@ -58,7 +58,7 @@ trait ShuffledJoin extends BaseJoinExec { left.output case x => throw new IllegalArgumentException( - s"ShuffledJoin not take $x as the JoinType") + s"${getClass.getSimpleName} not take $x as the JoinType") } } } From e3322766d4ea6d039f819a46e12dc8641ca59c63 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Sun, 9 Aug 2020 19:31:09 -0700 Subject: [PATCH 04/11] Address all new comments --- .../sql/execution/joins/HashedRelation.scala | 30 +++++++++---------- .../joins/ShuffledHashJoinExec.scala | 27 ++++++++++++----- .../org/apache/spark/sql/JoinSuite.scala | 27 ++++++++++++----- .../execution/joins/HashedRelationSuite.scala | 15 +++++----- 4 files changed, 61 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 3bab5bcd12be3..20ad602ee3c51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -97,8 +97,9 @@ private[execution] object HashedRelation { /** * Create a HashedRelation from an Iterator of InternalRow. * - * @param isLookupAware reserve one extra boolean in value to track if value being looked up - * @param value the expressions for value inserted into HashedRelation + * @param canMarkRowLookedUp Reserve one extra boolean in value to track if value being looked up. + * This is only used for full outer shuffled hash join. + * @param valueExprs The expressions for value inserted into HashedRelation. */ def apply( input: Iterator[InternalRow], @@ -106,8 +107,8 @@ private[execution] object HashedRelation { sizeEstimate: Int = 64, taskMemoryManager: TaskMemoryManager = null, isNullAware: Boolean = false, - isLookupAware: Boolean = false, - value: Option[Seq[Expression]] = None): HashedRelation = { + canMarkRowLookedUp: Boolean = false, + valueExprs: Option[Seq[Expression]] = None): HashedRelation = { val mm = Option(taskMemoryManager).getOrElse { new TaskMemoryManager( new UnifiedMemoryManager( @@ -120,12 +121,13 @@ private[execution] object HashedRelation { if (!input.hasNext) { EmptyHashedRelation - } else if (key.length == 1 && key.head.dataType == LongType && !isLookupAware) { - // NOTE: LongHashedRelation cannot support isLookupAware as it cannot + } else if (key.length == 1 && key.head.dataType == LongType && !canMarkRowLookedUp) { + // NOTE: LongHashedRelation cannot support canMarkRowLookedUp as it cannot // handle NULL key LongHashedRelation(input, key, sizeEstimate, mm, isNullAware) } else { - UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware, isLookupAware, value) + UnsafeHashedRelation( + input, key, sizeEstimate, mm, isNullAware, canMarkRowLookedUp, valueExprs) } } } @@ -344,12 +346,10 @@ private[joins] object UnsafeHashedRelation { sizeEstimate: Int, taskMemoryManager: TaskMemoryManager, isNullAware: Boolean = false, - isLookupAware: Boolean = false, - value: Option[Seq[Expression]] = None): HashedRelation = { - if (isNullAware && isLookupAware) { - throw new SparkException( - "isLookupAware and isNullAware cannot be enabled at same time for UnsafeHashedRelation") - } + canMarkRowLookedUp: Boolean = false, + valueExprs: Option[Seq[Expression]] = None): HashedRelation = { + require(!(isNullAware && canMarkRowLookedUp), + "isNullAware and canMarkRowLookedUp cannot be enabled at same time") val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().get(BUFFER_PAGESIZE).getOrElse(16L * 1024 * 1024)) @@ -376,11 +376,11 @@ private[joins] object UnsafeHashedRelation { } } - if (isLookupAware) { + if (canMarkRowLookedUp) { // Add one extra boolean value at the end as part of the row, // to track the information that whether the corresponding key // has been looked up or not. See `ShuffledHashJoin.fullOuterJoin` for example of usage. - val valueGenerator = UnsafeProjection.create(value.get :+ Literal(false)) + val valueGenerator = UnsafeProjection.create(valueExprs.get :+ Literal(false)) while (input.hasNext) { val row = input.next().asInstanceOf[UnsafeRow] numFields = row.numFields() + 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 6c6f5cd8cdbc0..7af6f81b5eded 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -66,7 +66,7 @@ case class ShuffledHashJoinExec( val start = System.nanoTime() val context = TaskContext.get() - val (isLookupAware, value) = if (joinType == FullOuter) { + val (canMarkRowLookedUp, valueExprs) = if (joinType == FullOuter) { (true, Some(BindReferences.bindReferences(buildOutput, buildOutput))) } else { (false, None) @@ -75,8 +75,8 @@ case class ShuffledHashJoinExec( iter, buildBoundKeys, taskMemoryManager = context.taskMemoryManager(), - isLookupAware = isLookupAware, - value = value) + canMarkRowLookedUp = canMarkRowLookedUp, + valueExprs = valueExprs) buildTime += NANOSECONDS.toMillis(System.nanoTime() - start) buildDataSize += relation.estimatedSize // This relation is usually used until the end of task. @@ -103,7 +103,7 @@ case class ShuffledHashJoinExec( * 2. Process rows from stream side by looking up hash relation, * and mark the matched rows from build side be looked up. * 3. Process rows from build side by iterating hash relation, - * and filter out rows from build side being looked up already. + * and filter out rows from build side being matched already. */ private def fullOuterJoin( streamIter: Iterator[InternalRow], @@ -180,15 +180,26 @@ case class ShuffledHashJoinExec( } } - // Process build side with filtering out rows looked up already + // Process build side with filtering out rows looked up and + // passed join condition already + val streamNullJoinRow = new JoinedRow + val streamNullJoinRowWithBuild = { + buildSide match { + case BuildLeft => + streamNullJoinRow.withRight(streamNullRow) + streamNullJoinRow.withLeft _ + case BuildRight => + streamNullJoinRow.withLeft(streamNullRow) + streamNullJoinRow.withRight _ + } + } val buildResultIter = hashedRelation.values().flatMap { brow => val unsafebrow = brow.asInstanceOf[UnsafeRow] val isLookup = unsafebrow.getBoolean(unsafebrow.numFields() - 1) if (!isLookup) { val buildRow = buildRowGenerator(unsafebrow) - joinRowWithBuild(buildRow) - joinRowWithStream(streamNullRow) - Some(joinRow) + streamNullJoinRowWithBuild(buildRow) + Some(streamNullJoinRow) } else { None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 0c44205f26eb0..ce239864f1f87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1193,31 +1193,42 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan val inputDFs = Seq( // Test unique join key (spark.range(10).selectExpr("id as k1"), - spark.range(30).selectExpr("id as k2")), + spark.range(30).selectExpr("id as k2"), + $"k1" === $"k2"), // Test non-unique join key (spark.range(10).selectExpr("id % 5 as k1"), - spark.range(30).selectExpr("id % 5 as k2")), + spark.range(30).selectExpr("id % 5 as k2"), + $"k1" === $"k2"), // Test string join key (spark.range(10).selectExpr("cast(id * 3 as string) as k1"), - spark.range(30).selectExpr("cast(id as string) as k2")), + spark.range(30).selectExpr("cast(id as string) as k2"), + $"k1" === $"k2"), // Test build side at right (spark.range(30).selectExpr("cast(id / 3 as string) as k1"), - spark.range(10).selectExpr("cast(id as string) as k2")), + spark.range(10).selectExpr("cast(id as string) as k2"), + $"k1" === $"k2"), // Test NULL join key (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr("value as k1"), - spark.range(30).map(i => if (i % 4 == 0) i else null).selectExpr("value as k2")) + spark.range(30).map(i => if (i % 4 == 0) i else null).selectExpr("value as k2"), + $"k1" === $"k2"), + // Test multiple join keys + (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr( + "value as k1", "cast(value % 5 as short) as k2", "cast(value * 3 as long) as k3"), + spark.range(30).map(i => if (i % 4 == 0) i else null).selectExpr( + "value as k4", "cast(value % 5 as short) as k5", "cast(value * 3 as long) as k6"), + $"k1" === $"k4" && $"k2" === $"k5" && $"k3" === $"k6") ) - inputDFs.foreach { case (df1, df2) => + inputDFs.foreach { case (df1, df2, joinExprs) => withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", SQLConf.SHUFFLE_PARTITIONS.key -> "2") { - val smjDF = df1.join(df2, $"k1" === $"k2", "full") + val smjDF = df1.join(df2, joinExprs, "full") assert(smjDF.queryExecution.executedPlan.collect { case _: SortMergeJoinExec => true }.size === 1) val smjResult = smjDF.collect() withSQLConf(SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { - val shjDF = df1.join(df2, $"k1" === $"k2", "full") + val shjDF = df1.join(df2, joinExprs, "full") assert(shjDF.queryExecution.executedPlan.collect { case _: ShuffledHashJoinExec => true }.size === 1) // Same result between shuffled hash join and sort merge join diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 5969e7037a3c8..d7b330368f694 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -598,22 +598,23 @@ class HashedRelationSuite extends SharedSparkSession { val value = Seq(BoundReference(0, IntegerType, true)) val unsafeProj = UnsafeProjection.create(value) val rows = (0 until 100).map(i => unsafeProj(InternalRow(i + 1)).copy()) + val expectedValues = (0 until 100).map(i => i + 1) // test LongHashedRelation val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) var values = longRelation.values() - assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === (0 until 100).map(i => i + 1)) + assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === expectedValues) // test UnsafeHashedRelation val unsafeRelation = UnsafeHashedRelation(rows.iterator, key, 10, mm) values = unsafeRelation.values() - assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === (0 until 100).map(i => i + 1)) + assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === expectedValues) - // test lookup-aware UnsafeHashedRelation - val lookupAwareUnsafeRelation = UnsafeHashedRelation( - rows.iterator, key, 10, mm, isLookupAware = true, value = Some(value)) - values = lookupAwareUnsafeRelation.values() + // test UnsafeHashedRelation which can mark row looked up + val markRowUnsafeRelation = UnsafeHashedRelation( + rows.iterator, key, 10, mm, canMarkRowLookedUp = true, valueExprs = Some(value)) + values = markRowUnsafeRelation.values() assert(values.map(v => (v.getInt(0), v.getBoolean(1))).toArray.sortWith(_._1 < _._1) - === (0 until 100).map(i => (i + 1, false))) + === expectedValues.map(i => (i, false))) } } From bd7261e7e6c9380ec0708b4c3d92c591f2ba0ebf Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Tue, 11 Aug 2020 23:28:05 -0700 Subject: [PATCH 05/11] Change approach to keep row matching info separately --- .../spark/unsafe/map/BytesToBytesMap.java | 63 +++++ .../sql/execution/joins/HashedRelation.scala | 158 ++++++----- .../joins/ShuffledHashJoinExec.scala | 261 ++++++++++++------ .../execution/joins/HashedRelationSuite.scala | 25 -- 4 files changed, 325 insertions(+), 182 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 6e028886f2318..8ac0dc005440e 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -428,6 +428,61 @@ public MapIterator destructiveIterator() { return new MapIterator(numValues, new Location(), true); } + /** + * Iterator for the entries of this map. This is to first iterate over key index array + * `longArray` then accessing values in `dataPages`. NOTE: this is different from `MapIterator` + * in the sense that key index is preserved here (See `UnsafeHashedRelation` for example of usage). + */ + public final class MapIteratorWithKeyIndex implements Iterator { + + private int keyIndex = 0; + private int numRecords; + private final Location loc; + + private MapIteratorWithKeyIndex(int numRecords, Location loc) { + this.numRecords = numRecords; + this.loc = loc; + } + + @Override + public boolean hasNext() { + return numRecords > 0; + } + + @Override + public Location next() { + if (!loc.isDefined() || !loc.nextValue()) { + while (longArray.get(keyIndex * 2) == 0) { + keyIndex++; + } + loc.with(keyIndex, (int) longArray.get(keyIndex * 2 + 1), true); + keyIndex++; + } + numRecords--; + return loc; + } + } + + /** + * Returns an iterator for iterating over the entries of this map, + * by first iterating over the key index inside hash map's `longArray`. + * + * For efficiency, all calls to `next()` will return the same {@link Location} object. + * + * The returned iterator is NOT thread-safe. If the map is modified while iterating over it, + * the behavior of the returned iterator is undefined. + */ + public MapIteratorWithKeyIndex iteratorWithKeyIndex() { + return new MapIteratorWithKeyIndex(numValues, new Location()); + } + + /** + * Number of allowed keys index. + */ + public int numKeysIndex() { + return (int) (longArray.size() / 2); + } + /** * Looks up a key, and return a {@link Location} handle that can be used to test existence * and read/write values. @@ -601,6 +656,14 @@ public boolean isDefined() { return isDefined; } + /** + * Returns index for key. + */ + public int getKeyIndex() { + assert (isDefined); + return pos; + } + /** * Returns the base object for key. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 20ad602ee3c51..226d92fa87769 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -66,6 +66,23 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { throw new UnsupportedOperationException } + /** + * Returns key index and matched rows. + * + * Returns null if there is no matched rows. + */ + def getWithKeyIndex(key: InternalRow): (Int, Iterator[InternalRow]) + + /** + * Returns an iterator for keys index and rows of InternalRow type. + */ + def valuesWithKeyIndex(): Iterator[(Int, InternalRow)] + + /** + * Returns number of keys index. + */ + def numKeysIndex: Int + /** * Returns true iff all the keys are unique. */ @@ -76,11 +93,6 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { */ def keys(): Iterator[InternalRow] - /** - * Returns an iterator for values of InternalRow type. - */ - def values(): Iterator[InternalRow] - /** * Returns a read-only copy of this, to be safely used in current thread. */ @@ -97,9 +109,8 @@ private[execution] object HashedRelation { /** * Create a HashedRelation from an Iterator of InternalRow. * - * @param canMarkRowLookedUp Reserve one extra boolean in value to track if value being looked up. - * This is only used for full outer shuffled hash join. - * @param valueExprs The expressions for value inserted into HashedRelation. + * @param allowsNullKey Allow NULL keys in HashedRelation. + * This is used for full outer join in `ShuffledHashJoinExec` only. */ def apply( input: Iterator[InternalRow], @@ -107,8 +118,7 @@ private[execution] object HashedRelation { sizeEstimate: Int = 64, taskMemoryManager: TaskMemoryManager = null, isNullAware: Boolean = false, - canMarkRowLookedUp: Boolean = false, - valueExprs: Option[Seq[Expression]] = None): HashedRelation = { + allowsNullKey: Boolean = false): HashedRelation = { val mm = Option(taskMemoryManager).getOrElse { new TaskMemoryManager( new UnifiedMemoryManager( @@ -121,13 +131,11 @@ private[execution] object HashedRelation { if (!input.hasNext) { EmptyHashedRelation - } else if (key.length == 1 && key.head.dataType == LongType && !canMarkRowLookedUp) { - // NOTE: LongHashedRelation cannot support canMarkRowLookedUp as it cannot - // handle NULL key + } else if (key.length == 1 && key.head.dataType == LongType && !allowsNullKey) { + // NOTE: LongHashedRelation does not support NULL keys. LongHashedRelation(input, key, sizeEstimate, mm, isNullAware) } else { - UnsafeHashedRelation( - input, key, sizeEstimate, mm, isNullAware, canMarkRowLookedUp, valueExprs) + UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware, allowsNullKey) } } } @@ -155,7 +163,7 @@ private[joins] class UnsafeHashedRelation( override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption - // re-used in get()/getValue()/values() + // re-used in get()/getValue()/getWithKeyIndex()/valuesWithKeyIndex() var resultRow = new UnsafeRow(numFields) override def get(key: InternalRow): Iterator[InternalRow] = { @@ -193,23 +201,49 @@ private[joins] class UnsafeHashedRelation( } } - override def values(): Iterator[InternalRow] = { - val iter = binaryMap.iterator() + override def getWithKeyIndex(key: InternalRow): (Int, Iterator[InternalRow]) = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) + if (loc.isDefined) { + (loc.getKeyIndex, + new Iterator[UnsafeRow] { + private var _hasNext = true + override def hasNext: Boolean = _hasNext + override def next(): UnsafeRow = { + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + _hasNext = loc.nextValue() + resultRow + } + }) + } else { + null + } + } - new Iterator[InternalRow] { + override def valuesWithKeyIndex(): Iterator[(Int, InternalRow)] = { + val iter = binaryMap.iteratorWithKeyIndex() + + new Iterator[(Int, InternalRow)] { override def hasNext: Boolean = iter.hasNext - override def next(): InternalRow = { + override def next(): (Int, InternalRow) = { if (!hasNext) { throw new NoSuchElementException("End of the iterator") } val loc = iter.next() resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - resultRow + (loc.getKeyIndex, resultRow) } } } + override def numKeysIndex: Int = { + binaryMap.numKeysIndex + } + override def keys(): Iterator[InternalRow] = { val iter = binaryMap.iterator() @@ -346,10 +380,7 @@ private[joins] object UnsafeHashedRelation { sizeEstimate: Int, taskMemoryManager: TaskMemoryManager, isNullAware: Boolean = false, - canMarkRowLookedUp: Boolean = false, - valueExprs: Option[Seq[Expression]] = None): HashedRelation = { - require(!(isNullAware && canMarkRowLookedUp), - "isNullAware and canMarkRowLookedUp cannot be enabled at same time") + allowsNullKey: Boolean = false): HashedRelation = { val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().get(BUFFER_PAGESIZE).getOrElse(16L * 1024 * 1024)) @@ -362,40 +393,23 @@ private[joins] object UnsafeHashedRelation { // Create a mapping of buildKeys -> rows val keyGenerator = UnsafeProjection.create(key) var numFields = 0 - - val append = (key: UnsafeRow, value: UnsafeRow) => { - val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) - val success = loc.append( - key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, - value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) - if (!success) { - binaryMap.free() - // scalastyle:off throwerror - throw new SparkOutOfMemoryError("There is not enough memory to build hash map") - // scalastyle:on throwerror - } - } - - if (canMarkRowLookedUp) { - // Add one extra boolean value at the end as part of the row, - // to track the information that whether the corresponding key - // has been looked up or not. See `ShuffledHashJoin.fullOuterJoin` for example of usage. - val valueGenerator = UnsafeProjection.create(valueExprs.get :+ Literal(false)) - while (input.hasNext) { - val row = input.next().asInstanceOf[UnsafeRow] - numFields = row.numFields() + 1 - append(keyGenerator(row), valueGenerator(row)) - } - } else { - while (input.hasNext) { - val row = input.next().asInstanceOf[UnsafeRow] - numFields = row.numFields() - val key = keyGenerator(row) - if (!key.anyNull) { - append(key, row) - } else if (isNullAware) { - return EmptyHashedRelationWithAllNullKeys + while (input.hasNext) { + val row = input.next().asInstanceOf[UnsafeRow] + numFields = row.numFields() + val key = keyGenerator(row) + if (!key.anyNull || allowsNullKey) { + val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + val success = loc.append( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) + if (!success) { + binaryMap.free() + // scalastyle:off throwerror + throw new SparkOutOfMemoryError("There is not enough memory to build hash map") + // scalastyle:on throwerror } + } else if (isNullAware) { + return EmptyHashedRelationWithAllNullKeys } } @@ -938,10 +952,16 @@ class LongHashedRelation( */ override def keys(): Iterator[InternalRow] = map.keys() - override def values(): Iterator[InternalRow] = { - keys().flatMap { key => - get(key.getLong(0)) - } + override def getWithKeyIndex(key: InternalRow): (Int, Iterator[InternalRow]) = { + throw new UnsupportedOperationException + } + + override def valuesWithKeyIndex(): Iterator[(Int, InternalRow)] = { + throw new UnsupportedOperationException + } + + override def numKeysIndex: Int = { + throw new UnsupportedOperationException } } @@ -991,13 +1011,21 @@ trait NullAwareHashedRelation extends HashedRelation with Externalizable { throw new UnsupportedOperationException } - override def keyIsUnique: Boolean = true + override def getWithKeyIndex(key: InternalRow): (Int, Iterator[InternalRow]) = { + throw new UnsupportedOperationException + } - override def keys(): Iterator[InternalRow] = { + override def valuesWithKeyIndex(): Iterator[(Int, InternalRow)] = { + throw new UnsupportedOperationException + } + + override def numKeysIndex: Int = { throw new UnsupportedOperationException } - override def values(): Iterator[InternalRow] = { + override def keyIsUnique: Boolean = true + + override def keys(): Iterator[InternalRow] = { throw new UnsupportedOperationException } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 7af6f81b5eded..31db638b486fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.joins import java.util.concurrent.TimeUnit._ +import scala.collection.mutable + import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -29,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.util.collection.BitSet /** * Performs a hash join of two child relations by first shuffling the data using the join keys. @@ -65,18 +68,12 @@ case class ShuffledHashJoinExec( val buildTime = longMetric("buildTime") val start = System.nanoTime() val context = TaskContext.get() - - val (canMarkRowLookedUp, valueExprs) = if (joinType == FullOuter) { - (true, Some(BindReferences.bindReferences(buildOutput, buildOutput))) - } else { - (false, None) - } val relation = HashedRelation( iter, buildBoundKeys, taskMemoryManager = context.taskMemoryManager(), - canMarkRowLookedUp = canMarkRowLookedUp, - valueExprs = valueExprs) + // Full outer join needs support for NULL key in HashedRelation. + allowsNullKey = joinType == FullOuter) buildTime += NANOSECONDS.toMillis(System.nanoTime() - start) buildDataSize += relation.estimatedSize // This relation is usually used until the end of task. @@ -95,20 +92,11 @@ case class ShuffledHashJoinExec( } } - /** - * Full outer shuffled hash join has three steps: - * 1. Construct hash relation from build side, - * with extra boolean value at the end of row to track look up information - * (done in `buildHashedRelation`). - * 2. Process rows from stream side by looking up hash relation, - * and mark the matched rows from build side be looked up. - * 3. Process rows from build side by iterating hash relation, - * and filter out rows from build side being matched already. - */ private def fullOuterJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation, numOutputRows: SQLMetric): Iterator[InternalRow] = { + val joinKeys = streamSideKeyGenerator() val joinRow = new JoinedRow val (joinRowWithStream, joinRowWithBuild) = { buildSide match { @@ -116,100 +104,189 @@ case class ShuffledHashJoinExec( case BuildRight => (joinRow.withLeft _, joinRow.withRight _) } } - val joinKeys = streamSideKeyGenerator() - val buildRowGenerator = UnsafeProjection.create(buildOutput, buildOutput) val buildNullRow = new GenericInternalRow(buildOutput.length) val streamNullRow = new GenericInternalRow(streamedOutput.length) + val streamNullJoinRow = new JoinedRow + val streamNullJoinRowWithBuild = { + buildSide match { + case BuildLeft => + streamNullJoinRow.withRight(streamNullRow) + streamNullJoinRow.withLeft _ + case BuildRight => + streamNullJoinRow.withLeft(streamNullRow) + streamNullJoinRow.withRight _ + } + } - def markRowLookedUp(row: UnsafeRow): Unit = - row.setBoolean(row.numFields() - 1, true) + val iter = if (hashedRelation.keyIsUnique) { + fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, joinRow, streamNullJoinRow, + joinRowWithStream, joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, + streamNullRow) + } else { + fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, joinRow, + streamNullJoinRow, joinRowWithStream, joinRowWithBuild, streamNullJoinRowWithBuild, + buildNullRow, streamNullRow) + } + + val resultProj = UnsafeProjection.create(output, output) + iter.map { r => + numOutputRows += 1 + resultProj(r) + } + } + + /** + * Full outer shuffled hash join with unique join keys: + * 1. Process rows from stream side by looking up hash relation. + * Mark the matched rows from build side be looked up. + * A `BitSet` is used to track matched rows with key index. + * 2. Process rows from build side by iterating hash relation. + * Filter out rows from build side being matched already, + * by checking key index from `BitSet`. + */ + private def fullOuterJoinWithUniqueKey( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation, + joinKeys: UnsafeProjection, + joinRow: JoinedRow, + streamNullJoinRow: JoinedRow, + joinRowWithStream: InternalRow => JoinedRow, + joinRowWithBuild: InternalRow => JoinedRow, + streamNullJoinRowWithBuild: InternalRow => JoinedRow, + buildNullRow: GenericInternalRow, + streamNullRow: GenericInternalRow): Iterator[InternalRow] = { + val matchedKeys = new BitSet(hashedRelation.numKeysIndex) // Process stream side with looking up hash relation - val streamResultIter = - if (hashedRelation.keyIsUnique) { - streamIter.map { srow => - joinRowWithStream(srow) - val keys = joinKeys(srow) - if (keys.anyNull) { - joinRowWithBuild(buildNullRow) + val streamResultIter = streamIter.map { srow => + joinRowWithStream(srow) + val keys = joinKeys(srow) + if (keys.anyNull) { + joinRowWithBuild(buildNullRow) + } else { + val matched = hashedRelation.getWithKeyIndex(keys) + if (matched != null) { + val (keyIndex, buildIter) = (matched._1, matched._2) + val buildRow = buildIter.next + if (boundCondition(joinRowWithBuild(buildRow))) { + matchedKeys.set(keyIndex) + joinRow } else { - val matched = hashedRelation.getValue(keys) - if (matched != null) { - val buildRow = buildRowGenerator(matched) - if (boundCondition(joinRowWithBuild(buildRow))) { - markRowLookedUp(matched.asInstanceOf[UnsafeRow]) - joinRow - } else { - joinRowWithBuild(buildNullRow) - } - } else { - joinRowWithBuild(buildNullRow) - } + joinRowWithBuild(buildNullRow) } + } else { + joinRowWithBuild(buildNullRow) } + } + } + + // Process build side with filtering out rows looked up and + // passed join condition already + val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap { + case (keyIndex, brow) => + val isMatched = matchedKeys.get(keyIndex) + if (!isMatched) { + streamNullJoinRowWithBuild(brow) + Some(streamNullJoinRow) + } else { + None + } + } + + streamResultIter ++ buildResultIter + } + + /** + * Full outer shuffled hash join with unique join keys: + * 1. Process rows from stream side by looking up hash relation. + * Mark the matched rows from build side be looked up. + * A `HashSet[Long]` is used to track matched rows with + * key index (Int) and value index (Int) together. + * 2. Process rows from build side by iterating hash relation. + * Filter out rows from build side being matched already, + * by checking key index and value index from `HashSet`. + */ + private def fullOuterJoinWithNonUniqueKey( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation, + joinKeys: UnsafeProjection, + joinRow: JoinedRow, + streamNullJoinRow: JoinedRow, + joinRowWithStream: InternalRow => JoinedRow, + joinRowWithBuild: InternalRow => JoinedRow, + streamNullJoinRowWithBuild: InternalRow => JoinedRow, + buildNullRow: GenericInternalRow, + streamNullRow: GenericInternalRow): Iterator[InternalRow] = { + val matchedRows = new mutable.HashSet[Long] + + def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = { + val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex + matchedRows.add(rowIndex) + } + + def isRowMatched(keyIndex: Int, valueIndex: Int): Boolean = { + val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex + matchedRows.contains(rowIndex) + } + + // Process stream side with looking up hash relation + val streamResultIter = streamIter.flatMap { srow => + joinRowWithStream(srow) + val keys = joinKeys(srow) + if (keys.anyNull) { + Iterator.single(joinRowWithBuild(buildNullRow)) } else { - streamIter.flatMap { srow => - joinRowWithStream(srow) - val keys = joinKeys(srow) - if (keys.anyNull) { - Iterator.single(joinRowWithBuild(buildNullRow)) - } else { - val buildIter = hashedRelation.get(keys) - new RowIterator { - private var found = false - override def advanceNext(): Boolean = { - while (buildIter != null && buildIter.hasNext) { - val matched = buildIter.next() - val buildRow = buildRowGenerator(matched) - if (boundCondition(joinRowWithBuild(buildRow))) { - markRowLookedUp(matched.asInstanceOf[UnsafeRow]) - found = true - return true - } - } - if (!found) { - joinRowWithBuild(buildNullRow) + val matched = hashedRelation.getWithKeyIndex(keys) + if (matched != null) { + val (keyIndex, buildIter) = (matched._1, matched._2.zipWithIndex) + + new RowIterator { + private var found = false + override def advanceNext(): Boolean = { + while (buildIter.hasNext) { + val (buildRow, valueIndex) = buildIter.next() + if (boundCondition(joinRowWithBuild(buildRow))) { + markRowMatched(keyIndex, valueIndex) found = true return true } - false } - override def getRow: InternalRow = joinRow - }.toScala - } + if (!found) { + joinRowWithBuild(buildNullRow) + found = true + return true + } + false + } + override def getRow: InternalRow = joinRow + }.toScala + } else { + Iterator.single(joinRowWithBuild(buildNullRow)) } } + } // Process build side with filtering out rows looked up and // passed join condition already - val streamNullJoinRow = new JoinedRow - val streamNullJoinRowWithBuild = { - buildSide match { - case BuildLeft => - streamNullJoinRow.withRight(streamNullRow) - streamNullJoinRow.withLeft _ - case BuildRight => - streamNullJoinRow.withLeft(streamNullRow) - streamNullJoinRow.withRight _ - } - } - val buildResultIter = hashedRelation.values().flatMap { brow => - val unsafebrow = brow.asInstanceOf[UnsafeRow] - val isLookup = unsafebrow.getBoolean(unsafebrow.numFields() - 1) - if (!isLookup) { - val buildRow = buildRowGenerator(unsafebrow) - streamNullJoinRowWithBuild(buildRow) - Some(streamNullJoinRow) - } else { - None - } + var prevKeyIndex = -1 + var valueIndex = -1 + val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap { + case (keyIndex, brow) => + if (prevKeyIndex == -1 || keyIndex != prevKeyIndex) { + prevKeyIndex = keyIndex + valueIndex = -1 + } + valueIndex += 1 + val isMatched = isRowMatched(keyIndex, valueIndex) + if (!isMatched) { + streamNullJoinRowWithBuild(brow) + Some(streamNullJoinRow) + } else { + None + } } - val resultProj = UnsafeProjection.create(output, output) - (streamResultIter ++ buildResultIter).map { r => - numOutputRows += 1 - resultProj(r) - } + streamResultIter ++ buildResultIter } // TODO(SPARK-32567): support full outer shuffled hash join code-gen diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index d7b330368f694..8b270bd5a2636 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -592,29 +592,4 @@ class HashedRelationSuite extends SharedSparkSession { assert(hashed.getValue(0L) == null) assert(hashed.getValue(key) == null) } - - test("SPARK-32399: test values() method for HashedRelation") { - val key = Seq(BoundReference(0, LongType, false)) - val value = Seq(BoundReference(0, IntegerType, true)) - val unsafeProj = UnsafeProjection.create(value) - val rows = (0 until 100).map(i => unsafeProj(InternalRow(i + 1)).copy()) - val expectedValues = (0 until 100).map(i => i + 1) - - // test LongHashedRelation - val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) - var values = longRelation.values() - assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === expectedValues) - - // test UnsafeHashedRelation - val unsafeRelation = UnsafeHashedRelation(rows.iterator, key, 10, mm) - values = unsafeRelation.values() - assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === expectedValues) - - // test UnsafeHashedRelation which can mark row looked up - val markRowUnsafeRelation = UnsafeHashedRelation( - rows.iterator, key, 10, mm, canMarkRowLookedUp = true, valueExprs = Some(value)) - values = markRowUnsafeRelation.values() - assert(values.map(v => (v.getInt(0), v.getBoolean(1))).toArray.sortWith(_._1 < _._1) - === expectedValues.map(i => (i, false))) - } } From f4f0b0f201469a8bb1d2bd29cf2cdaafba351787 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Wed, 12 Aug 2020 21:11:16 -0700 Subject: [PATCH 06/11] Address all comments and add unit test for HashedRelation and BytesToBytesMap --- .../spark/unsafe/map/BytesToBytesMap.java | 7 +- .../map/AbstractBytesToBytesMapSuite.java | 27 +++++- .../sql/execution/joins/HashedRelation.scala | 84 +++++++++++++++---- .../joins/ShuffledHashJoinExec.scala | 41 +++++---- .../org/apache/spark/sql/JoinSuite.scala | 4 +- .../execution/joins/HashedRelationSuite.scala | 78 +++++++++++++++++ 6 files changed, 198 insertions(+), 43 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 8ac0dc005440e..07ec2db56a370 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -431,7 +431,8 @@ public MapIterator destructiveIterator() { /** * Iterator for the entries of this map. This is to first iterate over key index array * `longArray` then accessing values in `dataPages`. NOTE: this is different from `MapIterator` - * in the sense that key index is preserved here (See `UnsafeHashedRelation` for example of usage). + * in the sense that key index is preserved here + * (See `UnsafeHashedRelation` for example of usage). */ public final class MapIteratorWithKeyIndex implements Iterator { @@ -477,9 +478,9 @@ public MapIteratorWithKeyIndex iteratorWithKeyIndex() { } /** - * Number of allowed keys index. + * The maximum number of allowed keys index. */ - public int numKeysIndex() { + public int maxNumKeysIndex() { return (int) (longArray.size() / 2); } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 6e995a3929a75..f4e952f465e54 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -172,6 +172,7 @@ public void emptyMap() { final byte[] key = getRandomByteArray(keyLengthInWords); Assert.assertFalse(map.lookup(key, Platform.BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); Assert.assertFalse(map.iterator().hasNext()); + Assert.assertFalse(map.iteratorWithKeyIndex().hasNext()); } finally { map.free(); } @@ -233,9 +234,10 @@ public void setAndRetrieveAKey() { } } - private void iteratorTestBase(boolean destructive) throws Exception { + private void iteratorTestBase(boolean destructive, boolean isWithKeyIndex) throws Exception { final int size = 4096; BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size / 2, PAGE_SIZE_BYTES); + Assert.assertEquals(size / 2, map.maxNumKeysIndex()); try { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; @@ -267,6 +269,8 @@ private void iteratorTestBase(boolean destructive) throws Exception { final Iterator iter; if (destructive) { iter = map.destructiveIterator(); + } else if (isWithKeyIndex) { + iter = map.iteratorWithKeyIndex(); } else { iter = map.iterator(); } @@ -291,6 +295,12 @@ private void iteratorTestBase(boolean destructive) throws Exception { countFreedPages++; } } + if (keyLength != 0 && isWithKeyIndex) { + final BytesToBytesMap.Location expectedLoc = map.lookup( + loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()); + Assert.assertTrue(expectedLoc.isDefined() && + expectedLoc.getKeyIndex() == loc.getKeyIndex()); + } } if (destructive) { // Latest page is not freed by iterator but by map itself @@ -304,12 +314,17 @@ private void iteratorTestBase(boolean destructive) throws Exception { @Test public void iteratorTest() throws Exception { - iteratorTestBase(false); + iteratorTestBase(false, false); } @Test public void destructiveIteratorTest() throws Exception { - iteratorTestBase(true); + iteratorTestBase(true, false); + } + + @Test + public void iteratorWithKeyIndexTest() throws Exception { + iteratorTestBase(false, true); } @Test @@ -603,6 +618,12 @@ public void multipleValuesForSameKey() { final BytesToBytesMap.Location loc = iter.next(); assert loc.isDefined(); } + BytesToBytesMap.MapIteratorWithKeyIndex iterWithKeyIndex = map.iteratorWithKeyIndex(); + for (i = 0; i < 2048; i++) { + assert iterWithKeyIndex.hasNext(); + final BytesToBytesMap.Location loc = iterWithKeyIndex.next(); + assert loc.isDefined() && loc.getKeyIndex() >= 0; + } } finally { map.free(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 226d92fa87769..8baee32e2494a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -73,15 +73,22 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { */ def getWithKeyIndex(key: InternalRow): (Int, Iterator[InternalRow]) + /** + * Returns key index and matched single row. + * + * Returns null if there is no matched rows. + */ + def getValueWithKeyIndex(key: InternalRow): ValueRowWithKeyIndex + /** * Returns an iterator for keys index and rows of InternalRow type. */ - def valuesWithKeyIndex(): Iterator[(Int, InternalRow)] + def valuesWithKeyIndex(): Iterator[ValueRowWithKeyIndex] /** - * Returns number of keys index. + * Returns the maximum number of allowed keys index. */ - def numKeysIndex: Int + def maxNumKeysIndex: Int /** * Returns true iff all the keys are unique. @@ -140,6 +147,30 @@ private[execution] object HashedRelation { } } +/** + * A wrapper for key index and value in InternalRow type. + * Designed to be instantiated once per thread and reused. + */ +private[execution] class ValueRowWithKeyIndex { + private var keyIndex: Int = _ + private var value: InternalRow = _ + + /** Updates this ValueRowWithKeyIndex. Returns itself. */ + def updates(newKeyIndex: Int, newValue: InternalRow): ValueRowWithKeyIndex = { + keyIndex = newKeyIndex + value = newValue + this + } + + def getKeyIndex: Int = { + keyIndex + } + + def getValue: InternalRow = { + value + } +} + /** * A HashedRelation for UnsafeRow, which is backed BytesToBytesMap. * @@ -163,9 +194,12 @@ private[joins] class UnsafeHashedRelation( override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption - // re-used in get()/getValue()/getWithKeyIndex()/valuesWithKeyIndex() + // re-used in get()/getValue()/getWithKeyIndex()/getValueWithKeyIndex()/valuesWithKeyIndex() var resultRow = new UnsafeRow(numFields) + // re-used in getValueWithKeyIndex()/valuesWithKeyIndex() + var valueRowWithKeyIndex = new ValueRowWithKeyIndex + override def get(key: InternalRow): Iterator[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] val map = binaryMap // avoid the compiler error @@ -223,25 +257,39 @@ private[joins] class UnsafeHashedRelation( } } - override def valuesWithKeyIndex(): Iterator[(Int, InternalRow)] = { + override def getValueWithKeyIndex(key: InternalRow): ValueRowWithKeyIndex = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) + if (loc.isDefined) { + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + valueRowWithKeyIndex.updates(loc.getKeyIndex, resultRow) + } else { + null + } + } + + override def valuesWithKeyIndex(): Iterator[ValueRowWithKeyIndex] = { val iter = binaryMap.iteratorWithKeyIndex() - new Iterator[(Int, InternalRow)] { + new Iterator[ValueRowWithKeyIndex] { override def hasNext: Boolean = iter.hasNext - override def next(): (Int, InternalRow) = { + override def next(): ValueRowWithKeyIndex = { if (!hasNext) { throw new NoSuchElementException("End of the iterator") } val loc = iter.next() resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - (loc.getKeyIndex, resultRow) + valueRowWithKeyIndex.updates(loc.getKeyIndex, resultRow) } } } - override def numKeysIndex: Int = { - binaryMap.numKeysIndex + override def maxNumKeysIndex: Int = { + binaryMap.maxNumKeysIndex } override def keys(): Iterator[InternalRow] = { @@ -956,11 +1004,15 @@ class LongHashedRelation( throw new UnsupportedOperationException } - override def valuesWithKeyIndex(): Iterator[(Int, InternalRow)] = { + override def getValueWithKeyIndex(key: InternalRow): ValueRowWithKeyIndex = { + throw new UnsupportedOperationException + } + + override def valuesWithKeyIndex(): Iterator[ValueRowWithKeyIndex] = { throw new UnsupportedOperationException } - override def numKeysIndex: Int = { + override def maxNumKeysIndex: Int = { throw new UnsupportedOperationException } } @@ -1015,11 +1067,15 @@ trait NullAwareHashedRelation extends HashedRelation with Externalizable { throw new UnsupportedOperationException } - override def valuesWithKeyIndex(): Iterator[(Int, InternalRow)] = { + override def getValueWithKeyIndex(key: InternalRow): ValueRowWithKeyIndex = { + throw new UnsupportedOperationException + } + + override def valuesWithKeyIndex(): Iterator[ValueRowWithKeyIndex] = { throw new UnsupportedOperationException } - override def numKeysIndex: Int = { + override def maxNumKeysIndex: Int = { throw new UnsupportedOperationException } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 31db638b486fb..2f2db90b5d865 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -119,13 +119,11 @@ case class ShuffledHashJoinExec( } val iter = if (hashedRelation.keyIsUnique) { - fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, joinRow, streamNullJoinRow, - joinRowWithStream, joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, - streamNullRow) + fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, joinRowWithStream, + joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, streamNullRow) } else { - fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, joinRow, - streamNullJoinRow, joinRowWithStream, joinRowWithBuild, streamNullJoinRowWithBuild, - buildNullRow, streamNullRow) + fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, joinRowWithStream, + joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, streamNullRow) } val resultProj = UnsafeProjection.create(output, output) @@ -148,14 +146,12 @@ case class ShuffledHashJoinExec( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation, joinKeys: UnsafeProjection, - joinRow: JoinedRow, - streamNullJoinRow: JoinedRow, joinRowWithStream: InternalRow => JoinedRow, joinRowWithBuild: InternalRow => JoinedRow, streamNullJoinRowWithBuild: InternalRow => JoinedRow, buildNullRow: GenericInternalRow, streamNullRow: GenericInternalRow): Iterator[InternalRow] = { - val matchedKeys = new BitSet(hashedRelation.numKeysIndex) + val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex) // Process stream side with looking up hash relation val streamResultIter = streamIter.map { srow => @@ -164,11 +160,12 @@ case class ShuffledHashJoinExec( if (keys.anyNull) { joinRowWithBuild(buildNullRow) } else { - val matched = hashedRelation.getWithKeyIndex(keys) + val matched = hashedRelation.getValueWithKeyIndex(keys) if (matched != null) { - val (keyIndex, buildIter) = (matched._1, matched._2) - val buildRow = buildIter.next - if (boundCondition(joinRowWithBuild(buildRow))) { + val keyIndex = matched.getKeyIndex + val buildRow = matched.getValue + val joinRow = joinRowWithBuild(buildRow) + if (boundCondition(joinRow)) { matchedKeys.set(keyIndex) joinRow } else { @@ -183,11 +180,12 @@ case class ShuffledHashJoinExec( // Process build side with filtering out rows looked up and // passed join condition already val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap { - case (keyIndex, brow) => + valueRowWithKeyIndex => + val keyIndex = valueRowWithKeyIndex.getKeyIndex val isMatched = matchedKeys.get(keyIndex) if (!isMatched) { - streamNullJoinRowWithBuild(brow) - Some(streamNullJoinRow) + val buildRow = valueRowWithKeyIndex.getValue + Some(streamNullJoinRowWithBuild(buildRow)) } else { None } @@ -210,8 +208,6 @@ case class ShuffledHashJoinExec( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation, joinKeys: UnsafeProjection, - joinRow: JoinedRow, - streamNullJoinRow: JoinedRow, joinRowWithStream: InternalRow => JoinedRow, joinRowWithBuild: InternalRow => JoinedRow, streamNullJoinRowWithBuild: InternalRow => JoinedRow, @@ -231,7 +227,7 @@ case class ShuffledHashJoinExec( // Process stream side with looking up hash relation val streamResultIter = streamIter.flatMap { srow => - joinRowWithStream(srow) + val joinRow = joinRowWithStream(srow) val keys = joinKeys(srow) if (keys.anyNull) { Iterator.single(joinRowWithBuild(buildNullRow)) @@ -271,7 +267,8 @@ case class ShuffledHashJoinExec( var prevKeyIndex = -1 var valueIndex = -1 val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap { - case (keyIndex, brow) => + valueRowWithKeyIndex => + val keyIndex = valueRowWithKeyIndex.getKeyIndex if (prevKeyIndex == -1 || keyIndex != prevKeyIndex) { prevKeyIndex = keyIndex valueIndex = -1 @@ -279,8 +276,8 @@ case class ShuffledHashJoinExec( valueIndex += 1 val isMatched = isRowMatched(keyIndex, valueIndex) if (!isMatched) { - streamNullJoinRowWithBuild(brow) - Some(streamNullJoinRow) + val buildRow = valueRowWithKeyIndex.getValue + Some(streamNullJoinRowWithBuild(buildRow)) } else { None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index ce239864f1f87..a3168c2ce0f39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1214,12 +1214,14 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan // Test multiple join keys (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr( "value as k1", "cast(value % 5 as short) as k2", "cast(value * 3 as long) as k3"), - spark.range(30).map(i => if (i % 4 == 0) i else null).selectExpr( + spark.range(30).map(i => if (i % 3 == 0) i else null).selectExpr( "value as k4", "cast(value % 5 as short) as k5", "cast(value * 3 as long) as k6"), $"k1" === $"k4" && $"k2" === $"k5" && $"k3" === $"k6") ) inputDFs.foreach { case (df1, df2, joinExprs) => withSQLConf( + // Set broadcast join threshold and number of shuffle partitions, + // as shuffled hash join depends on these two configs. SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", SQLConf.SHUFFLE_PARTITIONS.key -> "2") { val smjDF = df1.join(df2, joinExprs, "full") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 8b270bd5a2636..497bf1b1a6f69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -592,4 +593,81 @@ class HashedRelationSuite extends SharedSparkSession { assert(hashed.getValue(0L) == null) assert(hashed.getValue(key) == null) } + + test("SPARK-32399: test methods related to key index") { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val toUnsafe = UnsafeProjection.create(schema) + val key = Seq(BoundReference(0, IntegerType, true)) + val row = Seq(BoundReference(0, IntegerType, true), BoundReference(1, IntegerType, true)) + val unsafeProj = UnsafeProjection.create(row) + var rows = (0 until 100).map(i => { + val k = if (i % 10 == 0) null else i % 10 + unsafeProj(InternalRow(k, i)).copy() + }) + rows = unsafeProj(InternalRow(-1, -1)).copy() +: rows + val unsafeRelation = UnsafeHashedRelation(rows.iterator, key, 10, mm, allowsNullKey = true) + val keyIndexToKeyMap = new mutable.HashMap[Int, String] + val keyIndexToValueMap = new mutable.HashMap[Int, Seq[Int]] + + // test getWithKeyIndex() + (0 until 10).foreach(i => { + val key = if (i == 0) InternalRow(null) else InternalRow(i) + val valuesWithKeyIndex = unsafeRelation.getWithKeyIndex(toUnsafe(key)) + val keyIndex = valuesWithKeyIndex._1 + val actualValues = valuesWithKeyIndex._2.map(v => v.getInt(1)).toSeq + val expectedValues = (0 until 10).map(j => j * 10 + i) + if (i == 0) { + keyIndexToKeyMap(keyIndex) = "null" + } else { + keyIndexToKeyMap(keyIndex) = i.toString + } + keyIndexToValueMap(keyIndex) = actualValues + // key index is non-negative + assert(keyIndex >= 0) + // values are expected + assert(actualValues.sortWith(_ < _) === expectedValues) + }) + // key index is unique per key + val numUniqueKeyIndex = (0 until 10).map(i => { + val key = if (i == 0) InternalRow(null) else InternalRow(i) + val keyIndex = unsafeRelation.getWithKeyIndex(toUnsafe(key))._1 + keyIndex + }).distinct.size + assert(numUniqueKeyIndex == 10) + // NULL for non-existing key + assert(unsafeRelation.getWithKeyIndex(toUnsafe(InternalRow(100))) == null) + + // test getValueWithKeyIndex() + val valuesWithKeyIndex = unsafeRelation.getValueWithKeyIndex(toUnsafe(InternalRow(-1))) + val keyIndex = valuesWithKeyIndex.getKeyIndex + keyIndexToKeyMap(keyIndex) = "-1" + keyIndexToValueMap(keyIndex) = Seq(-1) + // key index is non-negative + assert(valuesWithKeyIndex.getKeyIndex >= 0) + // value is expected + assert(valuesWithKeyIndex.getValue.getInt(1) == -1) + // NULL for non-existing key + assert(unsafeRelation.getValueWithKeyIndex(toUnsafe(InternalRow(100))) == null) + + // test valuesWithKeyIndex() + val keyIndexToRowMap = unsafeRelation.valuesWithKeyIndex().map( + v => (v.getKeyIndex, v.getValue.copy())).toSeq.groupBy(_._1) + assert(keyIndexToRowMap.size == 11) + keyIndexToRowMap.foreach { + case (keyIndex, row) => + val expectedKey = keyIndexToKeyMap(keyIndex) + val expectedValues = keyIndexToValueMap(keyIndex) + // key index returned from valuesWithKeyIndex() + // should be the same as returned from getWithKeyIndex() + if (expectedKey == "null") { + assert(row.head._2.isNullAt(0)) + } else { + assert(row.head._2.getInt(0).toString == expectedKey) + } + // values returned from valuesWithKeyIndex() + // should have same value and order as returned from getWithKeyIndex() + val actualValues = row.map(_._2.getInt(1)) + assert(actualValues === expectedValues) + } + } } From 35651dd2ad4fb34d48f39bab116cd5e0da99115e Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Thu, 13 Aug 2020 12:52:59 -0700 Subject: [PATCH 07/11] Address all new comments --- .../spark/unsafe/map/BytesToBytesMap.java | 4 +- .../sql/execution/joins/HashedRelation.scala | 50 ++++++++----- .../joins/ShuffledHashJoinExec.scala | 74 +++++++++---------- .../execution/joins/HashedRelationSuite.scala | 11 +-- 4 files changed, 77 insertions(+), 62 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 07ec2db56a370..ccd622862711e 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -456,7 +456,7 @@ public Location next() { while (longArray.get(keyIndex * 2) == 0) { keyIndex++; } - loc.with(keyIndex, (int) longArray.get(keyIndex * 2 + 1), true); + loc.with(keyIndex, 0, true); keyIndex++; } numRecords--; @@ -479,6 +479,8 @@ public MapIteratorWithKeyIndex iteratorWithKeyIndex() { /** * The maximum number of allowed keys index. + * + * The value of allowed keys index is in the range of [0, maxNumKeysIndex - 1]. */ public int maxNumKeysIndex() { return (int) (longArray.size() / 2); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 8baee32e2494a..6162b5e3e1dad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -67,11 +67,11 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { } /** - * Returns key index and matched rows. + * Returns an iterator for key index and matched rows. * * Returns null if there is no matched rows. */ - def getWithKeyIndex(key: InternalRow): (Int, Iterator[InternalRow]) + def getWithKeyIndex(key: InternalRow): Iterator[ValueRowWithKeyIndex] /** * Returns key index and matched single row. @@ -155,8 +155,20 @@ private[execution] class ValueRowWithKeyIndex { private var keyIndex: Int = _ private var value: InternalRow = _ + /** Updates this ValueRowWithKeyIndex by updating its key index. Returns itself. */ + def withNewKeyIndex(newKeyIndex: Int): ValueRowWithKeyIndex = { + keyIndex = newKeyIndex + this + } + + /** Updates this ValueRowWithKeyIndex by updating its value. Returns itself. */ + def withNewValue(newValue: InternalRow): ValueRowWithKeyIndex = { + value = newValue + this + } + /** Updates this ValueRowWithKeyIndex. Returns itself. */ - def updates(newKeyIndex: Int, newValue: InternalRow): ValueRowWithKeyIndex = { + def update(newKeyIndex: Int, newValue: InternalRow): ValueRowWithKeyIndex = { keyIndex = newKeyIndex value = newValue this @@ -197,7 +209,7 @@ private[joins] class UnsafeHashedRelation( // re-used in get()/getValue()/getWithKeyIndex()/getValueWithKeyIndex()/valuesWithKeyIndex() var resultRow = new UnsafeRow(numFields) - // re-used in getValueWithKeyIndex()/valuesWithKeyIndex() + // re-used in getWithKeyIndex()/getValueWithKeyIndex()/valuesWithKeyIndex() var valueRowWithKeyIndex = new ValueRowWithKeyIndex override def get(key: InternalRow): Iterator[InternalRow] = { @@ -235,23 +247,23 @@ private[joins] class UnsafeHashedRelation( } } - override def getWithKeyIndex(key: InternalRow): (Int, Iterator[InternalRow]) = { + override def getWithKeyIndex(key: InternalRow): Iterator[ValueRowWithKeyIndex] = { val unsafeKey = key.asInstanceOf[UnsafeRow] val map = binaryMap // avoid the compiler error val loc = new map.Location // this could be allocated in stack binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) if (loc.isDefined) { - (loc.getKeyIndex, - new Iterator[UnsafeRow] { - private var _hasNext = true - override def hasNext: Boolean = _hasNext - override def next(): UnsafeRow = { - resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - _hasNext = loc.nextValue() - resultRow - } - }) + valueRowWithKeyIndex.withNewKeyIndex(loc.getKeyIndex) + new Iterator[ValueRowWithKeyIndex] { + private var _hasNext = true + override def hasNext: Boolean = _hasNext + override def next(): ValueRowWithKeyIndex = { + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + _hasNext = loc.nextValue() + valueRowWithKeyIndex.withNewValue(resultRow) + } + } } else { null } @@ -265,7 +277,7 @@ private[joins] class UnsafeHashedRelation( unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) if (loc.isDefined) { resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - valueRowWithKeyIndex.updates(loc.getKeyIndex, resultRow) + valueRowWithKeyIndex.update(loc.getKeyIndex, resultRow) } else { null } @@ -283,7 +295,7 @@ private[joins] class UnsafeHashedRelation( } val loc = iter.next() resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - valueRowWithKeyIndex.updates(loc.getKeyIndex, resultRow) + valueRowWithKeyIndex.update(loc.getKeyIndex, resultRow) } } } @@ -1000,7 +1012,7 @@ class LongHashedRelation( */ override def keys(): Iterator[InternalRow] = map.keys() - override def getWithKeyIndex(key: InternalRow): (Int, Iterator[InternalRow]) = { + override def getWithKeyIndex(key: InternalRow): Iterator[ValueRowWithKeyIndex] = { throw new UnsupportedOperationException } @@ -1063,7 +1075,7 @@ trait NullAwareHashedRelation extends HashedRelation with Externalizable { throw new UnsupportedOperationException } - override def getWithKeyIndex(key: InternalRow): (Int, Iterator[InternalRow]) = { + override def getWithKeyIndex(key: InternalRow): Iterator[ValueRowWithKeyIndex] = { throw new UnsupportedOperationException } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 2f2db90b5d865..5889e11930571 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -106,15 +106,14 @@ case class ShuffledHashJoinExec( } val buildNullRow = new GenericInternalRow(buildOutput.length) val streamNullRow = new GenericInternalRow(streamedOutput.length) - val streamNullJoinRow = new JoinedRow - val streamNullJoinRowWithBuild = { + lazy val streamNullJoinRowWithBuild = { buildSide match { case BuildLeft => - streamNullJoinRow.withRight(streamNullRow) - streamNullJoinRow.withLeft _ + joinRow.withRight(streamNullRow) + joinRow.withLeft _ case BuildRight => - streamNullJoinRow.withLeft(streamNullRow) - streamNullJoinRow.withRight _ + joinRow.withLeft(streamNullRow) + joinRow.withRight _ } } @@ -148,7 +147,7 @@ case class ShuffledHashJoinExec( joinKeys: UnsafeProjection, joinRowWithStream: InternalRow => JoinedRow, joinRowWithBuild: InternalRow => JoinedRow, - streamNullJoinRowWithBuild: InternalRow => JoinedRow, + streamNullJoinRowWithBuild: => InternalRow => JoinedRow, buildNullRow: GenericInternalRow, streamNullRow: GenericInternalRow): Iterator[InternalRow] = { val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex) @@ -177,8 +176,7 @@ case class ShuffledHashJoinExec( } } - // Process build side with filtering out rows looked up and - // passed join condition already + // Process build side with filtering out the matched rows val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap { valueRowWithKeyIndex => val keyIndex = valueRowWithKeyIndex.getKeyIndex @@ -210,7 +208,7 @@ case class ShuffledHashJoinExec( joinKeys: UnsafeProjection, joinRowWithStream: InternalRow => JoinedRow, joinRowWithBuild: InternalRow => JoinedRow, - streamNullJoinRowWithBuild: InternalRow => JoinedRow, + streamNullJoinRowWithBuild: => InternalRow => JoinedRow, buildNullRow: GenericInternalRow, streamNullRow: GenericInternalRow): Iterator[InternalRow] = { val matchedRows = new mutable.HashSet[Long] @@ -232,48 +230,50 @@ case class ShuffledHashJoinExec( if (keys.anyNull) { Iterator.single(joinRowWithBuild(buildNullRow)) } else { - val matched = hashedRelation.getWithKeyIndex(keys) - if (matched != null) { - val (keyIndex, buildIter) = (matched._1, matched._2.zipWithIndex) - - new RowIterator { - private var found = false - override def advanceNext(): Boolean = { - while (buildIter.hasNext) { - val (buildRow, valueIndex) = buildIter.next() - if (boundCondition(joinRowWithBuild(buildRow))) { - markRowMatched(keyIndex, valueIndex) - found = true - return true - } - } - if (!found) { - joinRowWithBuild(buildNullRow) + val buildIter = hashedRelation.getWithKeyIndex(keys) + new RowIterator { + private var found = false + private var valueIndex = -1 + override def advanceNext(): Boolean = { + while (buildIter != null && buildIter.hasNext) { + val buildRowWithKeyIndex = buildIter.next() + val keyIndex = buildRowWithKeyIndex.getKeyIndex + val buildRow = buildRowWithKeyIndex.getValue + valueIndex += 1 + if (boundCondition(joinRowWithBuild(buildRow))) { + markRowMatched(keyIndex, valueIndex) found = true return true } - false } - override def getRow: InternalRow = joinRow - }.toScala - } else { - Iterator.single(joinRowWithBuild(buildNullRow)) - } + // When we reach here, it means no match is found for this key. + // So we need to return one row with build side NULL row, + // to match the full outer join semantic. + if (!found) { + joinRowWithBuild(buildNullRow) + found = true + return true + } + false + } + override def getRow: InternalRow = joinRow + }.toScala } } - // Process build side with filtering out rows looked up and - // passed join condition already + // Process build side with filtering out the matched rows var prevKeyIndex = -1 var valueIndex = -1 val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap { valueRowWithKeyIndex => val keyIndex = valueRowWithKeyIndex.getKeyIndex if (prevKeyIndex == -1 || keyIndex != prevKeyIndex) { + valueIndex = 0 prevKeyIndex = keyIndex - valueIndex = -1 + } else { + valueIndex += 1 } - valueIndex += 1 + val isMatched = isRowMatched(keyIndex, valueIndex) if (!isMatched) { val buildRow = valueRowWithKeyIndex.getValue diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 497bf1b1a6f69..72e921deab933 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -612,9 +612,10 @@ class HashedRelationSuite extends SharedSparkSession { // test getWithKeyIndex() (0 until 10).foreach(i => { val key = if (i == 0) InternalRow(null) else InternalRow(i) - val valuesWithKeyIndex = unsafeRelation.getWithKeyIndex(toUnsafe(key)) - val keyIndex = valuesWithKeyIndex._1 - val actualValues = valuesWithKeyIndex._2.map(v => v.getInt(1)).toSeq + val valuesWithKeyIndex = unsafeRelation.getWithKeyIndex(toUnsafe(key)).map( + v => (v.getKeyIndex, v.getValue.getInt(1))).toArray + val keyIndex = valuesWithKeyIndex.head._1 + val actualValues = valuesWithKeyIndex.map(_._2) val expectedValues = (0 until 10).map(j => j * 10 + i) if (i == 0) { keyIndexToKeyMap(keyIndex) = "null" @@ -628,9 +629,9 @@ class HashedRelationSuite extends SharedSparkSession { assert(actualValues.sortWith(_ < _) === expectedValues) }) // key index is unique per key - val numUniqueKeyIndex = (0 until 10).map(i => { + val numUniqueKeyIndex = (0 until 10).flatMap(i => { val key = if (i == 0) InternalRow(null) else InternalRow(i) - val keyIndex = unsafeRelation.getWithKeyIndex(toUnsafe(key))._1 + val keyIndex = unsafeRelation.getWithKeyIndex(toUnsafe(key)).map(_.getKeyIndex).toSeq keyIndex }).distinct.size assert(numUniqueKeyIndex == 10) From 82ed0b412e9a0bdb565fda07921805b92b631f0f Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Thu, 13 Aug 2020 15:01:30 -0700 Subject: [PATCH 08/11] Address new comments around definition of key and value index --- .../org/apache/spark/unsafe/map/BytesToBytesMap.java | 12 ++++++++---- .../spark/sql/execution/joins/HashedRelation.scala | 2 ++ .../sql/execution/joins/ShuffledHashJoinExec.scala | 5 +++++ 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index ccd622862711e..56193ccc0a628 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -436,13 +436,17 @@ public MapIterator destructiveIterator() { */ public final class MapIteratorWithKeyIndex implements Iterator { + /** + * The index in `longArray` where the key is stored. + */ private int keyIndex = 0; + private int numRecords; private final Location loc; - private MapIteratorWithKeyIndex(int numRecords, Location loc) { - this.numRecords = numRecords; - this.loc = loc; + private MapIteratorWithKeyIndex() { + this.numRecords = numValues; + this.loc = new Location(); } @Override @@ -474,7 +478,7 @@ public Location next() { * the behavior of the returned iterator is undefined. */ public MapIteratorWithKeyIndex iteratorWithKeyIndex() { - return new MapIteratorWithKeyIndex(numValues, new Location()); + return new MapIteratorWithKeyIndex(); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 6162b5e3e1dad..ca8ede12fcde3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -441,6 +441,8 @@ private[joins] object UnsafeHashedRelation { taskMemoryManager: TaskMemoryManager, isNullAware: Boolean = false, allowsNullKey: Boolean = false): HashedRelation = { + require(!(isNullAware && allowsNullKey), + "isNullAware and allowsNullKey cannot be enabled at same time") val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().get(BUFFER_PAGESIZE).getOrElse(16L * 1024 * 1024)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 5889e11930571..71f9bab796727 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -201,6 +201,11 @@ case class ShuffledHashJoinExec( * 2. Process rows from build side by iterating hash relation. * Filter out rows from build side being matched already, * by checking key index and value index from `HashSet`. + * + * The "value index" is defined as the index of the tuple in the chain + * of tuples having the same key. For example, if certain key is found thrice, + * the value indices of its tuples will be 0, 1 and 2. + * Note that value indices of tuples with different keys are incomparable. */ private def fullOuterJoinWithNonUniqueKey( streamIter: Iterator[InternalRow], From cf04e2f9edeb0364ffed180d49b64d5b6969ef36 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Fri, 14 Aug 2020 16:58:41 -0700 Subject: [PATCH 09/11] Address all new comments --- .../apache/spark/sql/execution/joins/HashedRelation.scala | 1 + .../spark/sql/execution/joins/ShuffledHashJoinExec.scala | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index ca8ede12fcde3..ea58f3a934381 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -75,6 +75,7 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { /** * Returns key index and matched single row. + * This is for unique key case. * * Returns null if there is no matched rows. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 71f9bab796727..1e9426eae5e6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -193,7 +193,7 @@ case class ShuffledHashJoinExec( } /** - * Full outer shuffled hash join with unique join keys: + * Full outer shuffled hash join with non-unique join keys: * 1. Process rows from stream side by looking up hash relation. * Mark the matched rows from build side be looked up. * A `HashSet[Long]` is used to track matched rows with @@ -253,9 +253,11 @@ case class ShuffledHashJoinExec( } // When we reach here, it means no match is found for this key. // So we need to return one row with build side NULL row, - // to match the full outer join semantic. + // to satisfy the full outer join semantic. if (!found) { joinRowWithBuild(buildNullRow) + // Set `found` to be true as we only need to return one row + // but no more. found = true return true } From 381cdbc2ec393a96cab1ab611dc461dbed9d7dc2 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Sat, 15 Aug 2020 13:51:13 -0700 Subject: [PATCH 10/11] Address all new comments --- .../org/apache/spark/unsafe/map/BytesToBytesMap.java | 2 +- .../org/apache/spark/sql/internal/SQLConf.scala | 4 +++- .../spark/sql/execution/joins/HashedRelation.scala | 2 +- .../test/scala/org/apache/spark/sql/JoinSuite.scala | 12 ++++++++++++ 4 files changed, 17 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 56193ccc0a628..8eea9db393aff 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -429,7 +429,7 @@ public MapIterator destructiveIterator() { } /** - * Iterator for the entries of this map. This is to first iterate over key index array + * Iterator for the entries of this map. This is to first iterate over key indices in * `longArray` then accessing values in `dataPages`. NOTE: this is different from `MapIterator` * in the sense that key index is preserved here * (See `UnsafeHashedRelation` for example of usage). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 57fc1bd99be28..98ccf1ab0c01f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -329,7 +329,9 @@ object SQLConf { val PREFER_SORTMERGEJOIN = buildConf("spark.sql.join.preferSortMergeJoin") .internal() - .doc("When true, prefer sort merge join over shuffle hash join.") + .doc("When true, prefer sort merge join over shuffled hash join. " + + "Note that shuffled hash join supports all join types (e.g. full outer) " + + "that sort merge join supports.") .version("2.0.0") .booleanConf .createWithDefault(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index ea58f3a934381..183d8a3ad04d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -137,7 +137,7 @@ private[execution] object HashedRelation { 0) } - if (!input.hasNext) { + if (!input.hasNext && !allowsNullKey) { EmptyHashedRelation } else if (key.length == 1 && key.head.dataType == LongType && !allowsNullKey) { // NOTE: LongHashedRelation does not support NULL keys. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a3168c2ce0f39..ac9299990d5e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1199,6 +1199,18 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan (spark.range(10).selectExpr("id % 5 as k1"), spark.range(30).selectExpr("id % 5 as k2"), $"k1" === $"k2"), + // Test empty build side + (spark.range(10).selectExpr("id as k1").filter("k1 < -1"), + spark.range(30).selectExpr("id as k2"), + $"k1" === $"k2"), + // Test empty stream side + (spark.range(10).selectExpr("id as k1"), + spark.range(30).selectExpr("id as k2").filter("k2 < -1"), + $"k1" === $"k2"), + // Test empty build and stream side + (spark.range(10).selectExpr("id as k1").filter("k1 < -1"), + spark.range(30).selectExpr("id as k2").filter("k2 < -1"), + $"k1" === $"k2"), // Test string join key (spark.range(10).selectExpr("cast(id * 3 as string) as k1"), spark.range(30).selectExpr("cast(id as string) as k2"), From 526709b73b87687f48f68486b9c8c7be0866291f Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Sun, 16 Aug 2020 11:23:49 -0700 Subject: [PATCH 11/11] Address all new comments --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 5 +++-- .../spark/sql/execution/joins/ShuffledHashJoinExec.scala | 4 ++++ sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala | 3 +++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 98ccf1ab0c01f..9f36e8702c830 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -330,8 +330,9 @@ object SQLConf { val PREFER_SORTMERGEJOIN = buildConf("spark.sql.join.preferSortMergeJoin") .internal() .doc("When true, prefer sort merge join over shuffled hash join. " + - "Note that shuffled hash join supports all join types (e.g. full outer) " + - "that sort merge join supports.") + "Sort merge join consumes less memory than shuffled hash join and it works efficiently " + + "when both join tables are large. On the other hand, shuffled hash join can improve " + + "performance (e.g., of full outer joins) when one of join tables is much smaller.") .version("2.0.0") .booleanConf .createWithDefault(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 1e9426eae5e6e..133e964719682 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -150,6 +150,8 @@ case class ShuffledHashJoinExec( streamNullJoinRowWithBuild: => InternalRow => JoinedRow, buildNullRow: GenericInternalRow, streamNullRow: GenericInternalRow): Iterator[InternalRow] = { + // TODO(SPARK-32629):record metrics of extra BitSet/HashSet + // in full outer shuffled hash join val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex) // Process stream side with looking up hash relation @@ -216,6 +218,8 @@ case class ShuffledHashJoinExec( streamNullJoinRowWithBuild: => InternalRow => JoinedRow, buildNullRow: GenericInternalRow, streamNullRow: GenericInternalRow): Iterator[InternalRow] = { + // TODO(SPARK-32629):record metrics of extra BitSet/HashSet + // in full outer shuffled hash join val matchedRows = new mutable.HashSet[Long] def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index ac9299990d5e2..e7629a21f787a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1223,6 +1223,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr("value as k1"), spark.range(30).map(i => if (i % 4 == 0) i else null).selectExpr("value as k2"), $"k1" === $"k2"), + (spark.range(10).map(i => if (i % 3 == 0) i else null).selectExpr("value as k1"), + spark.range(30).map(i => if (i % 5 == 0) i else null).selectExpr("value as k2"), + $"k1" === $"k2"), // Test multiple join keys (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr( "value as k1", "cast(value % 5 as short) as k2", "cast(value * 3 as long) as k3"),