Skip to content

Commit 95e1ab2

Browse files
Davies Liudavies
authored andcommitted
[SPARK-13237] [SQL] generated broadcast outer join
This PR support codegen for broadcast outer join. In order to reduce the duplicated codes, this PR merge HashJoin and HashOuterJoin together (also BroadcastHashJoin and BroadcastHashOuterJoin). Author: Davies Liu <[email protected]> Closes #11130 from davies/gen_out.
1 parent 26f38bb commit 95e1ab2

File tree

12 files changed

+448
-371
lines changed

12 files changed

+448
-371
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
108108
// --- Inner joins --------------------------------------------------------------------------
109109

110110
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
111-
joins.BroadcastHashJoin(
112-
leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil
111+
Seq(joins.BroadcastHashJoin(
112+
leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right)))
113113

114114
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
115-
joins.BroadcastHashJoin(
116-
leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil
115+
Seq(joins.BroadcastHashJoin(
116+
leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right)))
117117

118118
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
119119
if RowOrdering.isOrderable(leftKeys) =>
@@ -124,13 +124,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
124124

125125
case ExtractEquiJoinKeys(
126126
LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
127-
joins.BroadcastHashOuterJoin(
128-
leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil
127+
Seq(joins.BroadcastHashJoin(
128+
leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right)))
129129

130130
case ExtractEquiJoinKeys(
131131
RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
132-
joins.BroadcastHashOuterJoin(
133-
leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
132+
Seq(joins.BroadcastHashJoin(
133+
leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))
134134

135135
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
136136
if RowOrdering.isOrderable(leftKeys) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
2828
import org.apache.spark.sql.catalyst.rules.Rule
2929
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
3030
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
31-
import org.apache.spark.sql.execution.metric.{LongSQLMetric, LongSQLMetricValue, SQLMetric}
31+
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
3232

3333
/**
3434
* An interface for those physical operators that support codegen.
@@ -38,7 +38,7 @@ trait CodegenSupport extends SparkPlan {
3838
/** Prefix used in the current operator's variable names. */
3939
private def variablePrefix: String = this match {
4040
case _: TungstenAggregate => "agg"
41-
case _: BroadcastHashJoin => "bhj"
41+
case _: BroadcastHashJoin => "join"
4242
case _ => nodeName.toLowerCase
4343
}
4444

@@ -391,9 +391,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
391391
var inputs = ArrayBuffer[SparkPlan]()
392392
val combined = plan.transform {
393393
// The build side can't be compiled together
394-
case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) =>
394+
case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) =>
395395
b.copy(left = apply(left))
396-
case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) =>
396+
case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
397397
b.copy(right = apply(right))
398398
case p if !supportCodegen(p) =>
399399
val input = apply(p) // collapse them recursively

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala

Lines changed: 199 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ import org.apache.spark.TaskContext
2424
import org.apache.spark.broadcast.Broadcast
2525
import org.apache.spark.rdd.RDD
2626
import org.apache.spark.sql.catalyst.InternalRow
27-
import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow}
27+
import org.apache.spark.sql.catalyst.expressions._
2828
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
29+
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter}
2930
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
3031
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution}
3132
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -41,6 +42,7 @@ import org.apache.spark.util.collection.CompactBuffer
4142
case class BroadcastHashJoin(
4243
leftKeys: Seq[Expression],
4344
rightKeys: Seq[Expression],
45+
joinType: JoinType,
4446
buildSide: BuildSide,
4547
condition: Option[Expression],
4648
left: SparkPlan,
@@ -105,75 +107,144 @@ case class BroadcastHashJoin(
105107
val broadcastRelation = Await.result(broadcastFuture, timeout)
106108

107109
streamedPlan.execute().mapPartitions { streamedIter =>
108-
val hashedRelation = broadcastRelation.value
109-
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
110-
hashJoin(streamedIter, hashedRelation, numOutputRows)
110+
val joinedRow = new JoinedRow()
111+
val hashTable = broadcastRelation.value
112+
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
113+
val keyGenerator = streamSideKeyGenerator
114+
val resultProj = createResultProjection
115+
116+
joinType match {
117+
case Inner =>
118+
hashJoin(streamedIter, hashTable, numOutputRows)
119+
120+
case LeftOuter =>
121+
streamedIter.flatMap { currentRow =>
122+
val rowKey = keyGenerator(currentRow)
123+
joinedRow.withLeft(currentRow)
124+
leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
125+
}
126+
127+
case RightOuter =>
128+
streamedIter.flatMap { currentRow =>
129+
val rowKey = keyGenerator(currentRow)
130+
joinedRow.withRight(currentRow)
131+
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
132+
}
133+
134+
case x =>
135+
throw new IllegalArgumentException(
136+
s"BroadcastHashJoin should not take $x as the JoinType")
137+
}
111138
}
112139
}
113140

114-
private var broadcastRelation: Broadcast[HashedRelation] = _
115-
// the term for hash relation
116-
private var relationTerm: String = _
117-
118141
override def upstream(): RDD[InternalRow] = {
119142
streamedPlan.asInstanceOf[CodegenSupport].upstream()
120143
}
121144

122145
override def doProduce(ctx: CodegenContext): String = {
146+
streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
147+
}
148+
149+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
150+
if (joinType == Inner) {
151+
codegenInner(ctx, input)
152+
} else {
153+
// LeftOuter and RightOuter
154+
codegenOuter(ctx, input)
155+
}
156+
}
157+
158+
/**
159+
* Returns a tuple of Broadcast of HashedRelation and the variable name for it.
160+
*/
161+
private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = {
123162
// create a name for HashedRelation
124-
broadcastRelation = Await.result(broadcastFuture, timeout)
163+
val broadcastRelation = Await.result(broadcastFuture, timeout)
125164
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
126-
relationTerm = ctx.freshName("relation")
165+
val relationTerm = ctx.freshName("relation")
127166
val clsName = broadcastRelation.value.getClass.getName
128167
ctx.addMutableState(clsName, relationTerm,
129168
s"""
130169
| $relationTerm = ($clsName) $broadcast.value();
131170
| incPeakExecutionMemory($relationTerm.getMemorySize());
132171
""".stripMargin)
133-
134-
s"""
135-
| ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
136-
""".stripMargin
172+
(broadcastRelation, relationTerm)
137173
}
138174

