Skip to content

Commit 32cfd3e

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-24361][SQL] Polish code block manipulation API
## What changes were proposed in this pull request? Current code block manipulation API is immature and hacky. We need a formal API to manipulate code blocks. The basic idea is making `JavaCode` as `TreeNode`. So we can use familiar `transform` API to manipulate code blocks and expressions in code blocks. For example, we can replace `SimpleExprValue` in a code block like this: ```scala code.transformExprValues { case SimpleExprValue("1 + 1", _) => aliasedParam } ``` The example use case is splitting code to methods. For example, we have an `ExprCode` containing generated code. But it is too long and we need to split it as method. Because statement-based expressions can't be directly passed into. We need to transform them as variables first: ```scala def getExprValues(block: Block): Set[ExprValue] = block match { case c: CodeBlock => c.blockInputs.collect { case e: ExprValue => e }.toSet case _ => Set.empty } def currentCodegenInputs(ctx: CodegenContext): Set[ExprValue] = { // Collects current variables in ctx.currentVars and ctx.INPUT_ROW. // It looks roughly like... ctx.currentVars.flatMap { v => getExprValues(v.code) ++ Set(v.value, v.isNull) }.toSet + ctx.INPUT_ROW } // A code block of an expression contains too long code, making it as method if (eval.code.length > 1024) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { ... } else { "" } // Pick up variables and statements necessary to pass in. val currentVars = currentCodegenInputs(ctx) val varsPassIn = getExprValues(eval.code).intersect(currentVars) val aliasedExprs = HashMap.empty[SimpleExprValue, VariableValue] // Replace statement-based expressions which can't be directly passed in the method. val newCode = eval.code.transform { case block => block.transformExprValues { case s: SimpleExprValue(_, javaType) if varsPassIn.contains(s) => if (aliasedExprs.contains(s)) { aliasedExprs(s) } else { val aliasedVariable = JavaCode.variable(ctx.freshName("aliasedVar"), javaType) aliasedExprs += s -> aliasedVariable varsPassIn += aliasedVariable aliasedVariable } } } val params = varsPassIn.filter(!_.isInstanceOf[SimpleExprValue])).map { variable => s"${variable.javaType.getName} ${variable.variableName}" }.mkString(", ") val funcName = ctx.freshName("nodeName") val javaType = CodeGenerator.javaType(dataType) val newValue = JavaCode.variable(ctx.freshName("value"), dataType) val funcFullName = ctx.addNewFunction(funcName, s""" |private $javaType $funcName($params) { | $newCode | $setIsNull | return ${eval.value}; |} """.stripMargin)) eval.value = newValue val args = varsPassIn.filter(!_.isInstanceOf[SimpleExprValue])).map { variable => s"${variable.variableName}" } // Create a code block to assign statements to aliased variables. val createVariables = aliasedExprs.foldLeft(EmptyBlock) { (block, (statement, variable)) => block + code"${statement.javaType.getName} $variable = $statement;" } eval.code = createVariables + code"$javaType $newValue = $funcFullName($args);" } ``` ## How was this patch tested? Added unite tests. Author: Liang-Chi Hsieh <[email protected]> Closes #21405 from viirya/codeblock-api.
1 parent 4be9f0c commit 32cfd3e

File tree

2 files changed

+104
-19
lines changed
  • sql/catalyst/src

2 files changed

+104
-19
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.lang.{Boolean => JBool}
2222
import scala.collection.mutable.ArrayBuffer
2323
import scala.language.{existentials, implicitConversions}
2424

25+
import org.apache.spark.sql.catalyst.trees.TreeNode
2526
import org.apache.spark.sql.types.{BooleanType, DataType}
2627

