Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -76,7 +76,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
| ${CodeGenerator.defaultValue(dataType)} : ($value);
""".stripMargin)
} else {
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false")
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ abstract class Expression extends TreeNode[Expression] {
}.getOrElse {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val eval = doGenCode(ctx, ExprCode("", isNull, value))
val eval = doGenCode(ctx, ExprCode("",
VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN),
VariableValue(value, CodeGenerator.javaType(dataType))))
reduceCodeSize(ctx, eval)
if (eval.code.nonEmpty) {
// Add `this` in the comment.
Expand All @@ -118,10 +120,10 @@ abstract class Expression extends TreeNode[Expression] {
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
// TODO: support whole stage codegen too
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
eval.isNull = globalIsNull
eval.isNull = GlobalValue(globalIsNull, CodeGenerator.JAVA_BOOLEAN)
s"$globalIsNull = $localIsNull;"
} else {
""
Expand All @@ -140,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] {
|}
""".stripMargin)

eval.value = newValue
eval.value = VariableValue(newValue, javaType)
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
}
}
Expand Down Expand Up @@ -446,7 +448,7 @@ abstract class UnaryExpression extends Expression {
boolean ${ev.isNull} = false;
${childGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = "false")
$resultCode""", isNull = FalseLiteral)
}
}
}
Expand Down Expand Up @@ -546,7 +548,7 @@ abstract class BinaryExpression extends Expression {
${leftGen.code}
${rightGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = "false")
$resultCode""", isNull = FalseLiteral)
}
}
}
Expand Down Expand Up @@ -690,7 +692,7 @@ abstract class TernaryExpression extends Expression {
${midGen.code}
${rightGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = "false")
$resultCode""", isNull = FalseLiteral)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.types.{DataType, LongType}

/**
Expand Down Expand Up @@ -73,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {

ev.copy(code = s"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
$countTerm++;""", isNull = "false")
$countTerm++;""", isNull = FalseLiteral)
}

override def prettyName: String = "monotonically_increasing_id"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.types.{DataType, IntegerType}

/**
Expand Down Expand Up @@ -47,6 +47,6 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
isNull = "false")
isNull = FalseLiteral)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,8 @@ case class Least(children: Seq[Expression]) extends Expression {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)
ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull),
CodeGenerator.JAVA_BOOLEAN)
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
Expand Down Expand Up @@ -680,7 +681,8 @@ case class Greatest(children: Seq[Expression]) extends Expression {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)
ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull),
CodeGenerator.JAVA_BOOLEAN)
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,17 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
*/
case class ExprCode(var code: String, var isNull: String, var value: String)
case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)

object ExprCode {
def forNullValue(dataType: DataType): ExprCode = {
val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true)
ExprCode(code = "", isNull = "true", value = defaultValueLiteral)
ExprCode(code = "", isNull = TrueLiteral,
value = LiteralValue(defaultValueLiteral, CodeGenerator.javaType(dataType)))
}

def forNonNullValue(value: String): ExprCode = {
ExprCode(code = "", isNull = "false", value = value)
def forNonNullValue(value: ExprValue): ExprCode = {
ExprCode(code = "", isNull = FalseLiteral, value = value)
}
}

Expand All @@ -77,7 +78,7 @@ object ExprCode {
* @param value A term for a value of a common sub-expression. Not valid if `isNull`
* is set to `true`.
*/
case class SubExprEliminationState(isNull: String, value: String)
case class SubExprEliminationState(isNull: ExprValue, value: ExprValue)

/**
* Codes and common subexpressions mapping used for subexpression elimination.
Expand Down Expand Up @@ -330,7 +331,7 @@ class CodegenContext {
case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
case _ => s"$value = $initCode;"
}
ExprCode(code, "false", value)
ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType)))
}

def declareMutableStates(): String = {
Expand Down Expand Up @@ -1003,7 +1004,8 @@ class CodegenContext {
// at least two nodes) as the cost of doing it is expected to be low.

subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
val state = SubExprEliminationState(isNull, value)
val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN),
GlobalValue(value, javaType(expr.dataType)))
subExprEliminationExprs ++= e.map(_ -> state).toMap
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ trait CodegenFallback extends Expression {
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
$javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
""", isNull = "false")
""", isNull = FalseLiteral)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.codegen

import scala.language.implicitConversions

import org.apache.spark.sql.types.DataType

// An abstraction that represents the evaluation result of [[ExprCode]].
abstract class ExprValue {

val javaType: String

// Whether we can directly access the evaluation value anywhere.
// For example, a variable created outside a method can not be accessed inside the method.
// For such cases, we may need to pass the evaluation as parameter.
val canDirectAccess: Boolean

def isPrimitive: Boolean = CodeGenerator.isPrimitiveType(javaType)
}

object ExprValue {
implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString
}

// A literal evaluation of [[ExprCode]].
class LiteralValue(val value: String, val javaType: String) extends ExprValue {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not a case class?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently have case objects for TrueLiteral and FalseLiteral which extends LiteralValue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

override def toString: String = value
override val canDirectAccess: Boolean = true
}

object LiteralValue {
def apply(value: String, javaType: String): LiteralValue = new LiteralValue(value, javaType)
def unapply(literal: LiteralValue): Option[(String, String)] =
Some((literal.value, literal.javaType))
}

// A variable evaluation of [[ExprCode]].
case class VariableValue(
val variableName: String,
val javaType: String) extends ExprValue {
override def toString: String = variableName
override val canDirectAccess: Boolean = false
}

// A statement evaluation of [[ExprCode]].
case class StatementValue(
val statement: String,
val javaType: String,
val canDirectAccess: Boolean = false) extends ExprValue {
override def toString: String = statement
}

// A global variable evaluation of [[ExprCode]].
case class GlobalValue(val value: String, val javaType: String) extends ExprValue {
override def toString: String = value
override val canDirectAccess: Boolean = true
}

case object TrueLiteral extends LiteralValue("true", "boolean")
case object FalseLiteral extends LiteralValue("false", "boolean")
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination)

// 4-tuples: (code for projection, isNull variable name, value variable name, column index)
val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map {
val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map {
case (ev, i) =>
val e = expressions(i)
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value")
Expand All @@ -69,7 +69,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
|${ev.code}
|$isNull = ${ev.isNull};
|$value = ${ev.value};
""".stripMargin, isNull, value, i)
""".stripMargin, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), value, i)
} else {
(s"""
|${ev.code}
Expand All @@ -83,7 +83,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP

