diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 250ce48d059e..44f63e21e93b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -119,6 +119,7 @@ object JavaCode { * A trait representing a block of java code. */ trait Block extends JavaCode { + import Block._ // The expressions to be evaluated inside this block. def exprValues: Set[ExprValue] @@ -148,14 +149,17 @@ trait Block extends JavaCode { } // Concatenates this block with other block. - def + (other: Block): Block + def + (other: Block): Block = other match { + case EmptyBlock => this + case _ => code"$this\n$other" + } } object Block { val CODE_BLOCK_BUFFER_LENGTH: Int = 512 - implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) + implicit def blocksToBlock(blocks: Seq[Block]): Block = blocks.reduceLeft(_ + _) implicit class BlockHelper(val sc: StringContext) extends AnyVal { def code(args: Any*): Block = { @@ -190,18 +194,17 @@ object Block { while (strings.hasNext) { val input = inputs.next input match { - case _: ExprValue | _: Block => + case _: ExprValue | _: CodeBlock => codeParts += buf.toString buf.clear blockInputs += input.asInstanceOf[JavaCode] + case EmptyBlock => case _ => buf.append(input) } buf.append(strings.next) } - if (buf.nonEmpty) { - codeParts += buf.toString - } + codeParts += buf.toString (codeParts.toSeq, blockInputs.toSeq) } @@ -209,7 +212,11 @@ object Block { /** * A block of java code. Including a sequence of code parts and some inputs to this block. - * The actual java code is generated by embedding the inputs into the code parts. + * The actual java code is generated by embedding the inputs into the code parts. Here we keep + * inputs of `JavaCode` instead of simply folding them as a string of code, because we need to + * track expressions (`ExprValue`) in this code block. We need to be able to manipulate the + * expressions later without changing the behavior of this code block in some applications, e.g., + * method splitting. */ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { override lazy val exprValues: Set[ExprValue] = { @@ -230,30 +237,11 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends } buf.toString } - - override def + (other: Block): Block = other match { - case c: CodeBlock => Blocks(Seq(this, c)) - case b: Blocks => Blocks(Seq(this) ++ b.blocks) - case EmptyBlock => this - } -} - -case class Blocks(blocks: Seq[Block]) extends Block { - override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet - override lazy val code: String = blocks.map(_.toString).mkString("\n") - - override def + (other: Block): Block = other match { - case c: CodeBlock => Blocks(blocks :+ c) - case b: Blocks => Blocks(blocks ++ b.blocks) - case EmptyBlock => this - } } object EmptyBlock extends Block with Serializable { override val code: String = "" override val exprValues: Set[ExprValue] = Set.empty - - override def + (other: Block): Block = other } /**