@@ -24,8 +24,9 @@ import org.apache.spark.TaskContext
2424import org .apache .spark .broadcast .Broadcast
2525import org .apache .spark .rdd .RDD
2626import org .apache .spark .sql .catalyst .InternalRow
27- import org .apache .spark .sql .catalyst .expressions .{ BindReferences , BoundReference , Expression , UnsafeRow }
27+ import org .apache .spark .sql .catalyst .expressions ._
2828import org .apache .spark .sql .catalyst .expressions .codegen .{CodegenContext , ExprCode , GenerateUnsafeProjection }
29+ import org .apache .spark .sql .catalyst .plans .{Inner , JoinType , LeftOuter , RightOuter }
2930import org .apache .spark .sql .catalyst .plans .physical .{Distribution , Partitioning , UnspecifiedDistribution }
3031import org .apache .spark .sql .execution .{BinaryNode , CodegenSupport , SparkPlan , SQLExecution }
3132import org .apache .spark .sql .execution .metric .SQLMetrics
@@ -41,6 +42,7 @@ import org.apache.spark.util.collection.CompactBuffer
4142case class BroadcastHashJoin (
4243 leftKeys : Seq [Expression ],
4344 rightKeys : Seq [Expression ],
45+ joinType : JoinType ,
4446 buildSide : BuildSide ,
4547 condition : Option [Expression ],
4648 left : SparkPlan ,
@@ -105,75 +107,144 @@ case class BroadcastHashJoin(
105107 val broadcastRelation = Await .result(broadcastFuture, timeout)
106108
107109 streamedPlan.execute().mapPartitions { streamedIter =>
108- val hashedRelation = broadcastRelation.value
109- TaskContext .get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
110- hashJoin(streamedIter, hashedRelation, numOutputRows)
110+ val joinedRow = new JoinedRow ()
111+ val hashTable = broadcastRelation.value
112+ TaskContext .get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
113+ val keyGenerator = streamSideKeyGenerator
114+ val resultProj = createResultProjection
115+
116+ joinType match {
117+ case Inner =>
118+ hashJoin(streamedIter, hashTable, numOutputRows)
119+
120+ case LeftOuter =>
121+ streamedIter.flatMap { currentRow =>
122+ val rowKey = keyGenerator(currentRow)
123+ joinedRow.withLeft(currentRow)
124+ leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
125+ }
126+
127+ case RightOuter =>
128+ streamedIter.flatMap { currentRow =>
129+ val rowKey = keyGenerator(currentRow)
130+ joinedRow.withRight(currentRow)
131+ rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
132+ }
133+
134+ case x =>
135+ throw new IllegalArgumentException (
136+ s " BroadcastHashJoin should not take $x as the JoinType " )
137+ }
111138 }
112139 }
113140
114- private var broadcastRelation : Broadcast [HashedRelation ] = _
115- // the term for hash relation
116- private var relationTerm : String = _
117-
118141 override def upstream (): RDD [InternalRow ] = {
119142 streamedPlan.asInstanceOf [CodegenSupport ].upstream()
120143 }
121144
122145 override def doProduce (ctx : CodegenContext ): String = {
146+ streamedPlan.asInstanceOf [CodegenSupport ].produce(ctx, this )
147+ }
148+
149+ override def doConsume (ctx : CodegenContext , input : Seq [ExprCode ]): String = {
150+ if (joinType == Inner ) {
151+ codegenInner(ctx, input)
152+ } else {
153+ // LeftOuter and RightOuter
154+ codegenOuter(ctx, input)
155+ }
156+ }
157+
158+ /**
159+ * Returns a tuple of Broadcast of HashedRelation and the variable name for it.
160+ */
161+ private def prepareBroadcast (ctx : CodegenContext ): (Broadcast [HashedRelation ], String ) = {
123162 // create a name for HashedRelation
124- broadcastRelation = Await .result(broadcastFuture, timeout)
163+ val broadcastRelation = Await .result(broadcastFuture, timeout)
125164 val broadcast = ctx.addReferenceObj(" broadcast" , broadcastRelation)
126- relationTerm = ctx.freshName(" relation" )
165+ val relationTerm = ctx.freshName(" relation" )
127166 val clsName = broadcastRelation.value.getClass.getName
128167 ctx.addMutableState(clsName, relationTerm,
129168 s """
130169 | $relationTerm = ( $clsName) $broadcast.value();
131170 | incPeakExecutionMemory( $relationTerm.getMemorySize());
132171 """ .stripMargin)
133-
134- s """
135- | ${streamedPlan.asInstanceOf [CodegenSupport ].produce(ctx, this )}
136- """ .stripMargin
172+ (broadcastRelation, relationTerm)
137173 }
138174
139- override def doConsume (ctx : CodegenContext , input : Seq [ExprCode ]): String = {
140- // generate the key as UnsafeRow or Long
175+ /**
176+ * Returns the code for generating join key for stream side, and expression of whether the key
177+ * has any null in it or not.
178+ */
179+ private def genStreamSideJoinKey (
180+ ctx : CodegenContext ,
181+ input : Seq [ExprCode ]): (ExprCode , String ) = {
141182 ctx.currentVars = input
142- val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) {
183+ if (canJoinKeyFitWithinLong) {
184+ // generate the join key as Long
143185 val expr = rewriteKeyExpr(streamedKeys).head
144186 val ev = BindReferences .bindReference(expr, streamedPlan.output).gen(ctx)
145187 (ev, ev.isNull)
146188 } else {
189+ // generate the join key as UnsafeRow
147190 val keyExpr = streamedKeys.map(BindReferences .bindReference(_, streamedPlan.output))
148191 val ev = GenerateUnsafeProjection .createCode(ctx, keyExpr)
149192 (ev, s " ${ev.value}.anyNull() " )
150193 }
194+ }
151195
152- // find the matches from HashedRelation
153- val matched = ctx.freshName( " matched " )
154-
155- // create variables for output
196+ /**
197+ * Generates the code for variable of build side.
198+ */
199+ private def genBuildSideVars ( ctx : CodegenContext , matched : String ) : Seq [ ExprCode ] = {
156200 ctx.currentVars = null
157201 ctx.INPUT_ROW = matched
158- val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
159- BoundReference (i, a.dataType, a.nullable).gen(ctx)
202+ buildPlan.output.zipWithIndex.map { case (a, i) =>
203+ val ev = BoundReference (i, a.dataType, a.nullable).gen(ctx)
204+ if (joinType == Inner ) {
205+ ev
206+ } else {
207+ // the variables are needed even there is no matched rows
208+ val isNull = ctx.freshName(" isNull" )
209+ val value = ctx.freshName(" value" )
210+ val code = s """
211+ |boolean $isNull = true;
212+ | ${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};
213+ |if ( $matched != null) {
214+ | ${ev.code}
215+ | $isNull = ${ev.isNull};
216+ | $value = ${ev.value};
217+ |}
218+ """ .stripMargin
219+ ExprCode (code, isNull, value)
220+ }
160221 }
222+ }
223+
224+ /**
225+ * Generates the code for Inner join.
226+ */
227+ private def codegenInner (ctx : CodegenContext , input : Seq [ExprCode ]): String = {
228+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
229+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
230+ val matched = ctx.freshName(" matched" )
231+ val buildVars = genBuildSideVars(ctx, matched)
161232 val resultVars = buildSide match {
162- case BuildLeft => buildColumns ++ input
163- case BuildRight => input ++ buildColumns
233+ case BuildLeft => buildVars ++ input
234+ case BuildRight => input ++ buildVars
164235 }
165-
166236 val numOutput = metricTerm(ctx, " numOutputRows" )
237+
167238 val outputCode = if (condition.isDefined) {
168239 // filter the output via condition
169240 ctx.currentVars = resultVars
170241 val ev = BindReferences .bindReference(condition.get, this .output).gen(ctx)
171242 s """
172- | ${ev.code}
173- | if (! ${ev.isNull} && ${ev.value}) {
174- | $numOutput.add(1);
175- | ${consume(ctx, resultVars)}
176- | }
243+ | ${ev.code}
244+ |if (! ${ev.isNull} && ${ev.value}) {
245+ | $numOutput.add(1);
246+ | ${consume(ctx, resultVars)}
247+ |}
177248 """ .stripMargin
178249 } else {
179250 s """
@@ -184,36 +255,110 @@ case class BroadcastHashJoin(
184255
185256 if (broadcastRelation.value.isInstanceOf [UniqueHashedRelation ]) {
186257 s """
187- | // generate join key
188- | ${keyVal.code}
189- | // find matches from HashedRelation
190- | UnsafeRow $matched = $anyNull ? null: (UnsafeRow) $relationTerm.getValue( ${keyVal.value});
191- | if ( $matched != null) {
192- | ${buildColumns.map(_.code).mkString(" \n " )}
193- | $outputCode
194- | }
195- """ .stripMargin
258+ |// generate join key for stream side
259+ | ${keyEv.code}
260+ |// find matches from HashedRelation
261+ |UnsafeRow $matched = $anyNull ? null: (UnsafeRow) $relationTerm.getValue( ${keyEv.value});
262+ |if ( $matched != null) {
263+ | ${buildVars.map(_.code).mkString(" \n " )}
264+ | $outputCode
265+ |}
266+ """ .stripMargin
267+
268+ } else {
269+ val matches = ctx.freshName(" matches" )
270+ val bufferType = classOf [CompactBuffer [UnsafeRow ]].getName
271+ val i = ctx.freshName(" i" )
272+ val size = ctx.freshName(" size" )
273+ s """
274+ |// generate join key for stream side
275+ | ${keyEv.code}
276+ |// find matches from HashRelation
277+ | $bufferType $matches = $anyNull ? null : ( $bufferType) $relationTerm.get( ${keyEv.value});
278+ |if ( $matches != null) {
279+ | int $size = $matches.size();
280+ | for (int $i = 0; $i < $size; $i++) {
281+ | UnsafeRow $matched = (UnsafeRow) $matches.apply( $i);
282+ | ${buildVars.map(_.code).mkString(" \n " )}
283+ | $outputCode
284+ | }
285+ |}
286+ """ .stripMargin
287+ }
288+ }
289+
290+
291+ /**
292+ * Generates the code for left or right outer join.
293+ */
294+ private def codegenOuter (ctx : CodegenContext , input : Seq [ExprCode ]): String = {
295+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
296+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
297+ val matched = ctx.freshName(" matched" )
298+ val buildVars = genBuildSideVars(ctx, matched)
299+ val resultVars = buildSide match {
300+ case BuildLeft => buildVars ++ input
301+ case BuildRight => input ++ buildVars
302+ }
303+ val numOutput = metricTerm(ctx, " numOutputRows" )
304+
305+ // filter the output via condition
306+ val conditionPassed = ctx.freshName(" conditionPassed" )
307+ val checkCondition = if (condition.isDefined) {
308+ ctx.currentVars = resultVars
309+ val ev = BindReferences .bindReference(condition.get, this .output).gen(ctx)
310+ s """
311+ |boolean $conditionPassed = true;
312+ |if ( $matched != null) {
313+ | ${ev.code}
314+ | $conditionPassed = ! ${ev.isNull} && ${ev.value};
315+ |}
316+ """ .stripMargin
317+ } else {
318+ s " final boolean $conditionPassed = true; "
319+ }
320+
321+ if (broadcastRelation.value.isInstanceOf [UniqueHashedRelation ]) {
322+ s """
323+ |// generate join key for stream side
324+ | ${keyEv.code}
325+ |// find matches from HashedRelation
326+ |UnsafeRow $matched = $anyNull ? null: (UnsafeRow) $relationTerm.getValue( ${keyEv.value});
327+ | ${buildVars.map(_.code).mkString(" \n " )}
328+ | ${checkCondition.trim}
329+ |if (! $conditionPassed) {
330+ | // reset to null
331+ | ${buildVars.map(v => s " ${v.isNull} = true; " ).mkString(" \n " )}
332+ |}
333+ | $numOutput.add(1);
334+ | ${consume(ctx, resultVars)}
335+ """ .stripMargin
196336
197337 } else {
198338 val matches = ctx.freshName(" matches" )
199339 val bufferType = classOf [CompactBuffer [UnsafeRow ]].getName
200340 val i = ctx.freshName(" i" )
201341 val size = ctx.freshName(" size" )
342+ val found = ctx.freshName(" found" )
202343 s """
203- | // generate join key
204- | ${keyVal.code}
205- | // find matches from HashRelation
206- | $bufferType $matches = ${anyNull} ? null :
207- | ( $bufferType) $relationTerm.get( ${keyVal.value});
208- | if ( $matches != null) {
209- | int $size = $matches.size();
210- | for (int $i = 0; $i < $size; $i++) {
211- | UnsafeRow $matched = (UnsafeRow) $matches.apply( $i);
212- | ${buildColumns.map(_.code).mkString(" \n " )}
213- | $outputCode
214- | }
215- | }
216- """ .stripMargin
344+ |// generate join key for stream side
345+ | ${keyEv.code}
346+ |// find matches from HashRelation
347+ | $bufferType $matches = $anyNull ? null : ( $bufferType) $relationTerm.get( ${keyEv.value});
348+ |int $size = $matches != null ? $matches.size() : 0;
349+ |boolean $found = false;
350+ |// the last iteration of this loop is to emit an empty row if there is no matched rows.
351+ |for (int $i = 0; $i <= $size; $i++) {
352+ | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply( $i) : null;
353+ | ${buildVars.map(_.code).mkString(" \n " )}
354+ | ${checkCondition.trim}
355+ | if ( $conditionPassed && ( $i < $size || ! $found)) {
356+ | $found = true;
357+ | $numOutput.add(1);
358+ | ${consume(ctx, resultVars)}
359+ | }
360+ |}
361+ """ .stripMargin
217362 }
218363 }
219364}
0 commit comments