Skip to content

Commit 4fb0bc8

Browse files
author
Davies Liu
committed
fix broadcast hash join
1 parent 4faf5f9 commit 4fb0bc8

File tree

2 files changed

+52
-43
lines changed

2 files changed

+52
-43
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -190,25 +190,25 @@ case class BroadcastHashJoin(
190190
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
191191
val matched = ctx.freshName("matched")
192192
val buildVars = genBuildSideVars(ctx, matched)
193-
val resultVars = buildSide match {
194-
case BuildLeft => buildVars ++ input
195-
case BuildRight => input ++ buildVars
196-
}
197193
val numOutput = metricTerm(ctx, "numOutputRows")
198194

199195
val checkCondition = if (condition.isDefined) {
200196
// filter the output via condition
201-
ctx.currentVars = resultVars
202-
val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
197+
ctx.currentVars = input ++ buildVars
198+
val ev = BindReferences.bindReference(
199+
condition.get, streamedPlan.output ++ buildPlan.output).gen(ctx)
203200
s"""
204-
|${evaluateRequiredVariables(buildPlan.output, buildVars, condition.get.references)}
205201
|${ev.code}
206202
|if (${ev.isNull} || !${ev.value}) continue;
207203
""".stripMargin
208204
} else {
209205
""
210206
}
211207

208+
val resultVars = buildSide match {
209+
case BuildLeft => buildVars ++ input
210+
case BuildRight => input ++ buildVars
211+
}
212212
if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
213213
s"""
214214
|// generate join key for stream side
@@ -252,29 +252,31 @@ case class BroadcastHashJoin(
252252
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
253253
val matched = ctx.freshName("matched")
254254
val buildVars = genBuildSideVars(ctx, matched)
255-
val resultVars = buildSide match {
256-
case BuildLeft => buildVars ++ input
257-
case BuildRight => input ++ buildVars
258-
}
259255
val numOutput = metricTerm(ctx, "numOutputRows")
260256

261257
// filter the output via condition
262258
val conditionPassed = ctx.freshName("conditionPassed")
263259
val checkCondition = if (condition.isDefined) {
264-
ctx.currentVars = resultVars
265-
val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
260+
val eval = evaluateRequiredVariables(buildPlan.output, buildVars, condition.get.references)
261+
ctx.currentVars = input ++ buildVars
262+
val ev = BindReferences.bindReference(condition.get,
263+
streamedPlan.output ++ buildPlan.output).gen(ctx)
266264
s"""
267265
|boolean $conditionPassed = true;
268-
|${evaluateRequiredVariables(buildPlan.output, buildVars, condition.get.references)}
266+
|${eval.trim}
267+
|${ev.code}
269268
|if ($matched != null) {
270-
| ${ev.code}
271269
| $conditionPassed = !${ev.isNull} && ${ev.value};
272270
|}
273271
""".stripMargin
274272
} else {
275273
s"final boolean $conditionPassed = true;"
276274
}
277275

276+
val resultVars = buildSide match {
277+
case BuildLeft => buildVars ++ input
278+
case BuildRight => input ++ buildVars
279+
}
278280
if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
279281
s"""
280282
|// generate join key for stream side

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,11 @@ case class SortMergeJoin(
306306
val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) =>
307307
condRefs.contains(a)
308308
}
309-
val beforeCond = used.map(_._2.code).mkString("\n")
310-
val afterCond = notUsed.map(_._2.code).mkString("\n")
309+
val beforeCond = evaluateVariables(used.map(_._2))
310+
val afterCond = evaluateVariables(notUsed.map(_._2))
311311
(beforeCond, afterCond)
312312
} else {
313-
(variables.map(_.code).mkString("\n"), "")
313+
(evaluateVariables(variables), "")
314314
}
315315
}
316316

@@ -326,41 +326,48 @@ case class SortMergeJoin(
326326
val leftVars = createLeftVars(ctx, leftRow)
327327
val rightRow = ctx.freshName("rightRow")
328328
val rightVars = createRightVar(ctx, rightRow)
329-
val resultVars = leftVars ++ rightVars
330-
331-
// Check condition
332-
ctx.currentVars = resultVars
333-
val cond = if (condition.isDefined) {
334-
BindReferences.bindReference(condition.get, output).gen(ctx)
335-
} else {
336-
ExprCode("", "false", "true")
337-
}
338-
// Split the code of creating variables based on whether it's used by condition or not.
339-
val loaded = ctx.freshName("loaded")
340-
val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
341-
val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
342-
343329

344330
val size = ctx.freshName("size")
345331
val i = ctx.freshName("i")
346332
val numOutput = metricTerm(ctx, "numOutputRows")
333+
val (beforeLoop, condCheck) = if (condition.isDefined) {
334+
// Split the code of creating variables based on whether it's used by condition or not.
335+
val loaded = ctx.freshName("loaded")
336+
val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
337+
val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
338+
// Generate code for condition
339+
ctx.currentVars = leftVars ++ rightVars
340+
val cond = BindReferences.bindReference(condition.get, output).gen(ctx)
341+
// evaluate the columns those used by condition before loop
342+
val before = s"""
343+
|boolean $loaded = false;
344+
|$leftBefore
345+
""".stripMargin
346+
347+
val checking = s"""
348+
|$rightBefore
349+
|${cond.code}
350+
|if (${cond.isNull} || !${cond.value}) continue;
351+
|if (!$loaded) {
352+
| $loaded = true;
353+
| $leftAfter
354+
|}
355+
|$rightAfter
356+
""".stripMargin
357+
(before, checking)
358+
} else {
359+
(evaluateVariables(leftVars), "")
360+
}
361+
347362
s"""
348363
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
349364
| int $size = $matches.size();
350-
| boolean $loaded = false;
351-
| $leftBefore
365+
| ${beforeLoop.trim}
352366
| for (int $i = 0; $i < $size; $i ++) {
353367
| InternalRow $rightRow = (InternalRow) $matches.get($i);
354-
| $rightBefore
355-
| ${cond.code}
356-
| if (${cond.isNull} || !${cond.value}) continue;
357-
| if (!$loaded) {
358-
| $loaded = true;
359-
| $leftAfter
360-
| }
361-
| $rightAfter
368+
| ${condCheck.trim}
362369
| $numOutput.add(1);
363-
| ${consume(ctx, resultVars)}
370+
| ${consume(ctx, leftVars ++ rightVars)}
364371
| }
365372
| if (shouldStop()) return;
366373
|}

0 commit comments

Comments
 (0)