val updates = validExpr.zip(projectionCodes).map {
case (e, (_, isNull, value, i)) =>
val ev = ExprCode("", isNull, value)
val ev = ExprCode("", isNull, GlobalValue(value, CodeGenerator.javaType(e.dataType)))
CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val rowClass = classOf[GenericInternalRow].getName

val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt)
val converter = convertToSafe(ctx,
StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString),
CodeGenerator.javaType(dt)), dt)
s"""
if (!$tmpInput.isNullAt($i)) {
${converter.code}
Expand All @@ -74,7 +76,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
|final InternalRow $output = new $rowClass($values);
""".stripMargin

ExprCode(code, "false", output)
ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow"))
}

private def createCodeForArray(
Expand All @@ -89,8 +91,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val index = ctx.freshName("index")
val arrayClass = classOf[GenericArrayData].getName

val elementConverter = convertToSafe(
ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType)
val elementConverter = convertToSafe(ctx,
StatementValue(CodeGenerator.getValue(tmpInput, elementType, index),
CodeGenerator.javaType(elementType)), elementType)
val code = s"""
final ArrayData $tmpInput = $input;
final int $numElements = $tmpInput.numElements();
Expand All @@ -104,7 +107,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
final ArrayData $output = new $arrayClass($values);
"""

ExprCode(code, "false", output)
ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData"))
}

private def createCodeForMap(
Expand All @@ -125,19 +128,19 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value});
"""

ExprCode(code, "false", output)
ExprCode(code, FalseLiteral, VariableValue(output, "MapData"))
}

@tailrec
private def convertToSafe(
ctx: CodegenContext,
input: String,
input: ExprValue,
dataType: DataType): ExprCode = dataType match {
case s: StructType => createCodeForStruct(ctx, input, s)
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
case _ => ExprCode("", "false", input)
case _ => ExprCode("", FalseLiteral, input)
}

protected def create(expressions: Seq[Expression]): Projection = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString))
ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", CodeGenerator.JAVA_BOOLEAN),
StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString),
CodeGenerator.javaType(dt)))
}

val rowWriterClass = classOf[UnsafeRowWriter].getName
Expand Down Expand Up @@ -334,7 +336,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$evalSubexpr
$writeExpressions
"""
ExprCode(code, "false", s"$rowWriter.getRow()")
// `rowWriter` is declared as a class field, so we can access it directly in methods.
ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow",
canDirectAccess = true))
}

protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.util.Comparator

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -55,7 +55,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
boolean ${ev.isNull} = false;
${childGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
(${childGen.value}).numElements();""", isNull = "false")
(${childGen.value}).numElements();""", isNull = FalseLiteral)
}
}

Expand Down
Loading