139-
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
140-
// generate the key as UnsafeRow or Long
175+
/**
176+
* Returns the code for generating join key for stream side, and expression of whether the key
177+
* has any null in it or not.
178+
*/
179+
private def genStreamSideJoinKey(
180+
ctx: CodegenContext,
181+
input: Seq[ExprCode]): (ExprCode, String) = {
141182
ctx.currentVars = input
142-
val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) {
183+
if (canJoinKeyFitWithinLong) {
184+
// generate the join key as Long
143185
val expr = rewriteKeyExpr(streamedKeys).head
144186
val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
145187
(ev, ev.isNull)
146188
} else {
189+
// generate the join key as UnsafeRow
147190
val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
148191
val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
149192
(ev, s"${ev.value}.anyNull()")
150193
}
194+
}
151195

152-
// find the matches from HashedRelation
153-
val matched = ctx.freshName("matched")
154-
155-
// create variables for output
196+
/**
197+
* Generates the code for variable of build side.
198+
*/
199+
private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = {
156200
ctx.currentVars = null
157201
ctx.INPUT_ROW = matched
158-
val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
159-
BoundReference(i, a.dataType, a.nullable).gen(ctx)
202+
buildPlan.output.zipWithIndex.map { case (a, i) =>
203+
val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx)
204+
if (joinType == Inner) {
205+
ev
206+
} else {
207+
// the variables are needed even there is no matched rows
208+
val isNull = ctx.freshName("isNull")
209+
val value = ctx.freshName("value")
210+
val code = s"""
211+
|boolean $isNull = true;
212+
|${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};
213+
|if ($matched != null) {
214+
| ${ev.code}
215+
| $isNull = ${ev.isNull};
216+
| $value = ${ev.value};
217+
|}
218+
""".stripMargin
219+
ExprCode(code, isNull, value)
220+
}
160221
}
222+
}
223+
224+
/**
225+
* Generates the code for Inner join.
226+
*/
227+
private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
228+
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
229+
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
230+
val matched = ctx.freshName("matched")
231+
val buildVars = genBuildSideVars(ctx, matched)
161232
val resultVars = buildSide match {
162-
case BuildLeft => buildColumns ++ input
163-
case BuildRight => input ++ buildColumns
233+
case BuildLeft => buildVars ++ input
234+
case BuildRight => input ++ buildVars
164235
}
165-
166236
val numOutput = metricTerm(ctx, "numOutputRows")
237+
167238
val outputCode = if (condition.isDefined) {
168239
// filter the output via condition
169240
ctx.currentVars = resultVars
170241
val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
171242
s"""
172-
| ${ev.code}
173-
| if (!${ev.isNull} && ${ev.value}) {
174-
| $numOutput.add(1);
175-
| ${consume(ctx, resultVars)}
176-
| }
243+
|${ev.code}
244+
|if (!${ev.isNull} && ${ev.value}) {
245+
| $numOutput.add(1);
246+
| ${consume(ctx, resultVars)}
247+
|}
177248
""".stripMargin
178249
} else {
179250
s"""
@@ -184,36 +255,110 @@ case class BroadcastHashJoin(
184255

185256
if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
186257
s"""
187-
| // generate join key
188-
| ${keyVal.code}
189-
| // find matches from HashedRelation
190-
| UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value});
191-
| if ($matched != null) {
192-
| ${buildColumns.map(_.code).mkString("\n")}
193-
| $outputCode
194-
| }
195-
""".stripMargin
258+
|// generate join key for stream side
259+
|${keyEv.code}
260+
|// find matches from HashedRelation
261+
|UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
262+
|if ($matched != null) {
263+
| ${buildVars.map(_.code).mkString("\n")}
264+
| $outputCode
265+
|}
266+
""".stripMargin
267+
268+
} else {
269+
val matches = ctx.freshName("matches")
270+
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
271+
val i = ctx.freshName("i")
272+
val size = ctx.freshName("size")
273+
s"""
274+
|// generate join key for stream side
275+
|${keyEv.code}
276+
|// find matches from HashRelation
277+
|$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
278+
|if ($matches != null) {
279+
| int $size = $matches.size();
280+
| for (int $i = 0; $i < $size; $i++) {
281+
| UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
282+
| ${buildVars.map(_.code).mkString("\n")}
283+
| $outputCode
284+
| }
285+
|}
286+
""".stripMargin
287+
}
288+
}
289+
290+
291+
/**
292+
* Generates the code for left or right outer join.
293+
*/
294+
private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
295+
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
296+
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
297+
val matched = ctx.freshName("matched")
298+
val buildVars = genBuildSideVars(ctx, matched)
299+
val resultVars = buildSide match {
300+
case BuildLeft => buildVars ++ input
301+
case BuildRight => input ++ buildVars
302+
}
303+
val numOutput = metricTerm(ctx, "numOutputRows")
304+
305+
// filter the output via condition
306+
val conditionPassed = ctx.freshName("conditionPassed")
307+
val checkCondition = if (condition.isDefined) {
308+
ctx.currentVars = resultVars
309+
val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
310+
s"""
311+
|boolean $conditionPassed = true;
312+
|if ($matched != null) {
313+
| ${ev.code}
314+
| $conditionPassed = !${ev.isNull} && ${ev.value};
315+
|}
316+
""".stripMargin
317+
} else {
318+
s"final boolean $conditionPassed = true;"
319+
}
320+
321+
if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
322+
s"""
323+
|// generate join key for stream side
324+
|${keyEv.code}
325+
|// find matches from HashedRelation
326+
|UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
327+
|${buildVars.map(_.code).mkString("\n")}
328+
|${checkCondition.trim}
329+
|if (!$conditionPassed) {
330+
| // reset to null
331+
| ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")}
332+
|}
333+
|$numOutput.add(1);
334+
|${consume(ctx, resultVars)}
335+
""".stripMargin
196336

