Skip to content

Commit 5fe425c

Browse files
committed
Add new abstraction for expression codegen.
1 parent 1df9943 commit 5fe425c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+276
-279
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.errors.attachTree
2323
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
24+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2425
import org.apache.spark.sql.types._
2526

2627
/**
@@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
5657
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
5758
if (nullable) {
5859
ev.copy(code =
59-
s"""
60+
code"""
6061
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
6162
|$javaType ${ev.value} = ${ev.isNull} ?
6263
| ${CodeGenerator.defaultValue(dataType)} : ($value);
6364
""".stripMargin)
6465
} else {
65-
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
66+
ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
6667
}
6768
}
6869
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.SparkException
2323
import org.apache.spark.sql.catalyst.InternalRow
2424
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
2525
import org.apache.spark.sql.catalyst.expressions.codegen._
26+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2627
import org.apache.spark.sql.catalyst.util._
2728
import org.apache.spark.sql.types._
2829
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -623,8 +624,14 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
623624
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
624625
val eval = child.genCode(ctx)
625626
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
627+
628+
// Below the code comment including `eval.value` and `eval.isNull` is a trick. It makes the two
629+
// expr values are referred by this code block.
626630
ev.copy(code = eval.code +
627-
castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
631+
code"""
632+
// Cast from ${eval.value}, ${eval.isNull}
633+
${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)}
634+
""")
628635
}
629636

630637
// The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull`

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.Locale
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2424
import org.apache.spark.sql.catalyst.expressions.codegen._
25+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2526
import org.apache.spark.sql.catalyst.trees.TreeNode
2627
import org.apache.spark.sql.types._
2728
import org.apache.spark.util.Utils
@@ -100,17 +101,18 @@ abstract class Expression extends TreeNode[Expression] {
100101
ctx.subExprEliminationExprs.get(this).map { subExprState =>
101102
// This expression is repeated which means that the code to evaluate it has already been added
102103
// as a function before. In that case, we just re-use it.
103-
ExprCode(ctx.registerComment(this.toString), subExprState.isNull, subExprState.value)
104+
ExprCode(JavaCode.block(ctx.registerComment(this.toString)), subExprState.isNull,
105+
subExprState.value)
104106
}.getOrElse {
105107
val isNull = ctx.freshName("isNull")
106108
val value = ctx.freshName("value")
107109
val eval = doGenCode(ctx, ExprCode(
108110
JavaCode.isNullVariable(isNull),
109111
JavaCode.variable(value, dataType)))
110112
reduceCodeSize(ctx, eval)
111-
if (eval.code.nonEmpty) {
113+
if (eval.code.toString.nonEmpty) {
112114
// Add `this` in the comment.
113-
eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim)
115+
eval.copy(code = JavaCode.block(s"${ctx.registerComment(this.toString)}\n") + eval.code)
114116
} else {
115117
eval
116118
}
@@ -143,7 +145,7 @@ abstract class Expression extends TreeNode[Expression] {
143145
""".stripMargin)
144146

145147
eval.value = JavaCode.variable(newValue, dataType)
146-
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
148+
eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
147149
}
148150
}
149151

@@ -437,14 +439,14 @@ abstract class UnaryExpression extends Expression {
437439

438440
if (nullable) {
439441
val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
440-
ev.copy(code = s"""
442+
ev.copy(code = code"""
441443
${childGen.code}
442444
boolean ${ev.isNull} = ${childGen.isNull};
443445
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
444446
$nullSafeEval
445447
""")
446448
} else {
447-
ev.copy(code = s"""
449+
ev.copy(code = code"""
448450
${childGen.code}
449451
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
450452
$resultCode""", isNull = FalseLiteral)
@@ -536,13 +538,13 @@ abstract class BinaryExpression extends Expression {
536538
}
537539
}
538540

539-
ev.copy(code = s"""
541+
ev.copy(code = code"""
540542
boolean ${ev.isNull} = true;
541543
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
542544
$nullSafeEval
543545
""")
544546
} else {
545-
ev.copy(code = s"""
547+
ev.copy(code = code"""
546548
${leftGen.code}
547549
${rightGen.code}
548550
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -679,12 +681,12 @@ abstract class TernaryExpression extends Expression {
679681
}
680682
}
681683

682-
ev.copy(code = s"""
684+
ev.copy(code = code"""
683685
boolean ${ev.isNull} = true;
684686
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
685687
$nullSafeEval""")
686688
} else {
687-
ev.copy(code = s"""
689+
ev.copy(code = code"""
688690
${leftGen.code}
689691
${midGen.code}
690692
${rightGen.code}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
22+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2223
import org.apache.spark.sql.types.{DataType, LongType}
2324

2425
/**
@@ -71,7 +72,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
7172
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
7273
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
7374

74-
ev.copy(code = s"""
75+
ev.copy(code = code"""
7576
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
7677
$countTerm++;""", isNull = FalseLiteral)
7778
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.SparkException
2121
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
23+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2324
import org.apache.spark.sql.types.DataType
2425

2526
/**
@@ -1030,7 +1031,7 @@ case class ScalaUDF(
10301031
""".stripMargin
10311032

10321033
ev.copy(code =
1033-
s"""
1034+
code"""
10341035
|$evalCode
10351036
|${initArgs.mkString("\n")}
10361037
|$callFunc

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2222
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
23+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2324
import org.apache.spark.sql.types._
2425
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._
2526

@@ -181,7 +182,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
181182
}
182183

183184
ev.copy(code = childCode.code +
184-
s"""
185+
code"""
185186
|long ${ev.value} = 0L;
186187
|boolean ${ev.isNull} = ${childCode.isNull};
187188
|if (!${childCode.isNull}) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
22+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2223
import org.apache.spark.sql.types.{DataType, IntegerType}
2324

2425
/**
@@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
4647
val idTerm = "partitionId"
4748
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
4849
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
49-
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
50+
ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
5051
isNull = FalseLiteral)
5152
}
5253
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException
2323
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2424
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
2525
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
26+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2627
import org.apache.spark.sql.types._
2728
import org.apache.spark.unsafe.types.CalendarInterval
2829

@@ -164,7 +165,7 @@ case class PreciseTimestampConversion(
164165
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
165166
val eval = child.genCode(ctx)
166167
ev.copy(code = eval.code +
167-
s"""boolean ${ev.isNull} = ${eval.isNull};
168+
code"""boolean ${ev.isNull} = ${eval.isNull};
168169
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
169170
""".stripMargin)
170171
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
23+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2324
import org.apache.spark.sql.catalyst.util.TypeUtils
2425
import org.apache.spark.sql.types._
2526
import org.apache.spark.unsafe.types.CalendarInterval
@@ -275,7 +276,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
275276
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
276277
}
277278
if (!left.nullable && !right.nullable) {
278-
ev.copy(code = s"""
279+
ev.copy(code = code"""
279280
${eval2.code}
280281
boolean ${ev.isNull} = false;
281282
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -286,7 +287,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
286287
${ev.value} = $divide;
287288
}""")
288289
} else {
289-
ev.copy(code = s"""
290+
ev.copy(code = code"""
290291
${eval2.code}
291292
boolean ${ev.isNull} = false;
292293
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -362,7 +363,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
362363
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
363364
}
364365
if (!left.nullable && !right.nullable) {
365-
ev.copy(code = s"""
366+
ev.copy(code = code"""
366367
${eval2.code}
367368
boolean ${ev.isNull} = false;
368369
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -373,7 +374,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
373374
${ev.value} = $remainder;
374375
}""")
375376
} else {
376-
ev.copy(code = s"""
377+
ev.copy(code = code"""
377378
${eval2.code}
378379
boolean ${ev.isNull} = false;
379380
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -479,7 +480,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
479480
}
480481

481482
if (!left.nullable && !right.nullable) {
482-
ev.copy(code = s"""
483+
ev.copy(code = code"""
483484
${eval2.code}
484485
boolean ${ev.isNull} = false;
485486
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -490,7 +491,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
490491
$result
491492
}""")
492493
} else {
493-
ev.copy(code = s"""
494+
ev.copy(code = code"""
494495
${eval2.code}
495496
boolean ${ev.isNull} = false;
496497
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -612,7 +613,7 @@ case class Least(children: Seq[Expression]) extends Expression {
612613
""".stripMargin,
613614
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
614615
ev.copy(code =
615-
s"""
616+
code"""
616617
|${ev.isNull} = true;
617618
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
618619
|$codes
@@ -687,7 +688,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
687688
""".stripMargin,
688689
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
689690
ev.copy(code =
690-
s"""
691+
code"""
691692
|${ev.isNull} = true;
692693
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
693694
|$codes

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging
3838
import org.apache.spark.metrics.source.CodegenMetrics
3939
import org.apache.spark.sql.catalyst.InternalRow
4040
import org.apache.spark.sql.catalyst.expressions._
41+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
4142
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
4243
import org.apache.spark.sql.internal.SQLConf
4344
import org.apache.spark.sql.types._
@@ -56,19 +57,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
5657
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
5758
* valid if `isNull` is set to `true`.
5859
*/
59-
case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)
60+
case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue)
6061

6162
object ExprCode {
6263
def apply(isNull: ExprValue, value: ExprValue): ExprCode = {
63-
ExprCode(code = "", isNull, value)
64+
ExprCode(code = code"", isNull, value)
6465
}
6566

6667
def forNullValue(dataType: DataType): ExprCode = {
67-
ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
68+
ExprCode(code = code"", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
6869
}
6970

7071
def forNonNullValue(value: ExprValue): ExprCode = {
71-
ExprCode(code = "", isNull = FalseLiteral, value = value)
72+
ExprCode(code = code"", isNull = FalseLiteral, value = value)
7273
}
7374
}
7475

@@ -329,9 +330,9 @@ class CodegenContext {
329330
def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = {
330331
val value = addMutableState(javaType(dataType), variableName)
331332
val code = dataType match {
332-
case StringType => s"$value = $initCode.clone();"
333-
case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
334-
case _ => s"$value = $initCode;"
333+
case StringType => code"$value = $initCode.clone();"
334+
case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();"
335+
case _ => code"$value = $initCode;"
335336
}
336337
ExprCode(code, FalseLiteral, JavaCode.global(value, dataType))
337338
}

0 commit comments

Comments
 (0)