From 5ace8b83b7c90cd5a6a451812ac9c1087aaa1c29 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 21 Dec 2017 04:22:10 +0000 Subject: [PATCH 1/9] Add wrappers for codegen output. --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 14 +++---- .../MonotonicallyIncreasingID.scala | 4 +- .../expressions/SparkPartitionID.scala | 5 ++- .../expressions/codegen/CodeGenerator.scala | 39 ++++++++++++++++--- .../expressions/codegen/CodegenFallback.scala | 2 +- .../codegen/GenerateMutableProjection.scala | 6 +-- .../codegen/GenerateSafeProjection.scala | 14 +++---- .../codegen/GenerateUnsafeProjection.scala | 5 ++- .../expressions/collectionOperations.scala | 4 +- .../expressions/complexTypeCreator.scala | 8 ++-- .../expressions/datetimeExpressions.scala | 6 +-- .../sql/catalyst/expressions/generators.scala | 4 +- .../spark/sql/catalyst/expressions/hash.scala | 4 +- .../catalyst/expressions/inputFileBlock.scala | 8 ++-- .../sql/catalyst/expressions/literals.scala | 18 ++++----- .../spark/sql/catalyst/expressions/misc.scala | 5 ++- .../expressions/nullExpressions.scala | 23 ++++++++--- .../expressions/objects/objects.scala | 23 +++++------ .../sql/catalyst/expressions/predicates.scala | 10 ++--- .../expressions/randomExpressions.scala | 8 ++-- .../sql/execution/ColumnarBatchScan.scala | 10 +++-- .../spark/sql/execution/ExpandExec.scala | 4 +- .../spark/sql/execution/GenerateExec.scala | 10 ++--- .../sql/execution/WholeStageCodegenExec.scala | 6 +-- .../aggregate/HashAggregateExec.scala | 2 +- .../aggregate/HashMapGenerator.scala | 4 +- .../execution/basicPhysicalOperators.scala | 6 +-- .../joins/BroadcastHashJoinExec.scala | 6 +-- .../execution/joins/SortMergeJoinExec.scala | 6 +-- 30 files changed, 160 insertions(+), 108 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 6a17a397b3ef2..74fbe634ae5b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue} import org.apache.spark.sql.types._ /** @@ -75,7 +75,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") + ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = LiteralValue("false")) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4568714933095..932dd18404352 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,7 +104,7 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode("", isNull, value)) + val eval = doGenCode(ctx, ExprCode("", VariableValue(isNull), VariableValue(value))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -118,10 +118,10 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { - val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { + val setIsNull = if ("false" != eval.isNull && "true" != eval.isNull) { val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = globalIsNull + eval.isNull = GlobalValue(globalIsNull) s"$globalIsNull = $localIsNull;" } else { "" @@ -140,7 +140,7 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = newValue + eval.value = VariableValue(newValue) eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -419,7 +419,7 @@ abstract class UnaryExpression extends Expression { boolean ${ev.isNull} = false; ${childGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = LiteralValue("false")) } } } @@ -519,7 +519,7 @@ abstract class BinaryExpression extends Expression { ${leftGen.code} ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = LiteralValue("false")) } } } @@ -663,7 +663,7 @@ abstract class TernaryExpression extends Expression { ${midGen.code} ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = LiteralValue("false")) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 784eaf8195194..c60a4bd560110 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -72,7 +72,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; - $countTerm++;""", isNull = "false") + $countTerm++;""", isNull = LiteralValue("false")) } override def prettyName: String = "monotonically_increasing_id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 736ca37c6d54a..131d255c775bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -45,6 +45,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = ctx.addMutableState(ctx.JAVA_INT, "partitionId") ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") + ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", + isNull = LiteralValue("false")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 41a920ba3d677..d793ce04dffd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -23,7 +23,7 @@ import java.util.{Map => JavaMap} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.language.existentials +import scala.language.{existentials, implicitConversions} import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} @@ -56,7 +56,36 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: String, var value: String) +case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) + + +// An abstraction that represents the evaluation result of [[ExprCode]]. +abstract class ExprValue + +object ExprValue { + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString +} + +// A literal evaluation of [[ExprCode]]. +case class LiteralValue(val value: String) extends ExprValue { + override def toString: String = value +} + +// A variable evaluation of [[ExprCode]]. +case class VariableValue(val variableName: String) extends ExprValue { + override def toString: String = variableName +} + +// A statement evaluation of [[ExprCode]]. +case class StatementValue(val statement: String) extends ExprValue { + override def toString: String = statement +} + +// A global variable evaluation of [[ExprCode]]. +case class GlobalValue(val value: String) extends ExprValue { + override def toString: String = value +} + /** * State used for subexpression elimination. @@ -66,7 +95,7 @@ case class ExprCode(var code: String, var isNull: String, var value: String) * @param value A term for a value of a common sub-expression. Not valid if `isNull` * is set to `true`. */ -case class SubExprEliminationState(isNull: String, value: String) +case class SubExprEliminationState(isNull: ExprValue, value: ExprValue) /** * Codes and common subexpressions mapping used for subexpression elimination. @@ -264,7 +293,7 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, "false", value) + ExprCode(code, LiteralValue("false"), GlobalValue(value)) } def declareMutableStates(): String = { @@ -1144,7 +1173,7 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) + val state = SubExprEliminationState(GlobalValue(isNull), GlobalValue(value)) e.foreach(subExprEliminationExprs.put(_, state)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 0322d1dd6a9ff..fed01ee51df61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -58,7 +58,7 @@ trait CodegenFallback extends Expression { $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; - """, isNull = "false") + """, isNull = LiteralValue("false")) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b53c0087e7e2d..21c9c605ac607 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -59,7 +59,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) - val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { + val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) val value = ctx.addMutableState(ctx.javaType(e.dataType), "value") @@ -69,7 +69,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, isNull, value, i) + """.stripMargin, GlobalValue(isNull), value, i) } else { (s""" |${ev.code} @@ -83,7 +83,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(projectionCodes).map { case (e, (_, isNull, value, i)) => - val ev = ExprCode("", isNull, value) + val ev = ExprCode("", isNull, GlobalValue(value)) ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 3dcbb518ba42a..c6b9f27fbb7c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -53,7 +53,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt) + val converter = convertToSafe(ctx, StatementValue(ctx.getValue(tmpInput, dt, i.toString)), dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -74,7 +74,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, "false", output) + ExprCode(code, LiteralValue("false"), VariableValue(output)) } private def createCodeForArray( @@ -90,7 +90,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val arrayClass = classOf[GenericArrayData].getName val elementConverter = convertToSafe( - ctx, ctx.getValue(tmpInput, elementType, index), elementType) + ctx, StatementValue(ctx.getValue(tmpInput, elementType, index)), elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -104,7 +104,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, "false", output) + ExprCode(code, LiteralValue("false"), VariableValue(output)) } private def createCodeForMap( @@ -125,19 +125,19 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, "false", output) + ExprCode(code, LiteralValue("false"), VariableValue(output)) } @tailrec private def convertToSafe( ctx: CodegenContext, - input: String, + input: ExprValue, dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => ExprCode("", "false", input) + case _ => ExprCode("", LiteralValue("false"), input) } protected def create(expressions: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 36ffa8dcdd2b6..5b1219d77f44f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,7 +52,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString)) + ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)"), + StatementValue(ctx.getValue(tmpInput, dt, i.toString))) } s""" @@ -347,7 +348,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $writeExpressions $updateRowSize """ - ExprCode(code, "false", result) + ExprCode(code, LiteralValue("false"), GlobalValue(result)) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4270b987d6de0..231077358e94a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -55,7 +55,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType boolean ${ev.isNull} = false; ${childGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : - (${childGen.value}).numElements();""", isNull = "false") + (${childGen.value}).numElements();""", isNull = LiteralValue("false")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 3dc2ee03a86e3..ce83bc6493f10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,8 +64,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + assigns + postprocess, - value = arrayData, - isNull = "false") + value = VariableValue(arrayData), + isNull = LiteralValue("false")) } override def prettyName: String = "array" @@ -378,7 +378,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); |$values = null; - """.stripMargin, isNull = "false") + """.stripMargin, isNull = LiteralValue("false")) } override def prettyName: String = "named_struct" @@ -394,7 +394,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) - ExprCode(code = eval.code, isNull = "false", value = eval.value) + ExprCode(code = eval.code, isNull = LiteralValue("false"), value = eval.value) } override def prettyName: String = "named_struct_unsafe" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index cfec7f82951a7..6bf0ec49d492f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -24,7 +24,7 @@ import java.util.{Calendar, TimeZone} import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -673,7 +673,7 @@ abstract class UnixTime case StringType if right.foldable => val df = classOf[DateFormat].getName if (formatter == null) { - ExprCode("", "true", ctx.defaultValue(dataType)) + ExprCode("", LiteralValue("true"), LiteralValue(ctx.defaultValue(dataType))) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) @@ -808,7 +808,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { - ExprCode("", "true", "(UTF8String) null") + ExprCode("", LiteralValue("true"), LiteralValue("(UTF8String) null")) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 1cd73a92a8635..d1308e23aa9f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -218,7 +218,7 @@ case class Stack(children: Seq[Expression]) extends Generator { s"$wrapperClass", ev.value, v => s"$v = $wrapperClass$$.MODULE$$.make($rowData);", useFreshName = false) - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = LiteralValue("false")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 055ebf6c0da54..1461f283f8eb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -269,7 +269,7 @@ abstract class HashExpression[E] extends Expression { protected def computeHash(value: Any, dataType: DataType, seed: E): E override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = LiteralValue("false") val childrenHash = children.map { child => val childGen = child.genCode(ctx) @@ -632,7 +632,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = LiteralValue("false") val childHash = ctx.freshName("childHash") val childrenHash = children.map { child => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 7a8edabed1757..15fc341867120 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue} import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -43,7 +43,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = "false") + s"$className.getInputFilePath();", isNull = LiteralValue("false")) } } @@ -66,7 +66,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = "false") + s"$className.getStartOffset();", isNull = LiteralValue("false")) } } @@ -89,6 +89,6 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = "false") + s"$className.getLength();", isNull = LiteralValue("false")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 383203a209833..a781461d291c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -280,13 +280,13 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { val javaType = ctx.javaType(dataType) // change the isNull and primitive to consts, to inline them if (value == null) { - ev.isNull = "true" + ev.isNull = LiteralValue("true") ev.copy(s"final $javaType ${ev.value} = ${ctx.defaultValue(dataType)};") } else { - ev.isNull = "false" + ev.isNull = LiteralValue("false") dataType match { case BooleanType | IntegerType | DateType => - ev.copy(code = "", value = value.toString) + ev.copy(code = "", value = LiteralValue(value.toString)) case FloatType => val v = value.asInstanceOf[Float] if (v.isNaN || v.isInfinite) { @@ -294,7 +294,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" ev.copy(code = code) } else { - ev.copy(code = "", value = s"${value}f") + ev.copy(code = "", value = LiteralValue(s"${value}f")) } case DoubleType => val v = value.asInstanceOf[Double] @@ -303,15 +303,15 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" ev.copy(code = code) } else { - ev.copy(code = "", value = s"${value}D") + ev.copy(code = "", value = LiteralValue(s"${value}D")) } case ByteType | ShortType => - ev.copy(code = "", value = s"($javaType)$value") + ev.copy(code = "", value = LiteralValue(s"($javaType)$value")) case TimestampType | LongType => - ev.copy(code = "", value = s"${value}L") + ev.copy(code = "", value = LiteralValue(s"${value}L")) case _ => - ev.copy(code = "", value = ctx.addReferenceObj("literal", value, - ctx.javaType(dataType))) + ev.copy(code = "", value = GlobalValue(ctx.addReferenceObj("literal", value, + ctx.javaType(dataType)))) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4b9006ab5b423..0feea6197de8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -85,7 +85,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); - |}""".stripMargin, isNull = "true", value = "null") + |}""".stripMargin, isNull = LiteralValue("true"), value = LiteralValue("null")) } override def sql: String = s"assert_true(${child.sql})" @@ -129,6 +129,7 @@ case class Uuid() extends LeafExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = s"final UTF8String ${ev.value} = " + - s"UTF8String.fromString(java.util.UUID.randomUUID().toString());", isNull = "false") + s"UTF8String.fromString(java.util.UUID.randomUUID().toString());", + isNull = LiteralValue("false")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index b4f895fffda38..0bbd938b217f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -236,7 +236,8 @@ case class IsNaN(child: Expression) extends UnaryExpression ev.copy(code = s""" ${eval.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false") + ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", + isNull = LiteralValue("false")) } } } @@ -321,7 +322,12 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = eval.isNull) + val value = if ("true" == eval.isNull || "false" == eval.isNull) { + LiteralValue(eval.isNull) + } else { + VariableValue(eval.isNull) + } + ExprCode(code = eval.code, isNull = LiteralValue("false"), value = value) } override def sql: String = s"(${child.sql} IS NULL)" @@ -347,7 +353,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = s"(!(${eval.isNull}))") + val value = if ("true" == eval.isNull) { + LiteralValue("false") + } else if ("false" == eval.isNull) { + LiteralValue("true") + } else { + StatementValue(s"(!(${eval.isNull}))") + } + ExprCode(code = eval.code, isNull = LiteralValue("false"), value = value) } override def sql: String = s"(${child.sql} IS NOT NULL)" @@ -442,6 +455,6 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate | $codes |} while (false); |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; - """.stripMargin, isNull = "false") + """.stripMargin, isNull = LiteralValue("false")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a59aad5be8715..a44dead124e3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -59,13 +59,13 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param ctx a [[CodegenContext]] * @return (code to prepare arguments, argument string, result of argument null check) */ - def prepareArguments(ctx: CodegenContext): (String, String, String) = { + def prepareArguments(ctx: CodegenContext): (String, String, ExprValue) = { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull") - resultIsNull + GlobalValue(resultIsNull) } else { - "false" + LiteralValue("false") } val argValues = arguments.map { e => val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue") @@ -146,7 +146,7 @@ case class StaticInvoke( val prepareIsNull = if (nullable) { s"boolean ${ev.isNull} = $resultIsNull;" } else { - ev.isNull = "false" + ev.isNull = LiteralValue("false") "" } @@ -427,7 +427,7 @@ case class WrapOption(child: Expression, optType: DataType) ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = LiteralValue("false")) } } @@ -443,7 +443,8 @@ case class LambdaVariable( with Unevaluable with NonSQLExpression { override def genCode(ctx: CodegenContext): ExprCode = { - ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") + ExprCode(code = "", value = VariableValue(value), + isNull = if (nullable) VariableValue(isNull) else LiteralValue("false")) } } @@ -634,7 +635,7 @@ case class MapObjects private( // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" - val genFunctionValue = lambdaFunction.dataType match { + val genFunctionValue: String = lambdaFunction.dataType match { case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) @@ -1131,7 +1132,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); """.stripMargin - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = LiteralValue("false")) } } @@ -1317,7 +1318,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) throw new NullPointerException($errMsgField); } """ - ev.copy(code = code, isNull = "false", value = childGen.value) + ev.copy(code = code, isNull = LiteralValue("false"), value = childGen.value) } } @@ -1360,7 +1361,7 @@ case class GetExternalRowField( final Object ${ev.value} = ${row.value}.get($index); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = LiteralValue("false")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ac9f56f78eb2e..5f55ceff10eb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, LiteralValue, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -404,7 +404,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with if (${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = LiteralValue("false")) } else { ev.copy(code = s""" ${eval1.code} @@ -460,7 +460,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.isNull = "false" + ev.isNull = LiteralValue("false") ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = true; @@ -468,7 +468,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P if (!${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = LiteralValue("false")) } else { ev.copy(code = s""" ${eval1.code} @@ -614,7 +614,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || - (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = "false") + (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = LiteralValue("false")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8bc936fcbfc31..7dd3a231514e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -82,7 +82,8 @@ case class Rand(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") + final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", + isNull = LiteralValue("false")) } } @@ -116,7 +117,8 @@ case class Randn(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") + final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", + isNull = LiteralValue("false")) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 782cec5e292ba..8a6fce3876c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue, VariableValue} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType @@ -48,7 +48,11 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { nullable: Boolean): ExprCode = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(columnVar, dataType, ordinal) - val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } + val isNullVar = if (nullable) { + VariableValue(ctx.freshName("isNull")) + } else { + LiteralValue("false") + } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { @@ -59,7 +63,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } else { s"$javaType $valueVar = $value;" }).trim - ExprCode(code, isNullVar, valueVar) + ExprCode(code, isNullVar, VariableValue(valueVar)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index a7bd5ebf93ecd..350d6622aaaa5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, VariableValue} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -156,7 +156,7 @@ case class ExpandExec( |boolean $isNull = true; |${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull), VariableValue(value)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index e1562befe14f9..5b8b119cdf102 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} @@ -169,9 +169,9 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", s"$index == -1", index)) + Seq(ExprCode("", StatementValue(s"$index == -1"), VariableValue(index))) } else { - Seq(ExprCode("", "false", index)) + Seq(ExprCode("", LiteralValue("false"), VariableValue(index))) } } else { Seq.empty @@ -314,9 +314,9 @@ case class GenerateExec( |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull), VariableValue(value)) } else { - ExprCode(s"$javaType $value = $getter;", "false", value) + ExprCode(s"$javaType $value = $getter;", LiteralValue("false"), VariableValue(value)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 9e7008d1e0c31..1d0ca752f9682 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -127,7 +127,7 @@ trait CodegenSupport extends SparkPlan { } val rowVar = if (row != null) { - ExprCode("", "false", row) + ExprCode("", LiteralValue("false"), VariableValue(row)) } else { if (outputVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => @@ -142,10 +142,10 @@ trait CodegenSupport extends SparkPlan { |$evaluateInputs |${ev.code.trim} """.stripMargin.trim - ExprCode(code, "false", ev.value) + ExprCode(code, LiteralValue("false"), ev.value) } else { // There is no columns - ExprCode("", "false", "unsafeRow") + ExprCode("", LiteralValue("false"), VariableValue("unsafeRow")) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index b1af360d85095..0f519fdbb3bc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,7 +194,7 @@ case class HashAggregateExec( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, GlobalValue(isNull), GlobalValue(value)) } val initBufVar = evaluateVariables(bufVars) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 1c613b19c4ab1..bad106bf29677 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GlobalValue} import org.apache.spark.sql.types._ /** @@ -54,7 +54,7 @@ abstract class HashMapGenerator( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, GlobalValue(isNull), GlobalValue(value)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 78137d3f97cfc..ccc601846958c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -192,7 +192,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => if (notNullAttributes.contains(child.output(i).exprId)) { - ev.isNull = "false" + ev.isNull = LiteralValue("false") } ev } @@ -368,7 +368,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val number = ctx.addMutableState(ctx.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", "false", value) + val ev = ExprCode("", LiteralValue("false"), VariableValue(value)) val BigInt = classOf[java.math.BigInteger].getName // inline mutable state since not many Range operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index ee763e23415cf..26d7060e8f112 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} @@ -191,7 +191,7 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull), VariableValue(value)) } } } @@ -486,7 +486,7 @@ case class BroadcastHashJoinExec( s"$existsVar = true;" } - val resultVar = input ++ Seq(ExprCode("", "false", existsVar)) + val resultVar = input ++ Seq(ExprCode("", LiteralValue("false"), VariableValue(existsVar))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 073730462a75f..be0df66eb4823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, @@ -530,11 +530,11 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, isNull, value), leftVarsDecl) + (ExprCode(code, VariableValue(isNull), VariableValue(value)), leftVarsDecl) } else { val code = s"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" - (ExprCode(code, "false", value), leftVarsDecl) + (ExprCode(code, LiteralValue("false"), VariableValue(value)), leftVarsDecl) } }.unzip } From 81c9b6e73ee64adcd8fc931d51f3faa98b892e0b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 21 Dec 2017 07:27:49 +0000 Subject: [PATCH 2/9] Fix equality check with string. --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 2 +- .../spark/sql/catalyst/expressions/nullExpressions.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 932dd18404352..b82028c64d3e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -118,7 +118,7 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { - val setIsNull = if ("false" != eval.isNull && "true" != eval.isNull) { + val setIsNull = if ("false" != s"${eval.isNull}" && "true" != s"${eval.isNull}") { val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull eval.isNull = GlobalValue(globalIsNull) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0bbd938b217f6..87ba37ea85f51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -322,7 +322,7 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = if ("true" == eval.isNull || "false" == eval.isNull) { + val value = if ("true" == s"${eval.isNull}" || "false" == s"${eval.isNull}") { LiteralValue(eval.isNull) } else { VariableValue(eval.isNull) @@ -353,9 +353,9 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = if ("true" == eval.isNull) { + val value = if ("true" == s"${eval.isNull}") { LiteralValue("false") - } else if ("false" == eval.isNull) { + } else if ("false" == s"${eval.isNull}") { LiteralValue("true") } else { StatementValue(s"(!(${eval.isNull}))") From 53926ccb0795c09a90ba70a5c7862ec1cb126391 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 21 Dec 2017 23:46:55 +0000 Subject: [PATCH 3/9] Address comments. --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 2 +- .../spark/sql/catalyst/expressions/nullExpressions.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b82028c64d3e0..44996114e5058 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -118,7 +118,7 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { - val setIsNull = if ("false" != s"${eval.isNull}" && "true" != s"${eval.isNull}") { + val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull eval.isNull = GlobalValue(globalIsNull) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 87ba37ea85f51..e2691e98fda55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -322,7 +322,7 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = if ("true" == s"${eval.isNull}" || "false" == s"${eval.isNull}") { + val value = if (eval.isNull.isInstanceOf[LiteralValue]) { LiteralValue(eval.isNull) } else { VariableValue(eval.isNull) @@ -353,9 +353,9 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = if ("true" == s"${eval.isNull}") { + val value = if (eval.isNull == LiteralValue("true")) { LiteralValue("false") - } else if ("false" == s"${eval.isNull}") { + } else if (eval.isNull == LiteralValue("false")) { LiteralValue("true") } else { StatementValue(s"(!(${eval.isNull}))") From 4384c84bd039addab34ba126447462ea87e13734 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 22 Dec 2017 01:27:11 +0000 Subject: [PATCH 4/9] Fix merging. --- .../apache/spark/sql/catalyst/expressions/arithmetic.scala | 4 ++-- .../sql/catalyst/expressions/conditionalExpressions.scala | 2 +- .../spark/sql/catalyst/expressions/nullExpressions.scala | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8bb14598a6d7b..d530032ca26a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,7 +602,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -681,7 +681,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 142dfb02be0a8..d00a7fff70bc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -190,7 +190,7 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value) + ev.value = GlobalValue(ctx.addMutableState(ctx.javaType(dataType), ev.value)) // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 058a1566ba701..d1226a8ac95bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => From 8715d32af9a3c834647489a2333845fb72cd45a7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 21 Feb 2018 09:58:36 +0000 Subject: [PATCH 5/9] Add test. --- .../expressions/codegen/ExprValueSuite.scala | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala new file mode 100644 index 0000000000000..1f68fd28d8dd7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite + +class ExprValueSuite extends SparkFunSuite { + + test("TrueLiteral and FalseLiteral should be LiteralValue") { + val trueLit = TrueLiteral + val falseLit = FalseLiteral + + assert(trueLit.value == "true") + assert(falseLit.value == "false") + + trueLit match { + case LiteralValue(value) => assert(value == "true") + case _ => fail() + } + + falseLit match { + case LiteralValue(value) => assert(value == "false") + case _ => fail() + } + } +} From f59bb19a3fd04b24ea3077a12283777be0af437d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 28 Feb 2018 06:01:53 +0000 Subject: [PATCH 6/9] Use TrueLiteral/FalseLiteral. Add java type and access property to ExprValue. --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 14 ++-- .../MonotonicallyIncreasingID.scala | 4 +- .../expressions/SparkPartitionID.scala | 4 +- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 44 ++-------- .../expressions/codegen/CodegenFallback.scala | 2 +- .../expressions/codegen/ExprValue.scala | 82 +++++++++++++++++++ .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateSafeProjection.scala | 14 ++-- .../codegen/GenerateUnsafeProjection.scala | 6 +- .../expressions/collectionOperations.scala | 2 +- .../expressions/complexTypeCreator.scala | 8 +- .../expressions/conditionalExpressions.scala | 3 +- .../expressions/datetimeExpressions.scala | 5 +- .../sql/catalyst/expressions/generators.scala | 2 +- .../spark/sql/catalyst/expressions/hash.scala | 4 +- .../catalyst/expressions/inputFileBlock.scala | 8 +- .../sql/catalyst/expressions/literals.scala | 27 +++--- .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/nullExpressions.scala | 25 +++--- .../expressions/objects/objects.scala | 22 +++-- .../sql/catalyst/expressions/predicates.scala | 10 +-- .../expressions/randomExpressions.scala | 6 +- .../expressions/codegen/ExprValueSuite.scala | 6 +- .../sql/execution/ColumnarBatchScan.scala | 8 +- .../spark/sql/execution/ExpandExec.scala | 5 +- .../spark/sql/execution/GenerateExec.scala | 11 ++- .../sql/execution/WholeStageCodegenExec.scala | 13 +-- .../aggregate/HashAggregateExec.scala | 3 +- .../aggregate/HashMapGenerator.scala | 5 +- .../execution/basicPhysicalOperators.scala | 4 +- .../joins/BroadcastHashJoinExec.scala | 6 +- .../execution/joins/SortMergeJoinExec.scala | 6 +- 34 files changed, 227 insertions(+), 151 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 74fbe634ae5b7..5a7f6d9feb5f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ /** @@ -75,7 +75,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = LiteralValue("false")) + ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 44996114e5058..0285492948c8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,7 +104,9 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode("", VariableValue(isNull), VariableValue(value))) + val eval = doGenCode(ctx, ExprCode("", + VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + VariableValue(value, ExprType(ctx, dataType)))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -121,7 +123,7 @@ abstract class Expression extends TreeNode[Expression] { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = GlobalValue(globalIsNull) + eval.isNull = GlobalValue(globalIsNull, ExprType(ctx.JAVA_BOOLEAN, true)) s"$globalIsNull = $localIsNull;" } else { "" @@ -140,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = VariableValue(newValue) + eval.value = VariableValue(newValue, ExprType(ctx, dataType)) eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -419,7 +421,7 @@ abstract class UnaryExpression extends Expression { boolean ${ev.isNull} = false; ${childGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode""", isNull = LiteralValue("false")) + $resultCode""", isNull = FalseLiteral) } } } @@ -519,7 +521,7 @@ abstract class BinaryExpression extends Expression { ${leftGen.code} ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode""", isNull = LiteralValue("false")) + $resultCode""", isNull = FalseLiteral) } } } @@ -663,7 +665,7 @@ abstract class TernaryExpression extends Expression { ${midGen.code} ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode""", isNull = LiteralValue("false")) + $resultCode""", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 2b2f2b8415e70..06eb5da9ae298 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -73,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; - $countTerm++;""", isNull = LiteralValue("false")) + $countTerm++;""", isNull = FalseLiteral) } override def prettyName: String = "monotonically_increasing_id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index d329addbe822f..23953ce35e639 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -47,6 +47,6 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", - isNull = LiteralValue("false")) + isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d530032ca26a4..53e6ff95c786d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,7 +602,8 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)) + ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull), + ExprType(ctx.JAVA_BOOLEAN, true)) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -681,7 +682,8 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)) + ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull), + ExprType(ctx.JAVA_BOOLEAN, true)) val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 6a33be9c14192..a4ffe7f37601e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -23,7 +23,7 @@ import java.util.{Map => JavaMap} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.language.{existentials, implicitConversions} +import scala.language.existentials import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} @@ -58,42 +58,6 @@ import org.apache.spark.util.{ParentClassLoader, Utils} */ case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) - -// An abstraction that represents the evaluation result of [[ExprCode]]. -abstract class ExprValue - -object ExprValue { - implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString -} - -// A literal evaluation of [[ExprCode]]. -class LiteralValue(val value: String) extends ExprValue { - override def toString: String = value -} - -object LiteralValue { - def apply(value: String): LiteralValue = new LiteralValue(value) - def unapply(literal: LiteralValue): Option[String] = Some(literal.value) -} - -// A variable evaluation of [[ExprCode]]. -case class VariableValue(val variableName: String) extends ExprValue { - override def toString: String = variableName -} - -// A statement evaluation of [[ExprCode]]. -case class StatementValue(val statement: String) extends ExprValue { - override def toString: String = statement -} - -// A global variable evaluation of [[ExprCode]]. -case class GlobalValue(val value: String) extends ExprValue { - override def toString: String = value -} - -case object TrueLiteral extends LiteralValue("true") -case object FalseLiteral extends LiteralValue("false") - object ExprCode { def forNonNullValue(value: ExprValue): ExprCode = { ExprCode(code = "", isNull = FalseLiteral, value = value) @@ -359,7 +323,8 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, LiteralValue("false"), GlobalValue(value)) + ExprCode(code, FalseLiteral, + GlobalValue(value, ExprType(this, dataType))) } def declareMutableStates(): String = { @@ -1244,7 +1209,8 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(GlobalValue(isNull), GlobalValue(value)) + val state = SubExprEliminationState(GlobalValue(isNull, ExprType(JAVA_BOOLEAN, true)), + GlobalValue(value, ExprType(this, expr.dataType))) e.foreach(subExprEliminationExprs.put(_, state)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index fed01ee51df61..bebdfd183cf00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -58,7 +58,7 @@ trait CodegenFallback extends Expression { $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; - """, isNull = LiteralValue("false")) + """, isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala new file mode 100644 index 0000000000000..bcf1baa5569d1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.language.implicitConversions + +import org.apache.spark.sql.types.DataType + +// An abstraction that represents the evaluation result of [[ExprCode]]. +abstract class ExprValue { + + val javaType: ExprType + + // Whether we can directly access the evaluation value anywhere. + // For example, a variable created outside a method can not be accessed inside the method. + // For such cases, we may need to pass the evaluation as parameter. + val canDirectAccess: Boolean +} + +object ExprValue { + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString +} + +// A literal evaluation of [[ExprCode]]. +class LiteralValue(val value: String, val javaType: ExprType) extends ExprValue { + override def toString: String = value + override val canDirectAccess: Boolean = true +} + +object LiteralValue { + def apply(value: String, javaType: ExprType): LiteralValue = new LiteralValue(value, javaType) + def unapply(literal: LiteralValue): Option[(String, ExprType)] = + Some((literal.value, literal.javaType)) +} + +// A variable evaluation of [[ExprCode]]. +case class VariableValue( + val variableName: String, + val javaType: ExprType, + val canDirectAccess: Boolean = false) extends ExprValue { + override def toString: String = variableName +} + +// A statement evaluation of [[ExprCode]]. +case class StatementValue( + val statement: String, + val javaType: ExprType, + val canDirectAccess: Boolean = false) extends ExprValue { + override def toString: String = statement +} + +// A global variable evaluation of [[ExprCode]]. +case class GlobalValue(val value: String, val javaType: ExprType) extends ExprValue { + override def toString: String = value + override val canDirectAccess: Boolean = true +} + +case object TrueLiteral extends LiteralValue("true", ExprType("boolean", true)) +case object FalseLiteral extends LiteralValue("false", ExprType("boolean", true)) + +// Represents the java type of an evaluation. +case class ExprType(val typeName: String, val isPrimitive: Boolean) + +object ExprType { + def apply(ctx: CodegenContext, dataType: DataType): ExprType = ExprType(ctx.javaType(dataType), + ctx.isPrimitiveType(dataType)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 21c9c605ac607..9082997469b56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -69,7 +69,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, GlobalValue(isNull), value, i) + """.stripMargin, GlobalValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), value, i) } else { (s""" |${ev.code} @@ -83,7 +83,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(projectionCodes).map { case (e, (_, isNull, value, i)) => - val ev = ExprCode("", isNull, GlobalValue(value)) + val ev = ExprCode("", isNull, GlobalValue(value, ExprType(ctx, e.dataType))) ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index c6b9f27fbb7c9..3110fd701c2e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -53,7 +53,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, StatementValue(ctx.getValue(tmpInput, dt, i.toString)), dt) + val converter = convertToSafe(ctx, StatementValue(ctx.getValue(tmpInput, dt, i.toString), + ExprType(ctx, dt)), dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -74,7 +75,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, LiteralValue("false"), VariableValue(output)) + ExprCode(code, FalseLiteral, VariableValue(output, ExprType("InternalRow", false))) } private def createCodeForArray( @@ -90,7 +91,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val arrayClass = classOf[GenericArrayData].getName val elementConverter = convertToSafe( - ctx, StatementValue(ctx.getValue(tmpInput, elementType, index)), elementType) + ctx, StatementValue(ctx.getValue(tmpInput, elementType, index), ExprType(ctx, elementType)), + elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -104,7 +106,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, LiteralValue("false"), VariableValue(output)) + ExprCode(code, FalseLiteral, VariableValue(output, ExprType("ArrayData", false))) } private def createCodeForMap( @@ -125,7 +127,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, LiteralValue("false"), VariableValue(output)) + ExprCode(code, FalseLiteral, VariableValue(output, ExprType("MapData", false))) } @tailrec @@ -137,7 +139,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => ExprCode("", LiteralValue("false"), input) + case _ => ExprCode("", FalseLiteral, input) } protected def create(expressions: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 5b1219d77f44f..07a38d9c7cd66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,8 +52,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)"), - StatementValue(ctx.getValue(tmpInput, dt, i.toString))) + ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", ExprType(ctx.JAVA_BOOLEAN, true)), + StatementValue(ctx.getValue(tmpInput, dt, i.toString), ExprType(ctx, dt))) } s""" @@ -348,7 +348,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $writeExpressions $updateRowSize """ - ExprCode(code, LiteralValue("false"), GlobalValue(result)) + ExprCode(code, FalseLiteral, GlobalValue(result, ExprType("UnsafeRow", false))) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 231077358e94a..ded302b20db18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -55,7 +55,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType boolean ${ev.isNull} = false; ${childGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : - (${childGen.value}).numElements();""", isNull = LiteralValue("false")) + (${childGen.value}).numElements();""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index caa9d3ba37406..b35ca20618c6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,8 +64,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + assigns + postprocess, - value = VariableValue(arrayData), - isNull = LiteralValue("false")) + value = VariableValue(arrayData, ExprType(ctx, dataType)), + isNull = FalseLiteral) } override def prettyName: String = "array" @@ -378,7 +378,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); |$values = null; - """.stripMargin, isNull = LiteralValue("false")) + """.stripMargin, isNull = FalseLiteral) } override def prettyName: String = "named_struct" @@ -394,7 +394,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) - ExprCode(code = eval.code, isNull = LiteralValue("false"), value = eval.value) + ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value) } override def prettyName: String = "named_struct_unsafe" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index b7df9757aa7b4..d893b2f5ab158 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -191,7 +191,8 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = GlobalValue(ctx.addMutableState(ctx.javaType(dataType), ev.value)) + ev.value = GlobalValue(ctx.addMutableState(ctx.javaType(dataType), ev.value), + ExprType(ctx, dataType)) // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 2fe541e730049..75d11db52b3f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -677,7 +677,8 @@ abstract class UnixTime case StringType if right.foldable => val df = classOf[DateFormat].getName if (formatter == null) { - ExprCode("", LiteralValue("true"), LiteralValue(ctx.defaultValue(dataType))) + ExprCode("", TrueLiteral, LiteralValue(ctx.defaultValue(dataType), + ExprType(ctx, dataType))) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) @@ -812,7 +813,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { - ExprCode("", LiteralValue("true"), LiteralValue("(UTF8String) null")) + ExprCode("", TrueLiteral, LiteralValue("(UTF8String) null", ExprType(ctx, dataType))) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 5cb0307a76960..3af4bfebad45e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -218,7 +218,7 @@ case class Stack(children: Seq[Expression]) extends Generator { s""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); - """.stripMargin, isNull = LiteralValue("false")) + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 1461f283f8eb3..85b44d136064e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -269,7 +269,7 @@ abstract class HashExpression[E] extends Expression { protected def computeHash(value: Any, dataType: DataType, seed: E): E override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = LiteralValue("false") + ev.isNull = FalseLiteral val childrenHash = children.map { child => val childGen = child.genCode(ctx) @@ -632,7 +632,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = LiteralValue("false") + ev.isNull = FalseLiteral val childHash = ctx.freshName("childHash") val childrenHash = children.map { child => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 15fc341867120..13793ff1ef032 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -43,7 +43,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = LiteralValue("false")) + s"$className.getInputFilePath();", isNull = FalseLiteral) } } @@ -66,7 +66,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = LiteralValue("false")) + s"$className.getStartOffset();", isNull = FalseLiteral) } } @@ -89,6 +89,6 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = LiteralValue("false")) + s"$className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index ea107bee31997..bebc86a5bf4a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -278,45 +278,46 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) + val exprType = ExprType(ctx, dataType) if (value == null) { val defaultValueLiteral = ctx.defaultValue(javaType) match { case "null" => s"(($javaType)null)" case lit => lit } - ExprCode(code = "", isNull = LiteralValue("true"), value = LiteralValue(defaultValueLiteral)) + ExprCode(code = "", isNull = TrueLiteral, value = LiteralValue(defaultValueLiteral, exprType)) } else { dataType match { case BooleanType | IntegerType | DateType => - ExprCode.forNonNullValue(LiteralValue(value.toString)) + ExprCode.forNonNullValue(LiteralValue(value.toString, exprType)) case FloatType => value.asInstanceOf[Float] match { case v if v.isNaN => - ExprCode.forNonNullValue(LiteralValue("Float.NaN")) + ExprCode.forNonNullValue(LiteralValue("Float.NaN", exprType)) case Float.PositiveInfinity => - ExprCode.forNonNullValue(LiteralValue("Float.POSITIVE_INFINITY")) + ExprCode.forNonNullValue(LiteralValue("Float.POSITIVE_INFINITY", exprType)) case Float.NegativeInfinity => - ExprCode.forNonNullValue(LiteralValue("Float.NEGATIVE_INFINITY")) + ExprCode.forNonNullValue(LiteralValue("Float.NEGATIVE_INFINITY", exprType)) case _ => - ExprCode.forNonNullValue(LiteralValue(s"${value}F")) + ExprCode.forNonNullValue(LiteralValue(s"${value}F", exprType)) } case DoubleType => value.asInstanceOf[Double] match { case v if v.isNaN => - ExprCode.forNonNullValue(LiteralValue("Double.NaN")) + ExprCode.forNonNullValue(LiteralValue("Double.NaN", exprType)) case Double.PositiveInfinity => - ExprCode.forNonNullValue(LiteralValue("Double.POSITIVE_INFINITY")) + ExprCode.forNonNullValue(LiteralValue("Double.POSITIVE_INFINITY", exprType)) case Double.NegativeInfinity => - ExprCode.forNonNullValue(LiteralValue("Double.NEGATIVE_INFINITY")) + ExprCode.forNonNullValue(LiteralValue("Double.NEGATIVE_INFINITY", exprType)) case _ => - ExprCode.forNonNullValue(LiteralValue(s"${value}D")) + ExprCode.forNonNullValue(LiteralValue(s"${value}D", exprType)) } case ByteType | ShortType => - ExprCode.forNonNullValue(LiteralValue(s"($javaType)$value")) + ExprCode.forNonNullValue(LiteralValue(s"($javaType)$value", exprType)) case TimestampType | LongType => - ExprCode.forNonNullValue(LiteralValue(s"${value}L")) + ExprCode.forNonNullValue(LiteralValue(s"${value}L", exprType)) case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) - ExprCode.forNonNullValue(GlobalValue(constRef)) + ExprCode.forNonNullValue(GlobalValue(constRef, exprType)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 0feea6197de8a..d9bacd07a9286 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -85,7 +85,8 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); - |}""".stripMargin, isNull = LiteralValue("true"), value = LiteralValue("null")) + |}""".stripMargin, isNull = TrueLiteral, + value = LiteralValue("null", ExprType(ctx, dataType))) } override def sql: String = s"assert_true(${child.sql})" @@ -130,6 +131,6 @@ case class Uuid() extends LeafExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = s"final UTF8String ${ev.value} = " + s"UTF8String.fromString(java.util.UUID.randomUUID().toString());", - isNull = LiteralValue("false")) + isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index d1226a8ac95bd..0e326640664a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,7 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)) + ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull), + ExprType(ctx.JAVA_BOOLEAN, true)) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -236,7 +237,7 @@ case class IsNaN(child: Expression) extends UnaryExpression ${eval.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", - isNull = LiteralValue("false")) + isNull = FalseLiteral) } } } @@ -322,11 +323,11 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val value = if (eval.isNull.isInstanceOf[LiteralValue]) { - LiteralValue(eval.isNull) + LiteralValue(eval.isNull, ExprType(ctx.JAVA_BOOLEAN, true)) } else { - VariableValue(eval.isNull) + VariableValue(eval.isNull, ExprType(ctx.JAVA_BOOLEAN, true)) } - ExprCode(code = eval.code, isNull = LiteralValue("false"), value = value) + ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NULL)" @@ -352,14 +353,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = if (eval.isNull == LiteralValue("true")) { - LiteralValue("false") - } else if (eval.isNull == LiteralValue("false")) { - LiteralValue("true") + val value = if (eval.isNull == TrueLiteral) { + FalseLiteral + } else if (eval.isNull == FalseLiteral) { + TrueLiteral } else { - StatementValue(s"(!(${eval.isNull}))") + StatementValue(s"(!(${eval.isNull}))", ExprType(ctx, dataType)) } - ExprCode(code = eval.code, isNull = LiteralValue("false"), value = value) + ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NOT NULL)" @@ -454,6 +455,6 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate | $codes |} while (false); |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; - """.stripMargin, isNull = LiteralValue("false")) + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index da076a22b228c..79212685899b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -63,9 +63,9 @@ trait InvokeLike extends Expression with NonSQLExpression { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull") - GlobalValue(resultIsNull) + GlobalValue(resultIsNull, ExprType(ctx.JAVA_BOOLEAN, true)) } else { - LiteralValue("false") + FalseLiteral } val argValues = arguments.map { e => val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue") @@ -146,7 +146,7 @@ case class StaticInvoke( val prepareIsNull = if (nullable) { s"boolean ${ev.isNull} = $resultIsNull;" } else { - ev.isNull = LiteralValue("false") + ev.isNull = FalseLiteral "" } @@ -428,7 +428,7 @@ case class WrapOption(child: Expression, optType: DataType) ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ - ev.copy(code = code, isNull = LiteralValue("false")) + ev.copy(code = code, isNull = FalseLiteral) } } @@ -444,8 +444,12 @@ case class LambdaVariable( with Unevaluable with NonSQLExpression { override def genCode(ctx: CodegenContext): ExprCode = { - ExprCode(code = "", value = VariableValue(value), - isNull = if (nullable) VariableValue(isNull) else LiteralValue("false")) + val isNullValue = if (nullable) { + VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)) + } else { + FalseLiteral + } + ExprCode(code = "", value = VariableValue(value, ExprType(ctx, dataType)), isNull = isNullValue) } } @@ -1133,7 +1137,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); """.stripMargin - ev.copy(code = code, isNull = LiteralValue("false")) + ev.copy(code = code, isNull = FalseLiteral) } } @@ -1327,7 +1331,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) throw new NullPointerException($errMsgField); } """ - ev.copy(code = code, isNull = LiteralValue("false"), value = childGen.value) + ev.copy(code = code, isNull = FalseLiteral, value = childGen.value) } } @@ -1370,7 +1374,7 @@ case class GetExternalRowField( final Object ${ev.value} = ${row.value}.get($index); """ - ev.copy(code = code, isNull = LiteralValue("false")) + ev.copy(code = code, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f609956c6676e..c31423ea4a078 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, LiteralValue, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -405,7 +405,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with if (${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = LiteralValue("false")) + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -461,7 +461,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.isNull = LiteralValue("false") + ev.isNull = FalseLiteral ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = true; @@ -469,7 +469,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P if (!${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = LiteralValue("false")) + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -615,7 +615,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || - (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = LiteralValue("false")) + (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 7dd3a231514e7..9e23e263cb3c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -83,7 +83,7 @@ case class Rand(child: Expression) extends RDG { s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", - isNull = LiteralValue("false")) + isNull = FalseLiteral) } } @@ -118,7 +118,7 @@ case class Randn(child: Expression) extends RDG { s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", - isNull = LiteralValue("false")) + isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala index 1f68fd28d8dd7..b0fa4b6fc741f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -29,12 +29,14 @@ class ExprValueSuite extends SparkFunSuite { assert(falseLit.value == "false") trueLit match { - case LiteralValue(value) => assert(value == "true") + case LiteralValue(value, javaType) => + assert(value == "true" && javaType.typeName == "boolean" && javaType.isPrimitive) case _ => fail() } falseLit match { - case LiteralValue(value) => assert(value == "false") + case LiteralValue(value, javaType) => + assert(value == "false" && javaType.typeName == "boolean" && javaType.isPrimitive) case _ => fail() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index aefb7d61ac9c1..b37a7383ade70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExprType, FalseLiteral, VariableValue} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -52,9 +52,9 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val javaType = ctx.javaType(dataType) val value = ctx.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { - VariableValue(ctx.freshName("isNull")) + VariableValue(ctx.freshName("isNull"), ExprType(ctx.JAVA_BOOLEAN, true)) } else { - LiteralValue("false") + FalseLiteral } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" @@ -66,7 +66,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } else { s"$javaType $valueVar = $value;" }).trim - ExprCode(code, isNullVar, VariableValue(valueVar)) + ExprCode(code, isNullVar, VariableValue(valueVar, ExprType(ctx, dataType))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 350d6622aaaa5..892b174aaa713 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExprType, VariableValue} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -156,7 +156,8 @@ case class ExpandExec( |boolean $isNull = true; |${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, VariableValue(isNull), VariableValue(value)) + ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + VariableValue(value, ExprType(ctx, firstExpr.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index bd2b6d62c4357..6cb4c6caf4fa9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -170,9 +170,10 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", StatementValue(s"$index == -1"), VariableValue(index))) + Seq(ExprCode("", StatementValue(s"$index == -1", ExprType(ctx.JAVA_BOOLEAN, true)), + VariableValue(index, ExprType(ctx.JAVA_INT, true)))) } else { - Seq(ExprCode("", LiteralValue("false"), VariableValue(index))) + Seq(ExprCode("", FalseLiteral, VariableValue(index, ExprType(ctx.JAVA_INT, true)))) } } else { Seq.empty @@ -315,9 +316,11 @@ case class GenerateExec( |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, VariableValue(isNull), VariableValue(value)) + ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + VariableValue(value, ExprType(ctx, dt))) } else { - ExprCode(s"$javaType $value = $getter;", LiteralValue("false"), VariableValue(value)) + ExprCode(s"$javaType $value = $getter;", FalseLiteral, + VariableValue(value, ExprType(ctx, dt))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index a24b07d69b345..a0f2710952053 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -111,7 +111,7 @@ trait CodegenSupport extends SparkPlan { private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { if (row != null) { - ExprCode("", LiteralValue("false"), VariableValue(row)) + ExprCode("", FalseLiteral, VariableValue(row, ExprType("UnsafeRow", false))) } else { if (colVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => @@ -126,10 +126,10 @@ trait CodegenSupport extends SparkPlan { |$evaluateInputs |${ev.code.trim} """.stripMargin.trim - ExprCode(code, LiteralValue("false"), ev.value) + ExprCode(code, FalseLiteral, ev.value) } else { // There is no columns - ExprCode("", LiteralValue("false"), VariableValue("unsafeRow")) + ExprCode("", FalseLiteral, VariableValue("unsafeRow", ExprType("UnsafeRow", false))) } } } @@ -240,15 +240,16 @@ trait CodegenSupport extends SparkPlan { parameters += s"$paramType $paramName" val paramIsNull = if (!attributes(i).nullable) { // Use constant `false` without passing `isNull` for non-nullable variable. - LiteralValue("false") + FalseLiteral } else { val isNull = ctx.freshName(s"exprIsNull_$i") arguments += ev.isNull parameters += s"boolean $isNull" - VariableValue(isNull) + VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)) } - paramVars += ExprCode("", paramIsNull, VariableValue(paramName)) + paramVars += ExprCode("", paramIsNull, + VariableValue(paramName, ExprType(ctx, attributes(i).dataType))) } (arguments, parameters, paramVars) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4541edd40634c..4c99bfba3e3bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,7 +194,8 @@ case class HashAggregateExec( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull), GlobalValue(value)) + ExprCode(ev.code + initVars, GlobalValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + GlobalValue(value, ExprType(ctx, e.dataType))) } val initBufVar = evaluateVariables(bufVars) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index bad106bf29677..61491bdc784c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GlobalValue} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExprType, GlobalValue} import org.apache.spark.sql.types._ /** @@ -54,7 +54,8 @@ abstract class HashMapGenerator( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull), GlobalValue(value)) + ExprCode(ev.code + initVars, GlobalValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + GlobalValue(value, ExprType(ctx, e.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index eaa43da728c25..a4168a8e0ec89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -192,7 +192,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => if (notNullAttributes.contains(child.output(i).exprId)) { - ev.isNull = LiteralValue("false") + ev.isNull = FalseLiteral } ev } @@ -368,7 +368,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val number = ctx.addMutableState(ctx.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", LiteralValue("false"), VariableValue(value)) + val ev = ExprCode("", FalseLiteral, VariableValue(value, ExprType(ctx.JAVA_LONG, true))) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 559f659d66d29..ec8539f5baa17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -191,7 +191,8 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, VariableValue(isNull), VariableValue(value)) + ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + VariableValue(value, ExprType(ctx, a.dataType))) } } } @@ -486,7 +487,8 @@ case class BroadcastHashJoinExec( s"$existsVar = true;" } - val resultVar = input ++ Seq(ExprCode("", LiteralValue("false"), VariableValue(existsVar))) + val resultVar = input ++ Seq(ExprCode("", FalseLiteral, + VariableValue(existsVar, ExprType(ctx.JAVA_BOOLEAN, true)))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 31af22d64f23f..6618bccb70d03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -531,11 +531,13 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, VariableValue(isNull), VariableValue(value)), leftVarsDecl) + (ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + VariableValue(value, ExprType(ctx, a.dataType))), leftVarsDecl) } else { val code = s"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" - (ExprCode(code, LiteralValue("false"), VariableValue(value)), leftVarsDecl) + (ExprCode(code, FalseLiteral, + VariableValue(value, ExprType(ctx, a.dataType))), leftVarsDecl) } }.unzip } From 37ae9b0e217de323dbc73c9e1247ebe9bf2c278c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 1 Mar 2018 02:22:25 +0000 Subject: [PATCH 7/9] Address comment. --- .../spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 3 +-- .../apache/spark/sql/execution/joins/SortMergeJoinExec.scala | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a4ffe7f37601e..d06936e894d1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -323,8 +323,7 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, FalseLiteral, - GlobalValue(value, ExprType(this, dataType))) + ExprCode(code, FalseLiteral, GlobalValue(value, ExprType(this, dataType))) } def declareMutableStates(): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 6618bccb70d03..7aa9f46277702 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExprType, FalseLiteral, VariableValue} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, From 0841c4afdb1808e3a1281d7612d1954b486c22dd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 1 Mar 2018 12:51:26 +0000 Subject: [PATCH 8/9] Address comment. --- .../sql/catalyst/expressions/Expression.scala | 4 ++-- .../sql/catalyst/expressions/arithmetic.scala | 4 ++-- .../expressions/codegen/CodeGenerator.scala | 2 +- .../expressions/codegen/ExprValue.scala | 17 ++++++++++------- .../codegen/GenerateMutableProjection.scala | 2 +- .../codegen/GenerateSafeProjection.scala | 6 +++--- .../codegen/GenerateUnsafeProjection.scala | 4 ++-- .../catalyst/expressions/nullExpressions.scala | 6 +++--- .../catalyst/expressions/objects/objects.scala | 4 ++-- .../expressions/codegen/ExprValueSuite.scala | 6 ++++-- .../spark/sql/execution/ColumnarBatchScan.scala | 2 +- .../apache/spark/sql/execution/ExpandExec.scala | 2 +- .../spark/sql/execution/GenerateExec.scala | 8 ++++---- .../sql/execution/WholeStageCodegenExec.scala | 6 +++--- .../execution/aggregate/HashAggregateExec.scala | 2 +- .../execution/aggregate/HashMapGenerator.scala | 2 +- .../sql/execution/basicPhysicalOperators.scala | 2 +- .../execution/joins/BroadcastHashJoinExec.scala | 4 ++-- .../sql/execution/joins/SortMergeJoinExec.scala | 2 +- 19 files changed, 45 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 0285492948c8e..d81e87c8591b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -105,7 +105,7 @@ abstract class Expression extends TreeNode[Expression] { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val eval = doGenCode(ctx, ExprCode("", - VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), VariableValue(value, ExprType(ctx, dataType)))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { @@ -123,7 +123,7 @@ abstract class Expression extends TreeNode[Expression] { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = GlobalValue(globalIsNull, ExprType(ctx.JAVA_BOOLEAN, true)) + eval.isNull = GlobalValue(globalIsNull, ExprType(ctx.JAVA_BOOLEAN)) s"$globalIsNull = $localIsNull;" } else { "" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 53e6ff95c786d..06ef86f19b76f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -603,7 +603,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull), - ExprType(ctx.JAVA_BOOLEAN, true)) + ExprType(ctx.JAVA_BOOLEAN)) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -683,7 +683,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull), - ExprType(ctx.JAVA_BOOLEAN, true)) + ExprType(ctx.JAVA_BOOLEAN)) val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d06936e894d1c..7190b8568d7ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1208,7 +1208,7 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(GlobalValue(isNull, ExprType(JAVA_BOOLEAN, true)), + val state = SubExprEliminationState(GlobalValue(isNull, ExprType(JAVA_BOOLEAN)), GlobalValue(value, ExprType(this, expr.dataType))) e.foreach(subExprEliminationExprs.put(_, state)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala index bcf1baa5569d1..d589fde09e0ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala @@ -30,6 +30,8 @@ abstract class ExprValue { // For example, a variable created outside a method can not be accessed inside the method. // For such cases, we may need to pass the evaluation as parameter. val canDirectAccess: Boolean + + def isPrimitive(ctx: CodegenContext): Boolean = javaType.isPrimitive(ctx) } object ExprValue { @@ -51,9 +53,9 @@ object LiteralValue { // A variable evaluation of [[ExprCode]]. case class VariableValue( val variableName: String, - val javaType: ExprType, - val canDirectAccess: Boolean = false) extends ExprValue { + val javaType: ExprType) extends ExprValue { override def toString: String = variableName + override val canDirectAccess: Boolean = false } // A statement evaluation of [[ExprCode]]. @@ -70,13 +72,14 @@ case class GlobalValue(val value: String, val javaType: ExprType) extends ExprVa override val canDirectAccess: Boolean = true } -case object TrueLiteral extends LiteralValue("true", ExprType("boolean", true)) -case object FalseLiteral extends LiteralValue("false", ExprType("boolean", true)) +case object TrueLiteral extends LiteralValue("true", ExprType("boolean")) +case object FalseLiteral extends LiteralValue("false", ExprType("boolean")) // Represents the java type of an evaluation. -case class ExprType(val typeName: String, val isPrimitive: Boolean) +case class ExprType(val typeName: String) { + def isPrimitive(ctx: CodegenContext): Boolean = ctx.isPrimitiveType(typeName) +} object ExprType { - def apply(ctx: CodegenContext, dataType: DataType): ExprType = ExprType(ctx.javaType(dataType), - ctx.isPrimitiveType(dataType)) + def apply(ctx: CodegenContext, dataType: DataType): ExprType = ExprType(ctx.javaType(dataType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 9082997469b56..37c49eabbf5cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -69,7 +69,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, GlobalValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), value, i) + """.stripMargin, GlobalValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), value, i) } else { (s""" |${ev.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 3110fd701c2e5..b79f6dbff66fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -75,7 +75,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, FalseLiteral, VariableValue(output, ExprType("InternalRow", false))) + ExprCode(code, FalseLiteral, VariableValue(output, ExprType("InternalRow"))) } private def createCodeForArray( @@ -106,7 +106,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, FalseLiteral, VariableValue(output, ExprType("ArrayData", false))) + ExprCode(code, FalseLiteral, VariableValue(output, ExprType("ArrayData"))) } private def createCodeForMap( @@ -127,7 +127,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, FalseLiteral, VariableValue(output, ExprType("MapData", false))) + ExprCode(code, FalseLiteral, VariableValue(output, ExprType("MapData"))) } @tailrec diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 07a38d9c7cd66..3a9e1c5b9a16f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,7 +52,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", ExprType(ctx.JAVA_BOOLEAN, true)), + ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", ExprType(ctx.JAVA_BOOLEAN)), StatementValue(ctx.getValue(tmpInput, dt, i.toString), ExprType(ctx, dt))) } @@ -348,7 +348,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $writeExpressions $updateRowSize """ - ExprCode(code, FalseLiteral, GlobalValue(result, ExprType("UnsafeRow", false))) + ExprCode(code, FalseLiteral, GlobalValue(result, ExprType("UnsafeRow"))) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0e326640664a8..60edd00117709 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -73,7 +73,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull), - ExprType(ctx.JAVA_BOOLEAN, true)) + ExprType(ctx.JAVA_BOOLEAN)) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -323,9 +323,9 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val value = if (eval.isNull.isInstanceOf[LiteralValue]) { - LiteralValue(eval.isNull, ExprType(ctx.JAVA_BOOLEAN, true)) + LiteralValue(eval.isNull, ExprType(ctx.JAVA_BOOLEAN)) } else { - VariableValue(eval.isNull, ExprType(ctx.JAVA_BOOLEAN, true)) + VariableValue(eval.isNull, ExprType(ctx.JAVA_BOOLEAN)) } ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 79212685899b0..d5299dc407221 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -63,7 +63,7 @@ trait InvokeLike extends Expression with NonSQLExpression { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull") - GlobalValue(resultIsNull, ExprType(ctx.JAVA_BOOLEAN, true)) + GlobalValue(resultIsNull, ExprType(ctx.JAVA_BOOLEAN)) } else { FalseLiteral } @@ -445,7 +445,7 @@ case class LambdaVariable( override def genCode(ctx: CodegenContext): ExprCode = { val isNullValue = if (nullable) { - VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)) + VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)) } else { FalseLiteral } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala index b0fa4b6fc741f..634d16c0cf811 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -28,15 +28,17 @@ class ExprValueSuite extends SparkFunSuite { assert(trueLit.value == "true") assert(falseLit.value == "false") + val ctx = new CodegenContext() + trueLit match { case LiteralValue(value, javaType) => - assert(value == "true" && javaType.typeName == "boolean" && javaType.isPrimitive) + assert(value == "true" && javaType.typeName == "boolean" && javaType.isPrimitive(ctx)) case _ => fail() } falseLit match { case LiteralValue(value, javaType) => - assert(value == "false" && javaType.typeName == "boolean" && javaType.isPrimitive) + assert(value == "false" && javaType.typeName == "boolean" && javaType.isPrimitive(ctx)) case _ => fail() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index b37a7383ade70..6a67fd2a826b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -52,7 +52,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val javaType = ctx.javaType(dataType) val value = ctx.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { - VariableValue(ctx.freshName("isNull"), ExprType(ctx.JAVA_BOOLEAN, true)) + VariableValue(ctx.freshName("isNull"), ExprType(ctx.JAVA_BOOLEAN)) } else { FalseLiteral } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 892b174aaa713..b55b9ea80d9aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -156,7 +156,7 @@ case class ExpandExec( |boolean $isNull = true; |${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), VariableValue(value, ExprType(ctx, firstExpr.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 6cb4c6caf4fa9..d77811985107b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -170,10 +170,10 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", StatementValue(s"$index == -1", ExprType(ctx.JAVA_BOOLEAN, true)), - VariableValue(index, ExprType(ctx.JAVA_INT, true)))) + Seq(ExprCode("", StatementValue(s"$index == -1", ExprType(ctx.JAVA_BOOLEAN)), + VariableValue(index, ExprType(ctx.JAVA_INT)))) } else { - Seq(ExprCode("", FalseLiteral, VariableValue(index, ExprType(ctx.JAVA_INT, true)))) + Seq(ExprCode("", FalseLiteral, VariableValue(index, ExprType(ctx.JAVA_INT)))) } } else { Seq.empty @@ -316,7 +316,7 @@ case class GenerateExec( |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), VariableValue(value, ExprType(ctx, dt))) } else { ExprCode(s"$javaType $value = $getter;", FalseLiteral, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index a0f2710952053..b3582f4075efc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -111,7 +111,7 @@ trait CodegenSupport extends SparkPlan { private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { if (row != null) { - ExprCode("", FalseLiteral, VariableValue(row, ExprType("UnsafeRow", false))) + ExprCode("", FalseLiteral, VariableValue(row, ExprType("UnsafeRow"))) } else { if (colVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => @@ -129,7 +129,7 @@ trait CodegenSupport extends SparkPlan { ExprCode(code, FalseLiteral, ev.value) } else { // There is no columns - ExprCode("", FalseLiteral, VariableValue("unsafeRow", ExprType("UnsafeRow", false))) + ExprCode("", FalseLiteral, VariableValue("unsafeRow", ExprType("UnsafeRow"))) } } } @@ -245,7 +245,7 @@ trait CodegenSupport extends SparkPlan { val isNull = ctx.freshName(s"exprIsNull_$i") arguments += ev.isNull parameters += s"boolean $isNull" - VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)) + VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)) } paramVars += ExprCode("", paramIsNull, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4c99bfba3e3bd..276e0b4e04e1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,7 +194,7 @@ case class HashAggregateExec( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + ExprCode(ev.code + initVars, GlobalValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), GlobalValue(value, ExprType(ctx, e.dataType))) } val initBufVar = evaluateVariables(bufVars) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 61491bdc784c6..fd1eb98c506f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -54,7 +54,7 @@ abstract class HashMapGenerator( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + ExprCode(ev.code + initVars, GlobalValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), GlobalValue(value, ExprType(ctx, e.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index a4168a8e0ec89..097a00674ddf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -368,7 +368,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val number = ctx.addMutableState(ctx.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", FalseLiteral, VariableValue(value, ExprType(ctx.JAVA_LONG, true))) + val ev = ExprCode("", FalseLiteral, VariableValue(value, ExprType(ctx.JAVA_LONG))) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index ec8539f5baa17..22bc5b820f719 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -191,7 +191,7 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), VariableValue(value, ExprType(ctx, a.dataType))) } } @@ -488,7 +488,7 @@ case class BroadcastHashJoinExec( } val resultVar = input ++ Seq(ExprCode("", FalseLiteral, - VariableValue(existsVar, ExprType(ctx.JAVA_BOOLEAN, true)))) + VariableValue(existsVar, ExprType(ctx.JAVA_BOOLEAN)))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 7aa9f46277702..3780d1d64207e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -531,7 +531,7 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), + (ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), VariableValue(value, ExprType(ctx, a.dataType))), leftVarsDecl) } else { val code = s"$value = $valueCode;" From e530f01f74c359aa4b21017393fd9d72d289a252 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 1 Mar 2018 13:33:46 +0000 Subject: [PATCH 9/9] Remove ExprType. --- .../sql/catalyst/expressions/Expression.scala | 8 ++--- .../sql/catalyst/expressions/arithmetic.scala | 6 ++-- .../expressions/codegen/CodeGenerator.scala | 6 ++-- .../expressions/codegen/ExprValue.scala | 29 +++++++------------ .../codegen/GenerateMutableProjection.scala | 4 +-- .../codegen/GenerateSafeProjection.scala | 10 +++---- .../codegen/GenerateUnsafeProjection.scala | 6 ++-- .../expressions/complexTypeCreator.scala | 2 +- .../expressions/conditionalExpressions.scala | 2 +- .../expressions/datetimeExpressions.scala | 4 +-- .../sql/catalyst/expressions/literals.scala | 27 +++++++++-------- .../spark/sql/catalyst/expressions/misc.scala | 2 +- .../expressions/nullExpressions.scala | 9 +++--- .../expressions/objects/objects.scala | 6 ++-- .../expressions/codegen/ExprValueSuite.scala | 7 +++-- .../sql/execution/ColumnarBatchScan.scala | 6 ++-- .../spark/sql/execution/ExpandExec.scala | 6 ++-- .../spark/sql/execution/GenerateExec.scala | 12 ++++---- .../sql/execution/WholeStageCodegenExec.scala | 8 ++--- .../aggregate/HashAggregateExec.scala | 4 +-- .../aggregate/HashMapGenerator.scala | 6 ++-- .../execution/basicPhysicalOperators.scala | 2 +- .../joins/BroadcastHashJoinExec.scala | 6 ++-- .../execution/joins/SortMergeJoinExec.scala | 8 ++--- 24 files changed, 88 insertions(+), 98 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d81e87c8591b9..e68ea46026e93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -105,8 +105,8 @@ abstract class Expression extends TreeNode[Expression] { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val eval = doGenCode(ctx, ExprCode("", - VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), - VariableValue(value, ExprType(ctx, dataType)))) + VariableValue(isNull, ctx.JAVA_BOOLEAN), + VariableValue(value, ctx.javaType(dataType)))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -123,7 +123,7 @@ abstract class Expression extends TreeNode[Expression] { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = GlobalValue(globalIsNull, ExprType(ctx.JAVA_BOOLEAN)) + eval.isNull = GlobalValue(globalIsNull, ctx.JAVA_BOOLEAN) s"$globalIsNull = $localIsNull;" } else { "" @@ -142,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = VariableValue(newValue, ExprType(ctx, dataType)) + eval.value = VariableValue(newValue, javaType) eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 06ef86f19b76f..78b80f635ff5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,8 +602,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull), - ExprType(ctx.JAVA_BOOLEAN)) + ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull), ctx.JAVA_BOOLEAN) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -682,8 +681,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull), - ExprType(ctx.JAVA_BOOLEAN)) + ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull), ctx.JAVA_BOOLEAN) val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 7190b8568d7ee..f725a91297fd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -323,7 +323,7 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, FalseLiteral, GlobalValue(value, ExprType(this, dataType))) + ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType))) } def declareMutableStates(): String = { @@ -1208,8 +1208,8 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(GlobalValue(isNull, ExprType(JAVA_BOOLEAN)), - GlobalValue(value, ExprType(this, expr.dataType))) + val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN), + GlobalValue(value, javaType(expr.dataType))) e.foreach(subExprEliminationExprs.put(_, state)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala index d589fde09e0ed..6f46d927ccd9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala @@ -24,14 +24,14 @@ import org.apache.spark.sql.types.DataType // An abstraction that represents the evaluation result of [[ExprCode]]. abstract class ExprValue { - val javaType: ExprType + val javaType: String // Whether we can directly access the evaluation value anywhere. // For example, a variable created outside a method can not be accessed inside the method. // For such cases, we may need to pass the evaluation as parameter. val canDirectAccess: Boolean - def isPrimitive(ctx: CodegenContext): Boolean = javaType.isPrimitive(ctx) + def isPrimitive(ctx: CodegenContext): Boolean = ctx.isPrimitiveType(javaType) } object ExprValue { @@ -39,21 +39,21 @@ object ExprValue { } // A literal evaluation of [[ExprCode]]. -class LiteralValue(val value: String, val javaType: ExprType) extends ExprValue { +class LiteralValue(val value: String, val javaType: String) extends ExprValue { override def toString: String = value override val canDirectAccess: Boolean = true } object LiteralValue { - def apply(value: String, javaType: ExprType): LiteralValue = new LiteralValue(value, javaType) - def unapply(literal: LiteralValue): Option[(String, ExprType)] = + def apply(value: String, javaType: String): LiteralValue = new LiteralValue(value, javaType) + def unapply(literal: LiteralValue): Option[(String, String)] = Some((literal.value, literal.javaType)) } // A variable evaluation of [[ExprCode]]. case class VariableValue( val variableName: String, - val javaType: ExprType) extends ExprValue { + val javaType: String) extends ExprValue { override def toString: String = variableName override val canDirectAccess: Boolean = false } @@ -61,25 +61,16 @@ case class VariableValue( // A statement evaluation of [[ExprCode]]. case class StatementValue( val statement: String, - val javaType: ExprType, + val javaType: String, val canDirectAccess: Boolean = false) extends ExprValue { override def toString: String = statement } // A global variable evaluation of [[ExprCode]]. -case class GlobalValue(val value: String, val javaType: ExprType) extends ExprValue { +case class GlobalValue(val value: String, val javaType: String) extends ExprValue { override def toString: String = value override val canDirectAccess: Boolean = true } -case object TrueLiteral extends LiteralValue("true", ExprType("boolean")) -case object FalseLiteral extends LiteralValue("false", ExprType("boolean")) - -// Represents the java type of an evaluation. -case class ExprType(val typeName: String) { - def isPrimitive(ctx: CodegenContext): Boolean = ctx.isPrimitiveType(typeName) -} - -object ExprType { - def apply(ctx: CodegenContext, dataType: DataType): ExprType = ExprType(ctx.javaType(dataType)) -} +case object TrueLiteral extends LiteralValue("true", "boolean") +case object FalseLiteral extends LiteralValue("false", "boolean") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 37c49eabbf5cd..2a96ee6969611 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -69,7 +69,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, GlobalValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), value, i) + """.stripMargin, GlobalValue(isNull, ctx.JAVA_BOOLEAN), value, i) } else { (s""" |${ev.code} @@ -83,7 +83,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(projectionCodes).map { case (e, (_, isNull, value, i)) => - val ev = ExprCode("", isNull, GlobalValue(value, ExprType(ctx, e.dataType))) + val ev = ExprCode("", isNull, GlobalValue(value, ctx.javaType(e.dataType))) ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b79f6dbff66fd..b86192cc91f5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -54,7 +54,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => val converter = convertToSafe(ctx, StatementValue(ctx.getValue(tmpInput, dt, i.toString), - ExprType(ctx, dt)), dt) + ctx.javaType(dt)), dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -75,7 +75,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, FalseLiteral, VariableValue(output, ExprType("InternalRow"))) + ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow")) } private def createCodeForArray( @@ -91,7 +91,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val arrayClass = classOf[GenericArrayData].getName val elementConverter = convertToSafe( - ctx, StatementValue(ctx.getValue(tmpInput, elementType, index), ExprType(ctx, elementType)), + ctx, StatementValue(ctx.getValue(tmpInput, elementType, index), ctx.javaType(elementType)), elementType) val code = s""" final ArrayData $tmpInput = $input; @@ -106,7 +106,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, FalseLiteral, VariableValue(output, ExprType("ArrayData"))) + ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData")) } private def createCodeForMap( @@ -127,7 +127,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, FalseLiteral, VariableValue(output, ExprType("MapData"))) + ExprCode(code, FalseLiteral, VariableValue(output, "MapData")) } @tailrec diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 3a9e1c5b9a16f..20c8107f7884c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,8 +52,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", ExprType(ctx.JAVA_BOOLEAN)), - StatementValue(ctx.getValue(tmpInput, dt, i.toString), ExprType(ctx, dt))) + ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", ctx.JAVA_BOOLEAN), + StatementValue(ctx.getValue(tmpInput, dt, i.toString), ctx.javaType(dt))) } s""" @@ -348,7 +348,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $writeExpressions $updateRowSize """ - ExprCode(code, FalseLiteral, GlobalValue(result, ExprType("UnsafeRow"))) + ExprCode(code, FalseLiteral, GlobalValue(result, "UnsafeRow")) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index b35ca20618c6e..c6ed7ad85db11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + assigns + postprocess, - value = VariableValue(arrayData, ExprType(ctx, dataType)), + value = VariableValue(arrayData, ctx.javaType(dataType)), isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index d893b2f5ab158..b6d3fc3eaa87d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -192,7 +192,7 @@ case class CaseWhen( // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") ev.value = GlobalValue(ctx.addMutableState(ctx.javaType(dataType), ev.value), - ExprType(ctx, dataType)) + ctx.javaType(dataType)) // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 75d11db52b3f6..488fcba7859c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -678,7 +678,7 @@ abstract class UnixTime val df = classOf[DateFormat].getName if (formatter == null) { ExprCode("", TrueLiteral, LiteralValue(ctx.defaultValue(dataType), - ExprType(ctx, dataType))) + ctx.javaType(dataType))) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) @@ -813,7 +813,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { - ExprCode("", TrueLiteral, LiteralValue("(UTF8String) null", ExprType(ctx, dataType))) + ExprCode("", TrueLiteral, LiteralValue("(UTF8String) null", ctx.javaType(dataType))) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index bebc86a5bf4a4..3b1c34c84f70b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -278,46 +278,45 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val exprType = ExprType(ctx, dataType) if (value == null) { val defaultValueLiteral = ctx.defaultValue(javaType) match { case "null" => s"(($javaType)null)" case lit => lit } - ExprCode(code = "", isNull = TrueLiteral, value = LiteralValue(defaultValueLiteral, exprType)) + ExprCode(code = "", isNull = TrueLiteral, value = LiteralValue(defaultValueLiteral, javaType)) } else { dataType match { case BooleanType | IntegerType | DateType => - ExprCode.forNonNullValue(LiteralValue(value.toString, exprType)) + ExprCode.forNonNullValue(LiteralValue(value.toString, javaType)) case FloatType => value.asInstanceOf[Float] match { case v if v.isNaN => - ExprCode.forNonNullValue(LiteralValue("Float.NaN", exprType)) + ExprCode.forNonNullValue(LiteralValue("Float.NaN", javaType)) case Float.PositiveInfinity => - ExprCode.forNonNullValue(LiteralValue("Float.POSITIVE_INFINITY", exprType)) + ExprCode.forNonNullValue(LiteralValue("Float.POSITIVE_INFINITY", javaType)) case Float.NegativeInfinity => - ExprCode.forNonNullValue(LiteralValue("Float.NEGATIVE_INFINITY", exprType)) + ExprCode.forNonNullValue(LiteralValue("Float.NEGATIVE_INFINITY", javaType)) case _ => - ExprCode.forNonNullValue(LiteralValue(s"${value}F", exprType)) + ExprCode.forNonNullValue(LiteralValue(s"${value}F", javaType)) } case DoubleType => value.asInstanceOf[Double] match { case v if v.isNaN => - ExprCode.forNonNullValue(LiteralValue("Double.NaN", exprType)) + ExprCode.forNonNullValue(LiteralValue("Double.NaN", javaType)) case Double.PositiveInfinity => - ExprCode.forNonNullValue(LiteralValue("Double.POSITIVE_INFINITY", exprType)) + ExprCode.forNonNullValue(LiteralValue("Double.POSITIVE_INFINITY", javaType)) case Double.NegativeInfinity => - ExprCode.forNonNullValue(LiteralValue("Double.NEGATIVE_INFINITY", exprType)) + ExprCode.forNonNullValue(LiteralValue("Double.NEGATIVE_INFINITY", javaType)) case _ => - ExprCode.forNonNullValue(LiteralValue(s"${value}D", exprType)) + ExprCode.forNonNullValue(LiteralValue(s"${value}D", javaType)) } case ByteType | ShortType => - ExprCode.forNonNullValue(LiteralValue(s"($javaType)$value", exprType)) + ExprCode.forNonNullValue(LiteralValue(s"($javaType)$value", javaType)) case TimestampType | LongType => - ExprCode.forNonNullValue(LiteralValue(s"${value}L", exprType)) + ExprCode.forNonNullValue(LiteralValue(s"${value}L", javaType)) case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) - ExprCode.forNonNullValue(GlobalValue(constRef, exprType)) + ExprCode.forNonNullValue(GlobalValue(constRef, javaType)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index d9bacd07a9286..fce6e4cdc947e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -86,7 +86,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = TrueLiteral, - value = LiteralValue("null", ExprType(ctx, dataType))) + value = LiteralValue("null", ctx.javaType(dataType))) } override def sql: String = s"assert_true(${child.sql})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 60edd00117709..f4d9477aebd6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull), - ExprType(ctx.JAVA_BOOLEAN)) + ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull), ctx.JAVA_BOOLEAN) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -323,9 +322,9 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val value = if (eval.isNull.isInstanceOf[LiteralValue]) { - LiteralValue(eval.isNull, ExprType(ctx.JAVA_BOOLEAN)) + LiteralValue(eval.isNull, ctx.JAVA_BOOLEAN) } else { - VariableValue(eval.isNull, ExprType(ctx.JAVA_BOOLEAN)) + VariableValue(eval.isNull, ctx.JAVA_BOOLEAN) } ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } @@ -358,7 +357,7 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } else if (eval.isNull == FalseLiteral) { TrueLiteral } else { - StatementValue(s"(!(${eval.isNull}))", ExprType(ctx, dataType)) + StatementValue(s"(!(${eval.isNull}))", ctx.javaType(dataType)) } ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d5299dc407221..db3490da09c18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -63,7 +63,7 @@ trait InvokeLike extends Expression with NonSQLExpression { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull") - GlobalValue(resultIsNull, ExprType(ctx.JAVA_BOOLEAN)) + GlobalValue(resultIsNull, ctx.JAVA_BOOLEAN) } else { FalseLiteral } @@ -445,11 +445,11 @@ case class LambdaVariable( override def genCode(ctx: CodegenContext): ExprCode = { val isNullValue = if (nullable) { - VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)) + VariableValue(isNull, ctx.JAVA_BOOLEAN) } else { FalseLiteral } - ExprCode(code = "", value = VariableValue(value, ExprType(ctx, dataType)), isNull = isNullValue) + ExprCode(code = "", value = VariableValue(value, ctx.javaType(dataType)), isNull = isNullValue) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala index 634d16c0cf811..9640e8d3ba7a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -30,15 +30,18 @@ class ExprValueSuite extends SparkFunSuite { val ctx = new CodegenContext() + assert(trueLit.isPrimitive(ctx)) + assert(falseLit.isPrimitive(ctx)) + trueLit match { case LiteralValue(value, javaType) => - assert(value == "true" && javaType.typeName == "boolean" && javaType.isPrimitive(ctx)) + assert(value == "true" && javaType == "boolean") case _ => fail() } falseLit match { case LiteralValue(value, javaType) => - assert(value == "false" && javaType.typeName == "boolean" && javaType.isPrimitive(ctx)) + assert(value == "false" && javaType == "boolean") case _ => fail() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 6a67fd2a826b5..ba84870e00b0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExprType, FalseLiteral, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral, VariableValue} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -52,7 +52,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val javaType = ctx.javaType(dataType) val value = ctx.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { - VariableValue(ctx.freshName("isNull"), ExprType(ctx.JAVA_BOOLEAN)) + VariableValue(ctx.freshName("isNull"), ctx.JAVA_BOOLEAN) } else { FalseLiteral } @@ -66,7 +66,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } else { s"$javaType $valueVar = $value;" }).trim - ExprCode(code, isNullVar, VariableValue(valueVar, ExprType(ctx, dataType))) + ExprCode(code, isNullVar, VariableValue(valueVar, javaType)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index b55b9ea80d9aa..8415e88220c7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExprType, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, VariableValue} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -156,8 +156,8 @@ case class ExpandExec( |boolean $isNull = true; |${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), - VariableValue(value, ExprType(ctx, firstExpr.dataType))) + ExprCode(code, VariableValue(isNull, ctx.JAVA_BOOLEAN), + VariableValue(value, ctx.javaType(firstExpr.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index d77811985107b..bd0313bc5bdfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -170,10 +170,10 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", StatementValue(s"$index == -1", ExprType(ctx.JAVA_BOOLEAN)), - VariableValue(index, ExprType(ctx.JAVA_INT)))) + Seq(ExprCode("", StatementValue(s"$index == -1", ctx.JAVA_BOOLEAN), + VariableValue(index, ctx.JAVA_INT))) } else { - Seq(ExprCode("", FalseLiteral, VariableValue(index, ExprType(ctx.JAVA_INT)))) + Seq(ExprCode("", FalseLiteral, VariableValue(index, ctx.JAVA_INT))) } } else { Seq.empty @@ -316,11 +316,11 @@ case class GenerateExec( |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), - VariableValue(value, ExprType(ctx, dt))) + ExprCode(code, VariableValue(isNull, ctx.JAVA_BOOLEAN), + VariableValue(value, javaType)) } else { ExprCode(s"$javaType $value = $getter;", FalseLiteral, - VariableValue(value, ExprType(ctx, dt))) + VariableValue(value, javaType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index b3582f4075efc..94e27dc6d08a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -111,7 +111,7 @@ trait CodegenSupport extends SparkPlan { private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { if (row != null) { - ExprCode("", FalseLiteral, VariableValue(row, ExprType("UnsafeRow"))) + ExprCode("", FalseLiteral, VariableValue(row, "UnsafeRow")) } else { if (colVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => @@ -129,7 +129,7 @@ trait CodegenSupport extends SparkPlan { ExprCode(code, FalseLiteral, ev.value) } else { // There is no columns - ExprCode("", FalseLiteral, VariableValue("unsafeRow", ExprType("UnsafeRow"))) + ExprCode("", FalseLiteral, VariableValue("unsafeRow", "UnsafeRow")) } } } @@ -245,11 +245,11 @@ trait CodegenSupport extends SparkPlan { val isNull = ctx.freshName(s"exprIsNull_$i") arguments += ev.isNull parameters += s"boolean $isNull" - VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)) + VariableValue(isNull, ctx.JAVA_BOOLEAN) } paramVars += ExprCode("", paramIsNull, - VariableValue(paramName, ExprType(ctx, attributes(i).dataType))) + VariableValue(paramName, ctx.javaType(attributes(i).dataType))) } (arguments, parameters, paramVars) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 276e0b4e04e1c..42fb03da4d2b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,8 +194,8 @@ case class HashAggregateExec( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), - GlobalValue(value, ExprType(ctx, e.dataType))) + ExprCode(ev.code + initVars, GlobalValue(isNull, ctx.JAVA_BOOLEAN), + GlobalValue(value, ctx.javaType(e.dataType))) } val initBufVar = evaluateVariables(bufVars) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index fd1eb98c506f8..3b9e01141c366 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExprType, GlobalValue} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GlobalValue} import org.apache.spark.sql.types._ /** @@ -54,8 +54,8 @@ abstract class HashMapGenerator( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), - GlobalValue(value, ExprType(ctx, e.dataType))) + ExprCode(ev.code + initVars, GlobalValue(isNull, ctx.JAVA_BOOLEAN), + GlobalValue(value, ctx.javaType(e.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 097a00674ddf9..bf01f64be1592 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -368,7 +368,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val number = ctx.addMutableState(ctx.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", FalseLiteral, VariableValue(value, ExprType(ctx.JAVA_LONG))) + val ev = ExprCode("", FalseLiteral, VariableValue(value, ctx.JAVA_LONG)) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 22bc5b820f719..84b2a4243dabe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -191,8 +191,8 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), - VariableValue(value, ExprType(ctx, a.dataType))) + ExprCode(code, VariableValue(isNull, ctx.JAVA_BOOLEAN), + VariableValue(value, ctx.javaType(a.dataType))) } } } @@ -488,7 +488,7 @@ case class BroadcastHashJoinExec( } val resultVar = input ++ Seq(ExprCode("", FalseLiteral, - VariableValue(existsVar, ExprType(ctx.JAVA_BOOLEAN)))) + VariableValue(existsVar, ctx.JAVA_BOOLEAN))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 3780d1d64207e..57f6f5df4a71e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExprType, FalseLiteral, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral, VariableValue} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, @@ -531,13 +531,13 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN)), - VariableValue(value, ExprType(ctx, a.dataType))), leftVarsDecl) + (ExprCode(code, VariableValue(isNull, ctx.JAVA_BOOLEAN), + VariableValue(value, ctx.javaType(a.dataType))), leftVarsDecl) } else { val code = s"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" (ExprCode(code, FalseLiteral, - VariableValue(value, ExprType(ctx, a.dataType))), leftVarsDecl) + VariableValue(value, ctx.javaType(a.dataType))), leftVarsDecl) } }.unzip }