197337
} else {
198338
val matches = ctx.freshName("matches")
199339
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
200340
val i = ctx.freshName("i")
201341
val size = ctx.freshName("size")
342+
val found = ctx.freshName("found")
202343
s"""
203-
| // generate join key
204-
| ${keyVal.code}
205-
| // find matches from HashRelation
206-
| $bufferType $matches = ${anyNull} ? null :
207-
| ($bufferType) $relationTerm.get(${keyVal.value});
208-
| if ($matches != null) {
209-
| int $size = $matches.size();
210-
| for (int $i = 0; $i < $size; $i++) {
211-
| UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
212-
| ${buildColumns.map(_.code).mkString("\n")}
213-
| $outputCode
214-
| }
215-
| }
216-
""".stripMargin
344+
|// generate join key for stream side
345+
|${keyEv.code}
346+
|// find matches from HashRelation
347+
|$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
348+
|int $size = $matches != null ? $matches.size() : 0;
349+
|boolean $found = false;
350+
|// the last iteration of this loop is to emit an empty row if there is no matched rows.
351+
|for (int $i = 0; $i <= $size; $i++) {
352+
| UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null;
353+
| ${buildVars.map(_.code).mkString("\n")}
354+
| ${checkCondition.trim}
355+
| if ($conditionPassed && ($i < $size || !$found)) {
356+
| $found = true;
357+
| $numOutput.add(1);
358+
| ${consume(ctx, resultVars)}
359+
| }
360+
|}
361+
""".stripMargin
217362
}
218363
}
219364
}

0 commit comments

Comments
 (0)