diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 6e028886f2318..8eea9db393aff 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -428,6 +428,68 @@ public MapIterator destructiveIterator() { return new MapIterator(numValues, new Location(), true); } + /** + * Iterator for the entries of this map. This is to first iterate over key indices in + * `longArray` then accessing values in `dataPages`. NOTE: this is different from `MapIterator` + * in the sense that key index is preserved here + * (See `UnsafeHashedRelation` for example of usage). + */ + public final class MapIteratorWithKeyIndex implements Iterator { + + /** + * The index in `longArray` where the key is stored. + */ + private int keyIndex = 0; + + private int numRecords; + private final Location loc; + + private MapIteratorWithKeyIndex() { + this.numRecords = numValues; + this.loc = new Location(); + } + + @Override + public boolean hasNext() { + return numRecords > 0; + } + + @Override + public Location next() { + if (!loc.isDefined() || !loc.nextValue()) { + while (longArray.get(keyIndex * 2) == 0) { + keyIndex++; + } + loc.with(keyIndex, 0, true); + keyIndex++; + } + numRecords--; + return loc; + } + } + + /** + * Returns an iterator for iterating over the entries of this map, + * by first iterating over the key index inside hash map's `longArray`. + * + * For efficiency, all calls to `next()` will return the same {@link Location} object. + * + * The returned iterator is NOT thread-safe. If the map is modified while iterating over it, + * the behavior of the returned iterator is undefined. + */ + public MapIteratorWithKeyIndex iteratorWithKeyIndex() { + return new MapIteratorWithKeyIndex(); + } + + /** + * The maximum number of allowed keys index. + * + * The value of allowed keys index is in the range of [0, maxNumKeysIndex - 1]. + */ + public int maxNumKeysIndex() { + return (int) (longArray.size() / 2); + } + /** * Looks up a key, and return a {@link Location} handle that can be used to test existence * and read/write values. @@ -601,6 +663,14 @@ public boolean isDefined() { return isDefined; } + /** + * Returns index for key. + */ + public int getKeyIndex() { + assert (isDefined); + return pos; + } + /** * Returns the base object for key. */ diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 6e995a3929a75..f4e952f465e54 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -172,6 +172,7 @@ public void emptyMap() { final byte[] key = getRandomByteArray(keyLengthInWords); Assert.assertFalse(map.lookup(key, Platform.BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); Assert.assertFalse(map.iterator().hasNext()); + Assert.assertFalse(map.iteratorWithKeyIndex().hasNext()); } finally { map.free(); } @@ -233,9 +234,10 @@ public void setAndRetrieveAKey() { } } - private void iteratorTestBase(boolean destructive) throws Exception { + private void iteratorTestBase(boolean destructive, boolean isWithKeyIndex) throws Exception { final int size = 4096; BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size / 2, PAGE_SIZE_BYTES); + Assert.assertEquals(size / 2, map.maxNumKeysIndex()); try { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; @@ -267,6 +269,8 @@ private void iteratorTestBase(boolean destructive) throws Exception { final Iterator iter; if (destructive) { iter = map.destructiveIterator(); + } else if (isWithKeyIndex) { + iter = map.iteratorWithKeyIndex(); } else { iter = map.iterator(); } @@ -291,6 +295,12 @@ private void iteratorTestBase(boolean destructive) throws Exception { countFreedPages++; } } + if (keyLength != 0 && isWithKeyIndex) { + final BytesToBytesMap.Location expectedLoc = map.lookup( + loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()); + Assert.assertTrue(expectedLoc.isDefined() && + expectedLoc.getKeyIndex() == loc.getKeyIndex()); + } } if (destructive) { // Latest page is not freed by iterator but by map itself @@ -304,12 +314,17 @@ private void iteratorTestBase(boolean destructive) throws Exception { @Test public void iteratorTest() throws Exception { - iteratorTestBase(false); + iteratorTestBase(false, false); } @Test public void destructiveIteratorTest() throws Exception { - iteratorTestBase(true); + iteratorTestBase(true, false); + } + + @Test + public void iteratorWithKeyIndexTest() throws Exception { + iteratorTestBase(false, true); } @Test @@ -603,6 +618,12 @@ public void multipleValuesForSameKey() { final BytesToBytesMap.Location loc = iter.next(); assert loc.isDefined(); } + BytesToBytesMap.MapIteratorWithKeyIndex iterWithKeyIndex = map.iteratorWithKeyIndex(); + for (i = 0; i < 2048; i++) { + assert iterWithKeyIndex.hasNext(); + final BytesToBytesMap.Location loc = iterWithKeyIndex.next(); + assert loc.isDefined() && loc.getKeyIndex() >= 0; + } } finally { map.free(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 85c6600685bd1..57c3f3dbd050d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -235,8 +235,8 @@ trait JoinSelectionHelper { canBroadcastBySize(right, conf) && !hintToNotBroadcastRight(hint) } getBuildSide( - canBuildLeft(joinType) && buildLeft, - canBuildRight(joinType) && buildRight, + canBuildBroadcastLeft(joinType) && buildLeft, + canBuildBroadcastRight(joinType) && buildRight, left, right ) @@ -260,8 +260,8 @@ trait JoinSelectionHelper { canBuildLocalHashMapBySize(right, conf) && muchSmaller(right, left) } getBuildSide( - canBuildLeft(joinType) && buildLeft, - canBuildRight(joinType) && buildRight, + canBuildShuffledHashJoinLeft(joinType) && buildLeft, + canBuildShuffledHashJoinRight(joinType) && buildRight, left, right ) @@ -278,20 +278,35 @@ trait JoinSelectionHelper { plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold } - def canBuildLeft(joinType: JoinType): Boolean = { + def canBuildBroadcastLeft(joinType: JoinType): Boolean = { joinType match { case _: InnerLike | RightOuter => true case _ => false } } - def canBuildRight(joinType: JoinType): Boolean = { + def canBuildBroadcastRight(joinType: JoinType): Boolean = { joinType match { case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true case _ => false } } + def canBuildShuffledHashJoinLeft(joinType: JoinType): Boolean = { + joinType match { + case _: InnerLike | RightOuter | FullOuter => true + case _ => false + } + } + + def canBuildShuffledHashJoinRight(joinType: JoinType): Boolean = { + joinType match { + case _: InnerLike | LeftOuter | FullOuter | + LeftSemi | LeftAnti | _: ExistenceJoin => true + case _ => false + } + } + def hintToBroadcastLeft(hint: JoinHint): Boolean = { hint.leftHint.exists(_.strategy.contains(BROADCAST)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 57fc1bd99be28..9f36e8702c830 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -329,7 +329,10 @@ object SQLConf { val PREFER_SORTMERGEJOIN = buildConf("spark.sql.join.preferSortMergeJoin") .internal() - .doc("When true, prefer sort merge join over shuffle hash join.") + .doc("When true, prefer sort merge join over shuffled hash join. " + + "Sort merge join consumes less memory than shuffled hash join and it works efficiently " + + "when both join tables are large. On the other hand, shuffled hash join can improve " + + "performance (e.g., of full outer joins) when one of join tables is much smaller.") .version("2.0.0") .booleanConf .createWithDefault(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index eb32bfcecae7b..391e2524e4794 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -116,7 +116,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * * - Shuffle hash join: * Only supported for equi-joins, while the join keys do not need to be sortable. - * Supported for all join types except full outer joins. + * Supported for all join types. + * Building hash map from table is a memory-intensive operation and it could cause OOM + * when the build side is big. * * - Shuffle sort merge join (SMJ): * Only supported for equi-joins and the join keys have to be sortable. @@ -260,7 +262,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // it's a right join, and broadcast right side if it's a left join. // TODO: revisit it. If left side is much smaller than the right side, it may be better // to broadcast the left side even if it's a left join. - if (canBuildLeft(joinType)) BuildLeft else BuildRight + if (canBuildBroadcastLeft(joinType)) BuildLeft else BuildRight } def createBroadcastNLJoin(buildLeft: Boolean, buildRight: Boolean) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 2154e370a1596..1a7554c905c6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -114,7 +114,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { } } - @transient private lazy val (buildOutput, streamedOutput) = { + @transient protected lazy val (buildOutput, streamedOutput) = { buildSide match { case BuildLeft => (left.output, right.output) case BuildRight => (right.output, left.output) @@ -133,7 +133,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { protected def streamSideKeyGenerator(): UnsafeProjection = UnsafeProjection.create(streamedBoundKeys) - @transient private[this] lazy val boundCondition = if (condition.isDefined) { + @transient protected[this] lazy val boundCondition = if (condition.isDefined) { Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _ } else { (r: InternalRow) => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 0d40520ae71a0..183d8a3ad04d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -66,6 +66,31 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { throw new UnsupportedOperationException } + /** + * Returns an iterator for key index and matched rows. + * + * Returns null if there is no matched rows. + */ + def getWithKeyIndex(key: InternalRow): Iterator[ValueRowWithKeyIndex] + + /** + * Returns key index and matched single row. + * This is for unique key case. + * + * Returns null if there is no matched rows. + */ + def getValueWithKeyIndex(key: InternalRow): ValueRowWithKeyIndex + + /** + * Returns an iterator for keys index and rows of InternalRow type. + */ + def valuesWithKeyIndex(): Iterator[ValueRowWithKeyIndex] + + /** + * Returns the maximum number of allowed keys index. + */ + def maxNumKeysIndex: Int + /** * Returns true iff all the keys are unique. */ @@ -91,13 +116,17 @@ private[execution] object HashedRelation { /** * Create a HashedRelation from an Iterator of InternalRow. + * + * @param allowsNullKey Allow NULL keys in HashedRelation. + * This is used for full outer join in `ShuffledHashJoinExec` only. */ def apply( input: Iterator[InternalRow], key: Seq[Expression], sizeEstimate: Int = 64, taskMemoryManager: TaskMemoryManager = null, - isNullAware: Boolean = false): HashedRelation = { + isNullAware: Boolean = false, + allowsNullKey: Boolean = false): HashedRelation = { val mm = Option(taskMemoryManager).getOrElse { new TaskMemoryManager( new UnifiedMemoryManager( @@ -108,16 +137,53 @@ private[execution] object HashedRelation { 0) } - if (!input.hasNext) { + if (!input.hasNext && !allowsNullKey) { EmptyHashedRelation - } else if (key.length == 1 && key.head.dataType == LongType) { + } else if (key.length == 1 && key.head.dataType == LongType && !allowsNullKey) { + // NOTE: LongHashedRelation does not support NULL keys. LongHashedRelation(input, key, sizeEstimate, mm, isNullAware) } else { - UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware) + UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware, allowsNullKey) } } } +/** + * A wrapper for key index and value in InternalRow type. + * Designed to be instantiated once per thread and reused. + */ +private[execution] class ValueRowWithKeyIndex { + private var keyIndex: Int = _ + private var value: InternalRow = _ + + /** Updates this ValueRowWithKeyIndex by updating its key index. Returns itself. */ + def withNewKeyIndex(newKeyIndex: Int): ValueRowWithKeyIndex = { + keyIndex = newKeyIndex + this + } + + /** Updates this ValueRowWithKeyIndex by updating its value. Returns itself. */ + def withNewValue(newValue: InternalRow): ValueRowWithKeyIndex = { + value = newValue + this + } + + /** Updates this ValueRowWithKeyIndex. Returns itself. */ + def update(newKeyIndex: Int, newValue: InternalRow): ValueRowWithKeyIndex = { + keyIndex = newKeyIndex + value = newValue + this + } + + def getKeyIndex: Int = { + keyIndex + } + + def getValue: InternalRow = { + value + } +} + /** * A HashedRelation for UnsafeRow, which is backed BytesToBytesMap. * @@ -141,9 +207,12 @@ private[joins] class UnsafeHashedRelation( override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption - // re-used in get()/getValue() + // re-used in get()/getValue()/getWithKeyIndex()/getValueWithKeyIndex()/valuesWithKeyIndex() var resultRow = new UnsafeRow(numFields) + // re-used in getWithKeyIndex()/getValueWithKeyIndex()/valuesWithKeyIndex() + var valueRowWithKeyIndex = new ValueRowWithKeyIndex + override def get(key: InternalRow): Iterator[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] val map = binaryMap // avoid the compiler error @@ -179,6 +248,63 @@ private[joins] class UnsafeHashedRelation( } } + override def getWithKeyIndex(key: InternalRow): Iterator[ValueRowWithKeyIndex] = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) + if (loc.isDefined) { + valueRowWithKeyIndex.withNewKeyIndex(loc.getKeyIndex) + new Iterator[ValueRowWithKeyIndex] { + private var _hasNext = true + override def hasNext: Boolean = _hasNext + override def next(): ValueRowWithKeyIndex = { + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + _hasNext = loc.nextValue() + valueRowWithKeyIndex.withNewValue(resultRow) + } + } + } else { + null + } + } + + override def getValueWithKeyIndex(key: InternalRow): ValueRowWithKeyIndex = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) + if (loc.isDefined) { + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + valueRowWithKeyIndex.update(loc.getKeyIndex, resultRow) + } else { + null + } + } + + override def valuesWithKeyIndex(): Iterator[ValueRowWithKeyIndex] = { + val iter = binaryMap.iteratorWithKeyIndex() + + new Iterator[ValueRowWithKeyIndex] { + override def hasNext: Boolean = iter.hasNext + + override def next(): ValueRowWithKeyIndex = { + if (!hasNext) { + throw new NoSuchElementException("End of the iterator") + } + val loc = iter.next() + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + valueRowWithKeyIndex.update(loc.getKeyIndex, resultRow) + } + } + } + + override def maxNumKeysIndex: Int = { + binaryMap.maxNumKeysIndex + } + override def keys(): Iterator[InternalRow] = { val iter = binaryMap.iterator() @@ -314,7 +440,10 @@ private[joins] object UnsafeHashedRelation { key: Seq[Expression], sizeEstimate: Int, taskMemoryManager: TaskMemoryManager, - isNullAware: Boolean = false): HashedRelation = { + isNullAware: Boolean = false, + allowsNullKey: Boolean = false): HashedRelation = { + require(!(isNullAware && allowsNullKey), + "isNullAware and allowsNullKey cannot be enabled at same time") val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().get(BUFFER_PAGESIZE).getOrElse(16L * 1024 * 1024)) @@ -331,7 +460,7 @@ private[joins] object UnsafeHashedRelation { val row = input.next().asInstanceOf[UnsafeRow] numFields = row.numFields() val key = keyGenerator(row) - if (!key.anyNull) { + if (!key.anyNull || allowsNullKey) { val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) val success = loc.append( key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, @@ -885,6 +1014,22 @@ class LongHashedRelation( * Returns an iterator for keys of InternalRow type. */ override def keys(): Iterator[InternalRow] = map.keys() + + override def getWithKeyIndex(key: InternalRow): Iterator[ValueRowWithKeyIndex] = { + throw new UnsupportedOperationException + } + + override def getValueWithKeyIndex(key: InternalRow): ValueRowWithKeyIndex = { + throw new UnsupportedOperationException + } + + override def valuesWithKeyIndex(): Iterator[ValueRowWithKeyIndex] = { + throw new UnsupportedOperationException + } + + override def maxNumKeysIndex: Int = { + throw new UnsupportedOperationException + } } /** @@ -933,6 +1078,22 @@ trait NullAwareHashedRelation extends HashedRelation with Externalizable { throw new UnsupportedOperationException } + override def getWithKeyIndex(key: InternalRow): Iterator[ValueRowWithKeyIndex] = { + throw new UnsupportedOperationException + } + + override def getValueWithKeyIndex(key: InternalRow): ValueRowWithKeyIndex = { + throw new UnsupportedOperationException + } + + override def valuesWithKeyIndex(): Iterator[ValueRowWithKeyIndex] = { + throw new UnsupportedOperationException + } + + override def maxNumKeysIndex: Int = { + throw new UnsupportedOperationException + } + override def keyIsUnique: Boolean = true override def keys(): Iterator[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 41cefd03dd931..133e964719682 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -19,16 +19,19 @@ package org.apache.spark.sql.execution.joins import java.util.concurrent.TimeUnit._ +import scala.collection.mutable + import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.optimizer.BuildSide +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.{RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.util.collection.BitSet /** * Performs a hash join of two child relations by first shuffling the data using the join keys. @@ -48,8 +51,15 @@ case class ShuffledHashJoinExec( "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"), "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map")) + override def output: Seq[Attribute] = super[ShuffledJoin].output + override def outputPartitioning: Partitioning = super[ShuffledJoin].outputPartitioning + override def outputOrdering: Seq[SortOrder] = joinType match { + case FullOuter => Nil + case _ => super.outputOrdering + } + /** * This is called by generated Java class, should be public. */ @@ -59,7 +69,11 @@ case class ShuffledHashJoinExec( val start = System.nanoTime() val context = TaskContext.get() val relation = HashedRelation( - iter, buildBoundKeys, taskMemoryManager = context.taskMemoryManager()) + iter, + buildBoundKeys, + taskMemoryManager = context.taskMemoryManager(), + // Full outer join needs support for NULL key in HashedRelation. + allowsNullKey = joinType == FullOuter) buildTime += NANOSECONDS.toMillis(System.nanoTime() - start) buildDataSize += relation.estimatedSize // This relation is usually used until the end of task. @@ -71,8 +85,221 @@ case class ShuffledHashJoinExec( val numOutputRows = longMetric("numOutputRows") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter) - join(streamIter, hashed, numOutputRows) + joinType match { + case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows) + case _ => join(streamIter, hashed, numOutputRows) + } + } + } + + private def fullOuterJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation, + numOutputRows: SQLMetric): Iterator[InternalRow] = { + val joinKeys = streamSideKeyGenerator() + val joinRow = new JoinedRow + val (joinRowWithStream, joinRowWithBuild) = { + buildSide match { + case BuildLeft => (joinRow.withRight _, joinRow.withLeft _) + case BuildRight => (joinRow.withLeft _, joinRow.withRight _) + } + } + val buildNullRow = new GenericInternalRow(buildOutput.length) + val streamNullRow = new GenericInternalRow(streamedOutput.length) + lazy val streamNullJoinRowWithBuild = { + buildSide match { + case BuildLeft => + joinRow.withRight(streamNullRow) + joinRow.withLeft _ + case BuildRight => + joinRow.withLeft(streamNullRow) + joinRow.withRight _ + } + } + + val iter = if (hashedRelation.keyIsUnique) { + fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, joinRowWithStream, + joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, streamNullRow) + } else { + fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, joinRowWithStream, + joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, streamNullRow) } + + val resultProj = UnsafeProjection.create(output, output) + iter.map { r => + numOutputRows += 1 + resultProj(r) + } + } + + /** + * Full outer shuffled hash join with unique join keys: + * 1. Process rows from stream side by looking up hash relation. + * Mark the matched rows from build side be looked up. + * A `BitSet` is used to track matched rows with key index. + * 2. Process rows from build side by iterating hash relation. + * Filter out rows from build side being matched already, + * by checking key index from `BitSet`. + */ + private def fullOuterJoinWithUniqueKey( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation, + joinKeys: UnsafeProjection, + joinRowWithStream: InternalRow => JoinedRow, + joinRowWithBuild: InternalRow => JoinedRow, + streamNullJoinRowWithBuild: => InternalRow => JoinedRow, + buildNullRow: GenericInternalRow, + streamNullRow: GenericInternalRow): Iterator[InternalRow] = { + // TODO(SPARK-32629):record metrics of extra BitSet/HashSet + // in full outer shuffled hash join + val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex) + + // Process stream side with looking up hash relation + val streamResultIter = streamIter.map { srow => + joinRowWithStream(srow) + val keys = joinKeys(srow) + if (keys.anyNull) { + joinRowWithBuild(buildNullRow) + } else { + val matched = hashedRelation.getValueWithKeyIndex(keys) + if (matched != null) { + val keyIndex = matched.getKeyIndex + val buildRow = matched.getValue + val joinRow = joinRowWithBuild(buildRow) + if (boundCondition(joinRow)) { + matchedKeys.set(keyIndex) + joinRow + } else { + joinRowWithBuild(buildNullRow) + } + } else { + joinRowWithBuild(buildNullRow) + } + } + } + + // Process build side with filtering out the matched rows + val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap { + valueRowWithKeyIndex => + val keyIndex = valueRowWithKeyIndex.getKeyIndex + val isMatched = matchedKeys.get(keyIndex) + if (!isMatched) { + val buildRow = valueRowWithKeyIndex.getValue + Some(streamNullJoinRowWithBuild(buildRow)) + } else { + None + } + } + + streamResultIter ++ buildResultIter + } + + /** + * Full outer shuffled hash join with non-unique join keys: + * 1. Process rows from stream side by looking up hash relation. + * Mark the matched rows from build side be looked up. + * A `HashSet[Long]` is used to track matched rows with + * key index (Int) and value index (Int) together. + * 2. Process rows from build side by iterating hash relation. + * Filter out rows from build side being matched already, + * by checking key index and value index from `HashSet`. + * + * The "value index" is defined as the index of the tuple in the chain + * of tuples having the same key. For example, if certain key is found thrice, + * the value indices of its tuples will be 0, 1 and 2. + * Note that value indices of tuples with different keys are incomparable. + */ + private def fullOuterJoinWithNonUniqueKey( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation, + joinKeys: UnsafeProjection, + joinRowWithStream: InternalRow => JoinedRow, + joinRowWithBuild: InternalRow => JoinedRow, + streamNullJoinRowWithBuild: => InternalRow => JoinedRow, + buildNullRow: GenericInternalRow, + streamNullRow: GenericInternalRow): Iterator[InternalRow] = { + // TODO(SPARK-32629):record metrics of extra BitSet/HashSet + // in full outer shuffled hash join + val matchedRows = new mutable.HashSet[Long] + + def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = { + val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex + matchedRows.add(rowIndex) + } + + def isRowMatched(keyIndex: Int, valueIndex: Int): Boolean = { + val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex + matchedRows.contains(rowIndex) + } + + // Process stream side with looking up hash relation + val streamResultIter = streamIter.flatMap { srow => + val joinRow = joinRowWithStream(srow) + val keys = joinKeys(srow) + if (keys.anyNull) { + Iterator.single(joinRowWithBuild(buildNullRow)) + } else { + val buildIter = hashedRelation.getWithKeyIndex(keys) + new RowIterator { + private var found = false + private var valueIndex = -1 + override def advanceNext(): Boolean = { + while (buildIter != null && buildIter.hasNext) { + val buildRowWithKeyIndex = buildIter.next() + val keyIndex = buildRowWithKeyIndex.getKeyIndex + val buildRow = buildRowWithKeyIndex.getValue + valueIndex += 1 + if (boundCondition(joinRowWithBuild(buildRow))) { + markRowMatched(keyIndex, valueIndex) + found = true + return true + } + } + // When we reach here, it means no match is found for this key. + // So we need to return one row with build side NULL row, + // to satisfy the full outer join semantic. + if (!found) { + joinRowWithBuild(buildNullRow) + // Set `found` to be true as we only need to return one row + // but no more. + found = true + return true + } + false + } + override def getRow: InternalRow = joinRow + }.toScala + } + } + + // Process build side with filtering out the matched rows + var prevKeyIndex = -1 + var valueIndex = -1 + val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap { + valueRowWithKeyIndex => + val keyIndex = valueRowWithKeyIndex.getKeyIndex + if (prevKeyIndex == -1 || keyIndex != prevKeyIndex) { + valueIndex = 0 + prevKeyIndex = keyIndex + } else { + valueIndex += 1 + } + + val isMatched = isRowMatched(keyIndex, valueIndex) + if (!isMatched) { + val buildRow = valueRowWithKeyIndex.getValue + Some(streamNullJoinRowWithBuild(buildRow)) + } else { + None + } + } + + streamResultIter ++ buildResultIter + } + + // TODO(SPARK-32567): support full outer shuffled hash join code-gen + override def supportCodegen: Boolean = { + joinType != FullOuter } override def inputRDDs(): Seq[RDD[InternalRow]] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala index 7035ddc35be9c..92bfc1ed4861d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, LeftExistence, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftExistence, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution, Partitioning, PartitioningCollection, UnknownPartitioning} /** @@ -40,4 +41,24 @@ trait ShuffledJoin extends BaseJoinExec { throw new IllegalArgumentException( s"ShuffledJoin should not take $x as the JoinType") } + + override def output: Seq[Attribute] = { + joinType match { + case _: InnerLike => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + (left.output ++ right.output).map(_.withNullability(true)) + case j: ExistenceJoin => + left.output :+ j.exists + case LeftExistence(_) => + left.output + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} not take $x as the JoinType") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index b9f6684447dd8..6e7bcb8825488 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -52,26 +52,6 @@ case class SortMergeJoinExec( override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator - override def output: Seq[Attribute] = { - joinType match { - case _: InnerLike => - left.output ++ right.output - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - (left.output ++ right.output).map(_.withNullability(true)) - case j: ExistenceJoin => - left.output :+ j.exists - case LeftExistence(_) => - left.output - case x => - throw new IllegalArgumentException( - s"${getClass.getSimpleName} should not take $x as the JoinType") - } - } - override def requiredChildDistribution: Seq[Distribution] = { if (isSkewJoin) { // We re-arrange the shuffle partitions to deal with skew join, and the new children diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index bedfbffc789ac..e7629a21f787a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1188,4 +1188,70 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan classOf[BroadcastNestedLoopJoinExec])) } } + + test("SPARK-32399: Full outer shuffled hash join") { + val inputDFs = Seq( + // Test unique join key + (spark.range(10).selectExpr("id as k1"), + spark.range(30).selectExpr("id as k2"), + $"k1" === $"k2"), + // Test non-unique join key + (spark.range(10).selectExpr("id % 5 as k1"), + spark.range(30).selectExpr("id % 5 as k2"), + $"k1" === $"k2"), + // Test empty build side + (spark.range(10).selectExpr("id as k1").filter("k1 < -1"), + spark.range(30).selectExpr("id as k2"), + $"k1" === $"k2"), + // Test empty stream side + (spark.range(10).selectExpr("id as k1"), + spark.range(30).selectExpr("id as k2").filter("k2 < -1"), + $"k1" === $"k2"), + // Test empty build and stream side + (spark.range(10).selectExpr("id as k1").filter("k1 < -1"), + spark.range(30).selectExpr("id as k2").filter("k2 < -1"), + $"k1" === $"k2"), + // Test string join key + (spark.range(10).selectExpr("cast(id * 3 as string) as k1"), + spark.range(30).selectExpr("cast(id as string) as k2"), + $"k1" === $"k2"), + // Test build side at right + (spark.range(30).selectExpr("cast(id / 3 as string) as k1"), + spark.range(10).selectExpr("cast(id as string) as k2"), + $"k1" === $"k2"), + // Test NULL join key + (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr("value as k1"), + spark.range(30).map(i => if (i % 4 == 0) i else null).selectExpr("value as k2"), + $"k1" === $"k2"), + (spark.range(10).map(i => if (i % 3 == 0) i else null).selectExpr("value as k1"), + spark.range(30).map(i => if (i % 5 == 0) i else null).selectExpr("value as k2"), + $"k1" === $"k2"), + // Test multiple join keys + (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr( + "value as k1", "cast(value % 5 as short) as k2", "cast(value * 3 as long) as k3"), + spark.range(30).map(i => if (i % 3 == 0) i else null).selectExpr( + "value as k4", "cast(value % 5 as short) as k5", "cast(value * 3 as long) as k6"), + $"k1" === $"k4" && $"k2" === $"k5" && $"k3" === $"k6") + ) + inputDFs.foreach { case (df1, df2, joinExprs) => + withSQLConf( + // Set broadcast join threshold and number of shuffle partitions, + // as shuffled hash join depends on these two configs. + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + val smjDF = df1.join(df2, joinExprs, "full") + assert(smjDF.queryExecution.executedPlan.collect { + case _: SortMergeJoinExec => true }.size === 1) + val smjResult = smjDF.collect() + + withSQLConf(SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { + val shjDF = df1.join(df2, joinExprs, "full") + assert(shjDF.queryExecution.executedPlan.collect { + case _: ShuffledHashJoinExec => true }.size === 1) + // Same result between shuffled hash join and sort merge join + checkAnswer(shjDF, smjResult) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 8b270bd5a2636..72e921deab933 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -592,4 +593,82 @@ class HashedRelationSuite extends SharedSparkSession { assert(hashed.getValue(0L) == null) assert(hashed.getValue(key) == null) } + + test("SPARK-32399: test methods related to key index") { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val toUnsafe = UnsafeProjection.create(schema) + val key = Seq(BoundReference(0, IntegerType, true)) + val row = Seq(BoundReference(0, IntegerType, true), BoundReference(1, IntegerType, true)) + val unsafeProj = UnsafeProjection.create(row) + var rows = (0 until 100).map(i => { + val k = if (i % 10 == 0) null else i % 10 + unsafeProj(InternalRow(k, i)).copy() + }) + rows = unsafeProj(InternalRow(-1, -1)).copy() +: rows + val unsafeRelation = UnsafeHashedRelation(rows.iterator, key, 10, mm, allowsNullKey = true) + val keyIndexToKeyMap = new mutable.HashMap[Int, String] + val keyIndexToValueMap = new mutable.HashMap[Int, Seq[Int]] + + // test getWithKeyIndex() + (0 until 10).foreach(i => { + val key = if (i == 0) InternalRow(null) else InternalRow(i) + val valuesWithKeyIndex = unsafeRelation.getWithKeyIndex(toUnsafe(key)).map( + v => (v.getKeyIndex, v.getValue.getInt(1))).toArray + val keyIndex = valuesWithKeyIndex.head._1 + val actualValues = valuesWithKeyIndex.map(_._2) + val expectedValues = (0 until 10).map(j => j * 10 + i) + if (i == 0) { + keyIndexToKeyMap(keyIndex) = "null" + } else { + keyIndexToKeyMap(keyIndex) = i.toString + } + keyIndexToValueMap(keyIndex) = actualValues + // key index is non-negative + assert(keyIndex >= 0) + // values are expected + assert(actualValues.sortWith(_ < _) === expectedValues) + }) + // key index is unique per key + val numUniqueKeyIndex = (0 until 10).flatMap(i => { + val key = if (i == 0) InternalRow(null) else InternalRow(i) + val keyIndex = unsafeRelation.getWithKeyIndex(toUnsafe(key)).map(_.getKeyIndex).toSeq + keyIndex + }).distinct.size + assert(numUniqueKeyIndex == 10) + // NULL for non-existing key + assert(unsafeRelation.getWithKeyIndex(toUnsafe(InternalRow(100))) == null) + + // test getValueWithKeyIndex() + val valuesWithKeyIndex = unsafeRelation.getValueWithKeyIndex(toUnsafe(InternalRow(-1))) + val keyIndex = valuesWithKeyIndex.getKeyIndex + keyIndexToKeyMap(keyIndex) = "-1" + keyIndexToValueMap(keyIndex) = Seq(-1) + // key index is non-negative + assert(valuesWithKeyIndex.getKeyIndex >= 0) + // value is expected + assert(valuesWithKeyIndex.getValue.getInt(1) == -1) + // NULL for non-existing key + assert(unsafeRelation.getValueWithKeyIndex(toUnsafe(InternalRow(100))) == null) + + // test valuesWithKeyIndex() + val keyIndexToRowMap = unsafeRelation.valuesWithKeyIndex().map( + v => (v.getKeyIndex, v.getValue.copy())).toSeq.groupBy(_._1) + assert(keyIndexToRowMap.size == 11) + keyIndexToRowMap.foreach { + case (keyIndex, row) => + val expectedKey = keyIndexToKeyMap(keyIndex) + val expectedValues = keyIndexToValueMap(keyIndex) + // key index returned from valuesWithKeyIndex() + // should be the same as returned from getWithKeyIndex() + if (expectedKey == "null") { + assert(row.head._2.isNullAt(0)) + } else { + assert(row.head._2.getInt(0).toString == expectedKey) + } + // values returned from valuesWithKeyIndex() + // should have same value and order as returned from getWithKeyIndex() + val actualValues = row.map(_._2.getInt(1)) + assert(actualValues === expectedValues) + } + } }