diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 89ffbb0016916..5021a567592e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ /** @@ -76,7 +76,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") + ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 38caf67d465d8..7a5e49cb5206b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,7 +104,9 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode("", isNull, value)) + val eval = doGenCode(ctx, ExprCode("", + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(dataType)))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -118,10 +120,10 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { - val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { + val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = globalIsNull + eval.isNull = GlobalValue(globalIsNull, CodeGenerator.JAVA_BOOLEAN) s"$globalIsNull = $localIsNull;" } else { "" @@ -140,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = newValue + eval.value = VariableValue(newValue, javaType) eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -446,7 +448,7 @@ abstract class UnaryExpression extends Expression { boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } @@ -546,7 +548,7 @@ abstract class BinaryExpression extends Expression { ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } @@ -690,7 +692,7 @@ abstract class TernaryExpression extends Expression { ${midGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index dd523d312e3b4..ad1e7bdb31987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -73,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; - $countTerm++;""", isNull = "false") + $countTerm++;""", isNull = FalseLiteral) } override def prettyName: String = "monotonically_increasing_id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index cc6a769d032d3..787bcaf5e81de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -47,6 +47,6 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", - isNull = "false") + isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 508bdd5050b54..478ff3a7c1011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -601,7 +601,8 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -680,7 +681,8 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 84b1e3fbda876..c9c60ef1be640 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -56,16 +56,17 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: String, var value: String) +case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) object ExprCode { def forNullValue(dataType: DataType): ExprCode = { val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) - ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + ExprCode(code = "", isNull = TrueLiteral, + value = LiteralValue(defaultValueLiteral, CodeGenerator.javaType(dataType))) } - def forNonNullValue(value: String): ExprCode = { - ExprCode(code = "", isNull = "false", value = value) + def forNonNullValue(value: ExprValue): ExprCode = { + ExprCode(code = "", isNull = FalseLiteral, value = value) } } @@ -77,7 +78,7 @@ object ExprCode { * @param value A term for a value of a common sub-expression. Not valid if `isNull` * is set to `true`. */ -case class SubExprEliminationState(isNull: String, value: String) +case class SubExprEliminationState(isNull: ExprValue, value: ExprValue) /** * Codes and common subexpressions mapping used for subexpression elimination. @@ -330,7 +331,7 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, "false", value) + ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType))) } def declareMutableStates(): String = { @@ -1003,7 +1004,8 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) + val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN), + GlobalValue(value, javaType(expr.dataType))) subExprEliminationExprs ++= e.map(_ -> state).toMap } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index e12420bb5dfdd..a91989e129664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -59,7 +59,7 @@ trait CodegenFallback extends Expression { $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; - """, isNull = "false") + """, isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala new file mode 100644 index 0000000000000..df5f1c58b1b2d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.language.implicitConversions + +import org.apache.spark.sql.types.DataType + +// An abstraction that represents the evaluation result of [[ExprCode]]. +abstract class ExprValue { + + val javaType: String + + // Whether we can directly access the evaluation value anywhere. + // For example, a variable created outside a method can not be accessed inside the method. + // For such cases, we may need to pass the evaluation as parameter. + val canDirectAccess: Boolean + + def isPrimitive: Boolean = CodeGenerator.isPrimitiveType(javaType) +} + +object ExprValue { + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString +} + +// A literal evaluation of [[ExprCode]]. +class LiteralValue(val value: String, val javaType: String) extends ExprValue { + override def toString: String = value + override val canDirectAccess: Boolean = true +} + +object LiteralValue { + def apply(value: String, javaType: String): LiteralValue = new LiteralValue(value, javaType) + def unapply(literal: LiteralValue): Option[(String, String)] = + Some((literal.value, literal.javaType)) +} + +// A variable evaluation of [[ExprCode]]. +case class VariableValue( + val variableName: String, + val javaType: String) extends ExprValue { + override def toString: String = variableName + override val canDirectAccess: Boolean = false +} + +// A statement evaluation of [[ExprCode]]. +case class StatementValue( + val statement: String, + val javaType: String, + val canDirectAccess: Boolean = false) extends ExprValue { + override def toString: String = statement +} + +// A global variable evaluation of [[ExprCode]]. +case class GlobalValue(val value: String, val javaType: String) extends ExprValue { + override def toString: String = value + override val canDirectAccess: Boolean = true +} + +case object TrueLiteral extends LiteralValue("true", "boolean") +case object FalseLiteral extends LiteralValue("false", "boolean") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d35fd8ecb4d63..3ae0b54c754cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -59,7 +59,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) - val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { + val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") @@ -69,7 +69,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, isNull, value, i) + """.stripMargin, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), value, i) } else { (s""" |${ev.code} @@ -83,7 +83,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(projectionCodes).map { case (e, (_, isNull, value, i)) => - val ev = ExprCode("", isNull, value) + val ev = ExprCode("", isNull, GlobalValue(value, CodeGenerator.javaType(e.dataType))) CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f92f70ee71fef..a30a0b22cd305 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -53,7 +53,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt) + val converter = convertToSafe(ctx, + StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), + CodeGenerator.javaType(dt)), dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -74,7 +76,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow")) } private def createCodeForArray( @@ -89,8 +91,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val index = ctx.freshName("index") val arrayClass = classOf[GenericArrayData].getName - val elementConverter = convertToSafe( - ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType) + val elementConverter = convertToSafe(ctx, + StatementValue(CodeGenerator.getValue(tmpInput, elementType, index), + CodeGenerator.javaType(elementType)), elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -104,7 +107,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData")) } private def createCodeForMap( @@ -125,19 +128,19 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "MapData")) } @tailrec private def convertToSafe( ctx: CodegenContext, - input: String, + input: ExprValue, dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => ExprCode("", "false", input) + case _ => ExprCode("", FalseLiteral, input) } protected def create(expressions: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index ab2254cd9f70a..4a4d76313a543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,7 +52,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) + ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", CodeGenerator.JAVA_BOOLEAN), + StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), + CodeGenerator.javaType(dt))) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -334,7 +336,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $evalSubexpr $writeExpressions """ - ExprCode(code, "false", s"$rowWriter.getRow()") + // `rowWriter` is declared as a class field, so we can access it directly in methods. + ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", + canDirectAccess = true)) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index beb84694c44e8..91188da8b0bd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -55,7 +55,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : - (${childGen.value}).numElements();""", isNull = "false") + (${childGen.value}).numElements();""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 85facdad43db7..49a8d12057188 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,8 +64,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + assigns + postprocess, - value = arrayData, - isNull = "false") + value = VariableValue(arrayData, CodeGenerator.javaType(dataType)), + isNull = FalseLiteral) } override def prettyName: String = "array" @@ -378,7 +378,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); |$values = null; - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } override def prettyName: String = "named_struct" @@ -394,7 +394,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) - ExprCode(code = eval.code, isNull = "false", value = eval.value) + ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value) } override def prettyName: String = "named_struct_unsafe" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index f4e9619bac59d..409c0b6b79b81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -191,7 +191,8 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value) + ev.value = GlobalValue(ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), + CodeGenerator.javaType(dataType)) // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 1ae4e5a2f716b..49dd988b4b53c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -813,7 +813,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { - ExprCode("", "true", "(UTF8String) null") + ExprCode("", TrueLiteral, LiteralValue("(UTF8String) null", + CodeGenerator.javaType(dataType))) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 4f4d49166e88c..3af4bfebad45e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -218,7 +218,7 @@ case class Stack(children: Seq[Expression]) extends Generator { s""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index b702422ed7a1d..71a7ce805d1ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -269,7 +269,7 @@ abstract class HashExpression[E] extends Expression { protected def computeHash(value: Any, dataType: DataType, seed: E): E override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = FalseLiteral val childrenHash = children.map { child => val childGen = child.genCode(ctx) @@ -633,7 +633,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = FalseLiteral val childHash = ctx.freshName("childHash") val childrenHash = children.map { child => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 07785e7448586..2a3cc580273ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -43,7 +43,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = "false") + s"$className.getInputFilePath();", isNull = FalseLiteral) } } @@ -66,7 +66,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = "false") + s"$className.getStartOffset();", isNull = FalseLiteral) } } @@ -89,6 +89,6 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = "false") + s"$className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 7395609a04ba5..742a650eb445d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -283,36 +283,36 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { } else { dataType match { case BooleanType | IntegerType | DateType => - ExprCode.forNonNullValue(value.toString) + ExprCode.forNonNullValue(LiteralValue(value.toString, javaType)) case FloatType => value.asInstanceOf[Float] match { case v if v.isNaN => - ExprCode.forNonNullValue("Float.NaN") + ExprCode.forNonNullValue(LiteralValue("Float.NaN", javaType)) case Float.PositiveInfinity => - ExprCode.forNonNullValue("Float.POSITIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Float.POSITIVE_INFINITY", javaType)) case Float.NegativeInfinity => - ExprCode.forNonNullValue("Float.NEGATIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Float.NEGATIVE_INFINITY", javaType)) case _ => - ExprCode.forNonNullValue(s"${value}F") + ExprCode.forNonNullValue(LiteralValue(s"${value}F", javaType)) } case DoubleType => value.asInstanceOf[Double] match { case v if v.isNaN => - ExprCode.forNonNullValue("Double.NaN") + ExprCode.forNonNullValue(LiteralValue("Double.NaN", javaType)) case Double.PositiveInfinity => - ExprCode.forNonNullValue("Double.POSITIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Double.POSITIVE_INFINITY", javaType)) case Double.NegativeInfinity => - ExprCode.forNonNullValue("Double.NEGATIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Double.NEGATIVE_INFINITY", javaType)) case _ => - ExprCode.forNonNullValue(s"${value}D") + ExprCode.forNonNullValue(LiteralValue(s"${value}D", javaType)) } case ByteType | ShortType => - ExprCode.forNonNullValue(s"($javaType)$value") + ExprCode.forNonNullValue(LiteralValue(s"($javaType)$value", javaType)) case TimestampType | LongType => - ExprCode.forNonNullValue(s"${value}L") + ExprCode.forNonNullValue(LiteralValue(s"${value}L", javaType)) case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) - ExprCode.forNonNullValue(constRef) + ExprCode.forNonNullValue(GlobalValue(constRef, javaType)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a390f8ef7fd9a..7081a5e096d56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -91,7 +91,8 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); - |}""".stripMargin, isNull = "true", value = "null") + |}""".stripMargin, isNull = TrueLiteral, + value = LiteralValue("null", CodeGenerator.javaType(dataType))) } override def sql: String = s"assert_true(${child.sql})" @@ -150,7 +151,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Uuid = Uuid(randomSeed) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index b35fa72e95d1e..55b6e346be82a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -72,7 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -235,7 +236,7 @@ case class IsNaN(child: Expression) extends UnaryExpression ev.copy(code = s""" ${eval.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false") + ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) } } } @@ -320,7 +321,12 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = eval.isNull) + val value = if (eval.isNull.isInstanceOf[LiteralValue]) { + LiteralValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) + } else { + VariableValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) + } + ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NULL)" @@ -346,7 +352,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = s"(!(${eval.isNull}))") + val value = if (eval.isNull == TrueLiteral) { + FalseLiteral + } else if (eval.isNull == FalseLiteral) { + TrueLiteral + } else { + StatementValue(s"(!(${eval.isNull}))", CodeGenerator.javaType(dataType)) + } + ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NOT NULL)" @@ -441,6 +454,6 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate | $codes |} while (false); |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 0e9d357c19c63..ce2d265982df3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -60,13 +60,13 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param ctx a [[CodegenContext]] * @return (code to prepare arguments, argument string, result of argument null check) */ - def prepareArguments(ctx: CodegenContext): (String, String, String) = { + def prepareArguments(ctx: CodegenContext): (String, String, ExprValue) = { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") - resultIsNull + GlobalValue(resultIsNull, CodeGenerator.JAVA_BOOLEAN) } else { - "false" + FalseLiteral } val argValues = arguments.map { e => val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue") @@ -202,7 +202,7 @@ case class StaticInvoke( val prepareIsNull = if (nullable) { s"boolean ${ev.isNull} = $resultIsNull;" } else { - ev.isNull = "false" + ev.isNull = FalseLiteral "" } @@ -490,7 +490,7 @@ case class WrapOption(child: Expression, optType: DataType) ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } @@ -512,7 +512,13 @@ case class LambdaVariable( } override def genCode(ctx: CodegenContext): ExprCode = { - ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") + val isNullValue = if (nullable) { + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) + } else { + FalseLiteral + } + ExprCode(code = "", value = VariableValue(value, CodeGenerator.javaType(dataType)), + isNull = isNullValue) } // This won't be called as `genCode` is overrided, just overriding it to make @@ -784,7 +790,7 @@ case class MapObjects private( // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" - val genFunctionValue = lambdaFunction.dataType match { + val genFunctionValue: String = lambdaFunction.dataType match { case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) @@ -1287,7 +1293,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); """.stripMargin - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } @@ -1443,7 +1449,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) throw new NullPointerException($errMsgField); } """ - ev.copy(code = code, isNull = "false", value = childGen.value) + ev.copy(code = code, isNull = FalseLiteral, value = childGen.value) } } @@ -1494,7 +1500,7 @@ case class GetExternalRowField( final Object ${ev.value} = ${row.value}.get($index); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4b85d9adbe311..e195ec17f3bcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -405,7 +405,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with if (${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -461,7 +461,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.isNull = "false" + ev.isNull = FalseLiteral ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = true; @@ -469,7 +469,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P if (!${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -615,7 +615,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || - (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = "false") + (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index f36633867316e..70186053617f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -83,7 +83,7 @@ case class Rand(child: Expression) extends RDG { s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Rand = Rand(child) @@ -120,7 +120,7 @@ case class Randn(child: Expression) extends RDG { s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Randn = Randn(child) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 398b6767654fa..8e83b35c3809c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -448,6 +448,8 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ref = BoundReference(0, IntegerType, true) val add1 = Add(ref, ref) val add2 = Add(add1, add1) + val dummy = SubExprEliminationState(VariableValue("dummy", "boolean"), + VariableValue("dummy", "boolean")) // raw testing of basic functionality { @@ -457,7 +459,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) assert(ctx.subExprEliminationExprs.contains(ref)) // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(add1 -> SubExprEliminationState("dummy", "dummy"))) { + ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) { assert(ctx.subExprEliminationExprs.contains(add1)) assert(!ctx.subExprEliminationExprs.contains(ref)) Seq.empty @@ -475,7 +477,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE assert(ctx.subExprEliminationExprs.contains(add1)) // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(ref -> SubExprEliminationState("dummy", "dummy"))) { + ctx.withSubExprEliminationExprs(Map(ref -> dummy)) { assert(ctx.subExprEliminationExprs.contains(ref)) assert(!ctx.subExprEliminationExprs.contains(add1)) Seq.empty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala new file mode 100644 index 0000000000000..c8f4cff7db48d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite + +class ExprValueSuite extends SparkFunSuite { + + test("TrueLiteral and FalseLiteral should be LiteralValue") { + val trueLit = TrueLiteral + val falseLit = FalseLiteral + + assert(trueLit.value == "true") + assert(falseLit.value == "false") + + assert(trueLit.isPrimitive) + assert(falseLit.isPrimitive) + + trueLit match { + case LiteralValue(value, javaType) => + assert(value == "true" && javaType == "boolean") + case _ => fail() + } + + falseLit match { + case LiteralValue(value, javaType) => + assert(value == "false" && javaType == "boolean") + case _ => fail() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 392906a022903..434214a10e1e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -51,7 +51,11 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { nullable: Boolean): ExprCode = { val javaType = CodeGenerator.javaType(dataType) val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) - val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } + val isNullVar = if (nullable) { + VariableValue(ctx.freshName("isNull"), CodeGenerator.JAVA_BOOLEAN) + } else { + FalseLiteral + } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { @@ -62,7 +66,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } else { s"$javaType $valueVar = $value;" }).trim - ExprCode(code, isNullVar, valueVar) + ExprCode(code, isNullVar, VariableValue(valueVar, javaType)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 12ae1ea4a7c13..0d9a62cace62a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, VariableValue} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -157,7 +157,8 @@ case class ExpandExec( |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(firstExpr.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 384f0398a1ec0..85c5ebfdaa689 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} @@ -170,9 +170,10 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", s"$index == -1", index)) + Seq(ExprCode("", StatementValue(s"$index == -1", CodeGenerator.JAVA_BOOLEAN), + VariableValue(index, CodeGenerator.JAVA_INT))) } else { - Seq(ExprCode("", "false", index)) + Seq(ExprCode("", FalseLiteral, VariableValue(index, CodeGenerator.JAVA_INT))) } } else { Seq.empty @@ -315,9 +316,11 @@ case class GenerateExec( |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, javaType)) } else { - ExprCode(s"$javaType $value = $getter;", "false", value) + ExprCode(s"$javaType $value = $getter;", FalseLiteral, + VariableValue(value, javaType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 6ddaacfee1a40..805ff3cf001ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -111,7 +111,7 @@ trait CodegenSupport extends SparkPlan { private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { if (row != null) { - ExprCode("", "false", row) + ExprCode("", FalseLiteral, VariableValue(row, "UnsafeRow")) } else { if (colVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => @@ -126,10 +126,10 @@ trait CodegenSupport extends SparkPlan { |$evaluateInputs |${ev.code.trim} """.stripMargin.trim - ExprCode(code, "false", ev.value) + ExprCode(code, FalseLiteral, ev.value) } else { // There is no columns - ExprCode("", "false", "unsafeRow") + ExprCode("", FalseLiteral, VariableValue("unsafeRow", "UnsafeRow")) } } } @@ -241,15 +241,16 @@ trait CodegenSupport extends SparkPlan { parameters += s"$paramType $paramName" val paramIsNull = if (!attributes(i).nullable) { // Use constant `false` without passing `isNull` for non-nullable variable. - "false" + FalseLiteral } else { val isNull = ctx.freshName(s"exprIsNull_$i") arguments += ev.isNull parameters += s"boolean $isNull" - isNull + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) } - paramVars += ExprCode("", paramIsNull, paramName) + paramVars += ExprCode("", paramIsNull, + VariableValue(paramName, CodeGenerator.javaType(attributes(i).dataType))) } (arguments, parameters, paramVars) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 1926e9373bc55..8f7f10243d4cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,7 +194,8 @@ case class HashAggregateExec( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), + GlobalValue(value, CodeGenerator.javaType(e.dataType))) } val initBufVar = evaluateVariables(bufVars) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 6b60b414ffe5f..4978954271311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GlobalValue} import org.apache.spark.sql.types._ /** @@ -54,7 +54,8 @@ abstract class HashMapGenerator( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), + GlobalValue(value, CodeGenerator.javaType(e.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 4707022f74547..cab7081400ce9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -192,7 +192,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => if (notNullAttributes.contains(child.output(i).exprId)) { - ev.isNull = "false" + ev.isNull = FalseLiteral } ev } @@ -368,7 +368,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", "false", value) + val ev = ExprCode("", FalseLiteral, VariableValue(value, CodeGenerator.JAVA_LONG)) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 487d6a2383318..fa62a32d51f3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -192,7 +192,8 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(a.dataType))) } } } @@ -487,7 +488,8 @@ case class BroadcastHashJoinExec( s"$existsVar = true;" } - val resultVar = input ++ Seq(ExprCode("", "false", existsVar)) + val resultVar = input ++ Seq(ExprCode("", FalseLiteral, + VariableValue(existsVar, CodeGenerator.JAVA_BOOLEAN))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 5a511b30e4fd9..b61acb8d4fda9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, @@ -531,11 +531,13 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, isNull, value), leftVarsDecl) + (ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) } else { val code = s"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" - (ExprCode(code, "false", value), leftVarsDecl) + (ExprCode(code, FalseLiteral, + VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) } }.unzip }