@@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
2727import org .apache .spark .sql .catalyst .plans .physical .Partitioning
2828import org .apache .spark .sql .catalyst .rules .Rule
2929import org .apache .spark .sql .execution .aggregate .TungstenAggregate
30- import org .apache .spark .sql .execution .joins .{BroadcastHashJoin , BuildLeft , BuildRight }
31- import org .apache .spark .sql .execution .metric .{ LongSQLMetric , LongSQLMetricValue , SQLMetric }
30+ import org .apache .spark .sql .execution .joins .{BroadcastHashJoin , BuildLeft , BuildRight , SortMergeJoin }
31+ import org .apache .spark .sql .execution .metric .LongSQLMetricValue
3232
3333/**
3434 * An interface for those physical operators that support codegen.
@@ -39,6 +39,7 @@ trait CodegenSupport extends SparkPlan {
3939 private def variablePrefix : String = this match {
4040 case _ : TungstenAggregate => " agg"
4141 case _ : BroadcastHashJoin => " bhj"
42+ case _ : SortMergeJoin => " smj"
4243 case _ => nodeName.toLowerCase
4344 }
4445
@@ -66,9 +67,11 @@ trait CodegenSupport extends SparkPlan {
6667 private var parent : CodegenSupport = null
6768
6869 /**
69- * Returns the RDD of InternalRow which generates the input rows.
70+ * Returns all the RDDs of InternalRow which generates the input rows.
71+ *
72+ * Note: right now we support up to two RDDs.
7073 */
71- def upstream (): RDD [InternalRow ]
74+ def upstreams (): Seq [ RDD [InternalRow ] ]
7275
7376 /**
7477 * Returns Java source code to process the rows from upstream.
@@ -172,19 +175,23 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
172175
173176 override def supportCodegen : Boolean = false
174177
175- override def upstream (): RDD [InternalRow ] = {
176- child.execute()
178+ override def upstreams (): Seq [ RDD [InternalRow ] ] = {
179+ child.execute() :: Nil
177180 }
178181
179182 override def doProduce (ctx : CodegenContext ): String = {
183+ val input = ctx.freshName(" input" )
184+ // Right now, Range is only used when there is one upstream.
185+ ctx.addMutableState(" scala.collection.Iterator" , input, s " $input = inputs[0]; " )
186+
180187 val exprs = output.zipWithIndex.map(x => new BoundReference (x._2, x._1.dataType, true ))
181188 val row = ctx.freshName(" row" )
182189 ctx.INPUT_ROW = row
183190 ctx.currentVars = null
184191 val columns = exprs.map(_.gen(ctx))
185192 s """
186- | while (input.hasNext()) {
187- | InternalRow $row = (InternalRow) input.next();
193+ | while ( $ input.hasNext()) {
194+ | InternalRow $row = (InternalRow) $ input.next();
188195 | ${columns.map(_.code).mkString(" \n " ).trim}
189196 | ${consume(ctx, columns).trim}
190197 | if (shouldStop()) {
@@ -208,7 +215,7 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
208215 *
209216 * -> execute()
210217 * |
211- * doExecute() ---------> upstream () -------> upstream () ------> execute()
218+ * doExecute() ---------> upstreams () -------> upstreams () ------> execute()
212219 * |
213220 * -----------------> produce()
214221 * |
@@ -260,6 +267,9 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
260267
261268 public GeneratedIterator(Object[] references) {
262269 this.references = references;
270+ }
271+
272+ public void init(scala.collection.Iterator inputs[]) {
263273 ${ctx.initMutableStates()}
264274 }
265275
@@ -276,19 +286,32 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
276286 // println(s"${CodeFormatter.format(cleanedSource)}")
277287 CodeGenerator .compile(cleanedSource)
278288
279- plan.upstream().mapPartitions { iter =>
280-
281- val clazz = CodeGenerator .compile(source)
282- val buffer = clazz.generate(references).asInstanceOf [BufferedRowIterator ]
283- buffer.setInput(iter)
284- new Iterator [InternalRow ] {
285- override def hasNext : Boolean = buffer.hasNext
286- override def next : InternalRow = buffer.next()
289+ val rdds = plan.upstreams()
290+ if (rdds.length == 1 ) {
291+ rdds.head.mapPartitions { iter =>
292+ val clazz = CodeGenerator .compile(cleanedSource)
293+ val buffer = clazz.generate(references).asInstanceOf [BufferedRowIterator ]
294+ buffer.init(Array (iter))
295+ new Iterator [InternalRow ] {
296+ override def hasNext : Boolean = buffer.hasNext
297+ override def next : InternalRow = buffer.next()
298+ }
299+ }
300+ } else {
301+ // Right now, we support up to two upstreams.
302+ rdds.head.zipPartitions(rdds(1 )) { (leftIter, rightIter) =>
303+ val clazz = CodeGenerator .compile(cleanedSource)
304+ val buffer = clazz.generate(references).asInstanceOf [BufferedRowIterator ]
305+ buffer.init(Array (leftIter, rightIter))
306+ new Iterator [InternalRow ] {
307+ override def hasNext : Boolean = buffer.hasNext
308+ override def next : InternalRow = buffer.next()
309+ }
287310 }
288311 }
289312 }
290313
291- override def upstream (): RDD [InternalRow ] = {
314+ override def upstreams (): Seq [ RDD [InternalRow ] ] = {
292315 throw new UnsupportedOperationException
293316 }
294317
@@ -305,7 +328,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
305328 if (row != null ) {
306329 // There is an UnsafeRow already
307330 s """
308- | currentRows.add ( $row.copy());
331+ |append ( $row.copy());
309332 """ .stripMargin
310333 } else {
311334 assert(input != null )
@@ -317,13 +340,13 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
317340 ctx.currentVars = input
318341 val code = GenerateUnsafeProjection .createCode(ctx, colExprs, false )
319342 s """
320- | ${code.code.trim}
321- | currentRows.add ( ${code.value}.copy());
343+ | ${code.code.trim}
344+ |append ( ${code.value}.copy());
322345 """ .stripMargin
323346 } else {
324347 // There is no columns
325348 s """
326- | currentRows.add (unsafeRow);
349+ |append (unsafeRow);
327350 """ .stripMargin
328351 }
329352 }
@@ -395,6 +418,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
395418 b.copy(left = apply(left))
396419 case b @ BroadcastHashJoin (_, _, BuildRight , _, left, right) =>
397420 b.copy(right = apply(right))
421+ case j @ SortMergeJoin (_, _, _, left, right) =>
422+ // The children of SortMergeJoin should do codegen separately.
423+ j.copy(left = apply(left), right = apply(right))
398424 case p if ! supportCodegen(p) =>
399425 val input = apply(p) // collapse them recursively
400426 inputs += input
0 commit comments