Skip to content

Commit f9f055a

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-24121][SQL] Add API for handling expression code generation
## What changes were proposed in this pull request? This patch tries to implement this [proposal](#19813 (comment)) to add an API for handling expression code generation. It should allow us to manipulate how to generate codes for expressions. In details, this adds an new abstraction `CodeBlock` to `JavaCode`. `CodeBlock` holds the code snippet and inputs for generating actual java code. For example, in following java code: ```java int ${variable} = 1; boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)}; ``` `variable`, `isNull` are two `VariableValue` and `CodeGenerator.defaultValue(BooleanType)` is a string. They are all inputs to this code block and held by `CodeBlock` representing this code. For codegen, we provide a specified string interpolator `code`, so you can define a code like this: ```scala val codeBlock = code""" |int ${variable} = 1; |boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)}; """.stripMargin // Generates actual java code. codeBlock.toString ``` Because those inputs are held separately in `CodeBlock` before generating code, we can safely manipulate them, e.g., replacing statements to aliased variables, etc.. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh <[email protected]> Closes #21193 from viirya/SPARK-24121.
1 parent 8086acc commit f9f055a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+479
-172
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.errors.attachTree
2323
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
24+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2425
import org.apache.spark.sql.types._
2526

2627
/**
@@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
5657
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
5758
if (nullable) {
5859
ev.copy(code =
59-
s"""
60+
code"""
6061
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
6162
|$javaType ${ev.value} = ${ev.isNull} ?
6263
| ${CodeGenerator.defaultValue(dataType)} : ($value);
6364
""".stripMargin)
6465
} else {
65-
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
66+
ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
6667
}
6768
}
6869
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.SparkException
2323
import org.apache.spark.sql.catalyst.InternalRow
2424
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
2525
import org.apache.spark.sql.catalyst.expressions.codegen._
26+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2627
import org.apache.spark.sql.catalyst.util._
2728
import org.apache.spark.sql.types._
2829
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -623,8 +624,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
623624
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
624625
val eval = child.genCode(ctx)
625626
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
626-
ev.copy(code = eval.code +
627-
castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
627+
628+
ev.copy(code =
629+
code"""
630+
${eval.code}
631+
// This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull}
632+
${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)}
633+
""")
628634
}
629635

630636
// The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull`

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.Locale
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2424
import org.apache.spark.sql.catalyst.expressions.codegen._
25+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2526
import org.apache.spark.sql.catalyst.trees.TreeNode
2627
import org.apache.spark.sql.types._
2728
import org.apache.spark.util.Utils
@@ -108,9 +109,9 @@ abstract class Expression extends TreeNode[Expression] {
108109
JavaCode.isNullVariable(isNull),
109110
JavaCode.variable(value, dataType)))
110111
reduceCodeSize(ctx, eval)
111-
if (eval.code.nonEmpty) {
112+
if (eval.code.toString.nonEmpty) {
112113
// Add `this` in the comment.
113-
eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim)
114+
eval.copy(code = ctx.registerComment(this.toString) + eval.code)
114115
} else {
115116
eval
116117
}
@@ -119,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] {
119120

120121
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
121122
// TODO: support whole stage codegen too
122-
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
123+
if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
123124
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
124125
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
125126
val localIsNull = eval.isNull
@@ -136,14 +137,14 @@ abstract class Expression extends TreeNode[Expression] {
136137
val funcFullName = ctx.addNewFunction(funcName,
137138
s"""
138139
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
139-
| ${eval.code.trim}
140+
| ${eval.code}
140141
| $setIsNull
141142
| return ${eval.value};
142143
|}
143144
""".stripMargin)
144145

145146
eval.value = JavaCode.variable(newValue, dataType)
146-
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
147+
eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
147148
}
148149
}
149150

@@ -437,15 +438,14 @@ abstract class UnaryExpression extends Expression {
437438

438439
if (nullable) {
439440
val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
440-
ev.copy(code = s"""
441+
ev.copy(code = code"""
441442
${childGen.code}
442443
boolean ${ev.isNull} = ${childGen.isNull};
443444
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
444445
$nullSafeEval
445446
""")
446447
} else {
447-
ev.copy(code = s"""
448-
boolean ${ev.isNull} = false;
448+
ev.copy(code = code"""
449449
${childGen.code}
450450
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
451451
$resultCode""", isNull = FalseLiteral)
@@ -537,14 +537,13 @@ abstract class BinaryExpression extends Expression {
537537
}
538538
}
539539

540-
ev.copy(code = s"""
540+
ev.copy(code = code"""
541541
boolean ${ev.isNull} = true;
542542
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
543543
$nullSafeEval
544544
""")
545545
} else {
546-
ev.copy(code = s"""
547-
boolean ${ev.isNull} = false;
546+
ev.copy(code = code"""
548547
${leftGen.code}
549548
${rightGen.code}
550549
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -681,13 +680,12 @@ abstract class TernaryExpression extends Expression {
681680
}
682681
}
683682

684-
ev.copy(code = s"""
683+
ev.copy(code = code"""
685684
boolean ${ev.isNull} = true;
686685
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
687686
$nullSafeEval""")
688687
} else {
689-
ev.copy(code = s"""
690-
boolean ${ev.isNull} = false;
688+
ev.copy(code = code"""
691689
${leftGen.code}
692690
${midGen.code}
693691
${rightGen.code}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
22+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2223
import org.apache.spark.sql.types.{DataType, LongType}
2324

2425
/**
@@ -72,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
7273
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
7374
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
7475

75-
ev.copy(code = s"""
76+
ev.copy(code = code"""
7677
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
7778
$countTerm++;""", isNull = FalseLiteral)
7879
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.SparkException
2121
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
23+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2324
import org.apache.spark.sql.types.DataType
2425

2526
/**
@@ -1030,7 +1031,7 @@ case class ScalaUDF(
10301031
""".stripMargin
10311032

10321033
ev.copy(code =
1033-
s"""
1034+
code"""
10341035
|$evalCode
10351036
|${initArgs.mkString("\n")}
10361037
|$callFunc

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2222
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
23+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2324
import org.apache.spark.sql.types._
2425
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._
2526

@@ -181,7 +182,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
181182
}
182183

183184
ev.copy(code = childCode.code +
184-
s"""
185+
code"""
185186
|long ${ev.value} = 0L;
186187
|boolean ${ev.isNull} = ${childCode.isNull};
187188
|if (!${childCode.isNull}) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
22+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2223
import org.apache.spark.sql.types.{DataType, IntegerType}
2324

2425
/**
@@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
4647
val idTerm = "partitionId"
4748
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
4849
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
49-
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
50+
ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
5051
isNull = FalseLiteral)
5152
}
5253
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException
2323
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2424
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
2525
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
26+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2627
import org.apache.spark.sql.types._
2728
import org.apache.spark.unsafe.types.CalendarInterval
2829

@@ -164,7 +165,7 @@ case class PreciseTimestampConversion(
164165
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
165166
val eval = child.genCode(ctx)
166167
ev.copy(code = eval.code +
167-
s"""boolean ${ev.isNull} = ${eval.isNull};
168+
code"""boolean ${ev.isNull} = ${eval.isNull};
168169
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
169170
""".stripMargin)
170171
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
23+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2324
import org.apache.spark.sql.catalyst.util.TypeUtils
2425
import org.apache.spark.sql.types._
2526
import org.apache.spark.unsafe.types.CalendarInterval
@@ -259,7 +260,7 @@ trait DivModLike extends BinaryArithmetic {
259260
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
260261
}
261262
if (!left.nullable && !right.nullable) {
262-
ev.copy(code = s"""
263+
ev.copy(code = code"""
263264
${eval2.code}
264265
boolean ${ev.isNull} = false;
265266
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -270,7 +271,7 @@ trait DivModLike extends BinaryArithmetic {
270271
${ev.value} = $operation;
271272
}""")
272273
} else {
273-
ev.copy(code = s"""
274+
ev.copy(code = code"""
274275
${eval2.code}
275276
boolean ${ev.isNull} = false;
276277
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -436,7 +437,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
436437
}
437438

438439
if (!left.nullable && !right.nullable) {
439-
ev.copy(code = s"""
440+
ev.copy(code = code"""
440441
${eval2.code}
441442
boolean ${ev.isNull} = false;
442443
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -447,7 +448,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
447448
$result
448449
}""")
449450
} else {
450-
ev.copy(code = s"""
451+
ev.copy(code = code"""
451452
${eval2.code}
452453
boolean ${ev.isNull} = false;
453454
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -569,7 +570,7 @@ case class Least(children: Seq[Expression]) extends Expression {
569570
""".stripMargin,
570571
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
571572
ev.copy(code =
572-
s"""
573+
code"""
573574
|${ev.isNull} = true;
574575
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
575576
|$codes
@@ -644,7 +645,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
644645
""".stripMargin,
645646
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
646647
ev.copy(code =
647-
s"""
648+
code"""
648649
|${ev.isNull} = true;
649650
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
650651
|$codes

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging
3838
import org.apache.spark.metrics.source.CodegenMetrics
3939
import org.apache.spark.sql.catalyst.InternalRow
4040
import org.apache.spark.sql.catalyst.expressions._
41+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
4142
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
4243
import org.apache.spark.sql.internal.SQLConf
4344
import org.apache.spark.sql.types._
@@ -57,19 +58,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
5758
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
5859
* valid if `isNull` is set to `true`.
5960
*/
60-
case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)
61+
case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue)
6162