2728
/**
@@ -118,12 +119,9 @@ object JavaCode {
118119
/**
119120
* A trait representing a block of java code.
120121
*/
121-
trait Block extends JavaCode {
122+
trait Block extends TreeNode[Block] with JavaCode {
122123
import Block._
123124

124-
// The expressions to be evaluated inside this block.
125-
def exprValues: Set[ExprValue]
126-
127125
// Returns java code string for this code block.
128126
override def toString: String = _marginChar match {
129127
case Some(c) => code.stripMargin(c).trim
@@ -148,11 +146,41 @@ trait Block extends JavaCode {
148146
this
149147
}
150148

149+
/**
150+
* Apply a map function to each java expression codes present in this java code, and return a new
151+
* java code based on the mapped java expression codes.
152+
*/
153+
def transformExprValues(f: PartialFunction[ExprValue, ExprValue]): this.type = {
154+
var changed = false
155+
156+
@inline def transform(e: ExprValue): ExprValue = {
157+
val newE = f lift e
158+
if (!newE.isDefined || newE.get.equals(e)) {
159+
e
160+
} else {
161+
changed = true
162+
newE.get
163+
}
164+
}
165+
166+
def doTransform(arg: Any): AnyRef = arg match {
167+
case e: ExprValue => transform(e)
168+
case Some(value) => Some(doTransform(value))
169+
case seq: Traversable[_] => seq.map(doTransform)
170+
case other: AnyRef => other
171+
}
172+
173+
val newArgs = mapProductIterator(doTransform)
174+
if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this
175+
}
176+
151177
// Concatenates this block with other block.
152178
def + (other: Block): Block = other match {
153179
case EmptyBlock => this
154180
case _ => code"$this\n$other"
155181
}
182+
183+
override def verboseString: String = toString
156184
}
157185

158186
object Block {
@@ -219,12 +247,8 @@ object Block {
219247
* method splitting.
220248
*/
221249
case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block {
222-
override lazy val exprValues: Set[ExprValue] = {
223-
blockInputs.flatMap {
224-
case b: Block => b.exprValues
225-
case e: ExprValue => Set(e)
226-
}.toSet
227-
}
250+
override def children: Seq[Block] =
251+
blockInputs.filter(_.isInstanceOf[Block]).asInstanceOf[Seq[Block]]
228252

229253
override lazy val code: String = {
230254
val strings = codeParts.iterator
@@ -239,9 +263,9 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends
239263
}
240264
}
241265

242-
object EmptyBlock extends Block with Serializable {
266+
case object EmptyBlock extends Block with Serializable {
243267
override val code: String = ""
244-
override val exprValues: Set[ExprValue] = Set.empty
268+
override def children: Seq[Block] = Seq.empty
245269
}
246270

247271
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ class CodeBlockSuite extends SparkFunSuite {
6565
|boolean $isNull = false;
6666
|int $value = -1;
6767
""".stripMargin
68-
val exprValues = code.exprValues
68+
val exprValues = code.asInstanceOf[CodeBlock].blockInputs.collect {
69+
case e: ExprValue => e
70+
}.toSet
6971
assert(exprValues.size == 2)
7072
assert(exprValues === Set(value, isNull))
7173
}
@@ -94,7 +96,9 @@ class CodeBlockSuite extends SparkFunSuite {
9496

9597
assert(code.toString == expected)
9698

97-
val exprValues = code.exprValues
99+
val exprValues = code.children.flatMap(_.asInstanceOf[CodeBlock].blockInputs.collect {
100+
case e: ExprValue => e
101+
}).toSet
98102
assert(exprValues.size == 5)
99103
assert(exprValues === Set(isNull1, value1, isNull2, value2, literal))
100104
}
@@ -107,7 +111,7 @@ class CodeBlockSuite extends SparkFunSuite {
107111
assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}"))
108112
}
109113

110-
test("replace expr values in code block") {
114+
test("transform expr in code block") {
111115
val expr = JavaCode.expression("1 + 1", IntegerType)
112116
val isNull = JavaCode.isNullVariable("expr1_isNull")
113117
val exprInFunc = JavaCode.variable("expr1", IntegerType)
@@ -120,11 +124,11 @@ class CodeBlockSuite extends SparkFunSuite {
120124
|}""".stripMargin
121125

122126
val aliasedParam = JavaCode.variable("aliased", expr.javaType)
123-
val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map {
124-
case _: SimpleExprValue => aliasedParam
125-
case other => other
127+
128+
// We want to replace all occurrences of `expr` with the variable `aliasedParam`.
129+
val aliasedCode = code.transformExprValues {
130+
case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam
126131
}
127-
val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin
128132
val expected =
129133
code"""
130134
|callFunc(int $aliasedParam) {
@@ -133,4 +137,61 @@ class CodeBlockSuite extends SparkFunSuite {
133137
|}""".stripMargin
134138
assert(aliasedCode.toString == expected.toString)
135139
}
140+
141+
test ("transform expr in nested blocks") {
142+
val expr = JavaCode.expression("1 + 1", IntegerType)
143+
val isNull = JavaCode.isNullVariable("expr1_isNull")
144+
val exprInFunc = JavaCode.variable("expr1", IntegerType)
145+
146+
val funcs = Seq("callFunc1", "callFunc2", "callFunc3")
147+
val subBlocks = funcs.map { funcName =>
148+
code"""
149+
|$funcName(int $expr) {
150+
| boolean $isNull = false;
151+
| int $exprInFunc = $expr + 1;
152+
|}""".stripMargin
153+
}
154+
155+
val aliasedParam = JavaCode.variable("aliased", expr.javaType)
156+
157+
val block = code"${subBlocks(0)}\n${subBlocks(1)}\n${subBlocks(2)}"
158+
val transformedBlock = block.transform {
159+
case b: Block => b.transformExprValues {
160+
case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam
161+
}
162+
}.asInstanceOf[CodeBlock]
163+
164+
val expected1 =
165+
code"""
166+
|callFunc1(int aliased) {
167+
| boolean expr1_isNull = false;
168+
| int expr1 = aliased + 1;
169+
|}""".stripMargin
170+
171+
val expected2 =
172+
code"""
173+
|callFunc2(int aliased) {
174+
| boolean expr1_isNull = false;
175+
| int expr1 = aliased + 1;
176+
|}""".stripMargin
177+
178+
val expected3 =
179+
code"""
180+
|callFunc3(int aliased) {
181+
| boolean expr1_isNull = false;
182+
| int expr1 = aliased + 1;
183+
|}""".stripMargin
184+
185+
val exprValues = transformedBlock.children.flatMap { block =>
186+
block.asInstanceOf[CodeBlock].blockInputs.collect {
187+
case e: ExprValue => e
188+
}
189+
}.toSet
190+
191+
assert(transformedBlock.children(0).toString == expected1.toString)
192+
assert(transformedBlock.children(1).toString == expected2.toString)
193+
assert(transformedBlock.children(2).toString == expected3.toString)
194+
assert(transformedBlock.toString == (expected1 + expected2 + expected3).toString)
195+
assert(exprValues === Set(isNull, exprInFunc, aliasedParam))
196+
}
136197
}

0 commit comments

Comments
 (0)