Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.joins.EmptyHashedRelationWithAllNullKeys
import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys

/**
* This optimization rule detects and convert a NAAJ to an Empty LocalRelation
* when buildSide is EmptyHashedRelationWithAllNullKeys.
* when buildSide is HashedRelationWithAllNullKeys.
*/
object EliminateNullAwareAntiJoin extends Rule[LogicalPlan] {

private def canEliminate(plan: LogicalPlan): Boolean = plan match {
case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined
&& stage.broadcast.relationFuture.get().value == EmptyHashedRelationWithAllNullKeys => true
&& stage.broadcast.relationFuture.get().value == HashedRelationWithAllNullKeys => true
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ case class BroadcastHashJoinExec(
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
if (hashed == EmptyHashedRelation) {
streamedIter
} else if (hashed == EmptyHashedRelationWithAllNullKeys) {
} else if (hashed == HashedRelationWithAllNullKeys) {
Iterator.empty
} else {
val keyGenerator = UnsafeProjection.create(
Expand Down Expand Up @@ -228,7 +228,6 @@ case class BroadcastHashJoinExec(
if (isNullAwareAntiJoin) {
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val (matched, _, _) = getJoinCondition(ctx, input)
val numOutput = metricTerm(ctx, "numOutputRows")

if (broadcastRelation.value == EmptyHashedRelation) {
Expand All @@ -237,26 +236,15 @@ case class BroadcastHashJoinExec(
|$numOutput.add(1);
|${consume(ctx, input)}
""".stripMargin
} else if (broadcastRelation.value == EmptyHashedRelationWithAllNullKeys) {
} else if (broadcastRelation.value == HashedRelationWithAllNullKeys) {
s"""
|// If the right side contains any all-null key, NAAJ simply returns Nothing.
""".stripMargin
} else {
val found = ctx.freshName("found")
s"""
|boolean $found = false;
|// generate join key for stream side
|${keyEv.code}
|if ($anyNull) {
| $found = true;
|} else {
| UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value});
| if ($matched != null) {
| $found = true;
| }
|}
|
|if (!$found) {
|if (!$anyNull && $relationTerm.getValue(${keyEv.value}) == null) {
| $numOutput.add(1);
| ${consume(ctx, input)}
|}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ private[joins] object UnsafeHashedRelation {
// scalastyle:on throwerror
}
} else if (isNullAware) {
return EmptyHashedRelationWithAllNullKeys
return HashedRelationWithAllNullKeys
}
}

Expand Down Expand Up @@ -1056,7 +1056,7 @@ private[joins] object LongHashedRelation {
val key = rowKey.getLong(0)
map.append(key, unsafeRow)
} else if (isNullAware) {
return EmptyHashedRelationWithAllNullKeys
return HashedRelationWithAllNullKeys
}
}
map.optimize()
Expand All @@ -1067,7 +1067,7 @@ private[joins] object LongHashedRelation {
/**
* Common trait with dummy implementation for NAAJ special HashedRelation
* EmptyHashedRelation
* EmptyHashedRelationWithAllNullKeys
* HashedRelationWithAllNullKeys
*/
trait NullAwareHashedRelation extends HashedRelation with Externalizable {
override def get(key: InternalRow): Iterator[InternalRow] = {
Expand Down Expand Up @@ -1130,8 +1130,8 @@ object EmptyHashedRelation extends NullAwareHashedRelation {
* A special HashedRelation indicates it built from a non-empty input:Iterator[InternalRow],
* which contains all null columns key.
*/
object EmptyHashedRelationWithAllNullKeys extends NullAwareHashedRelation {
override def asReadOnlyCopy(): EmptyHashedRelationWithAllNullKeys.type = this
object HashedRelationWithAllNullKeys extends NullAwareHashedRelation {
override def asReadOnlyCopy(): HashedRelationWithAllNullKeys.type = this
}

/** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ class AdaptiveQueryExecSuite
}
}

test("SPARK-32573: Eliminate NAAJ when BuildSide is EmptyHashedRelationWithAllNullKeys") {
test("SPARK-32573: Eliminate NAAJ when BuildSide is HashedRelationWithAllNullKeys") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString) {
Expand Down