6263
object ExprCode {
6364
def apply(isNull: ExprValue, value: ExprValue): ExprCode = {
64-
ExprCode(code = "", isNull, value)
65+
ExprCode(code = EmptyBlock, isNull, value)
6566
}
6667

6768
def forNullValue(dataType: DataType): ExprCode = {
68-
ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
69+
ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
6970
}
7071

7172
def forNonNullValue(value: ExprValue): ExprCode = {
72-
ExprCode(code = "", isNull = FalseLiteral, value = value)
73+
ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value)
7374
}
7475
}
7576

@@ -330,9 +331,9 @@ class CodegenContext {
330331
def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = {
331332
val value = addMutableState(javaType(dataType), variableName)
332333
val code = dataType match {
333-
case StringType => s"$value = $initCode.clone();"
334-
case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
335-
case _ => s"$value = $initCode;"
334+
case StringType => code"$value = $initCode.clone();"
335+
case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();"
336+
case _ => code"$value = $initCode;"
336337
}
337338
ExprCode(code, FalseLiteral, JavaCode.global(value, dataType))
338339
}
@@ -1056,7 +1057,7 @@ class CodegenContext {
10561057
val eval = expr.genCode(this)
10571058
val state = SubExprEliminationState(eval.isNull, eval.value)
10581059
e.foreach(localSubExprEliminationExprs.put(_, state))
1059-
eval.code.trim
1060+
eval.code.toString
10601061
}
10611062
SubExprCodes(codes, localSubExprEliminationExprs.toMap)
10621063
}
@@ -1084,7 +1085,7 @@ class CodegenContext {
10841085
val fn =
10851086
s"""
10861087
|private void $fnName(InternalRow $INPUT_ROW) {
1087-
| ${eval.code.trim}
1088+
| ${eval.code}
10881089
| $isNull = ${eval.isNull};
10891090
| $value = ${eval.value};
10901091
|}
@@ -1141,7 +1142,7 @@ class CodegenContext {
11411142
def registerComment(
11421143
text: => String,
11431144
placeholderId: String = "",
1144-
force: Boolean = false): String = {
1145+
force: Boolean = false): Block = {
11451146
// By default, disable comments in generated code because computing the comments themselves can
11461147
// be extremely expensive in certain cases, such as deeply-nested expressions which operate over
11471148
// inputs with wide schemas. For more details on the performance issues that motivated this
@@ -1160,9 +1161,9 @@ class CodegenContext {
11601161
s"// $text"
11611162
}
11621163
placeHolderToComments += (name -> comment)
1163-
s"/*$name*/"
1164+
code"/*$name*/"
11641165
} else {
1165-
""
1166+
EmptyBlock
11661167
}
11671168
}
11681169
}

0 commit comments

Comments
 (0)