Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ abstract class Expression extends TreeNode[Expression] {
}.getOrElse {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val eval = doGenCode(ctx, ExprCode("",
VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN),
VariableValue(value, CodeGenerator.javaType(dataType))))
val eval = doGenCode(ctx, ExprCode(
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, dataType)))
reduceCodeSize(ctx, eval)
if (eval.code.nonEmpty) {
// Add `this` in the comment.
Expand All @@ -123,7 +123,7 @@ abstract class Expression extends TreeNode[Expression] {
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
eval.isNull = GlobalValue(globalIsNull, CodeGenerator.JAVA_BOOLEAN)
eval.isNull = JavaCode.isNullGlobal(globalIsNull)
s"$globalIsNull = $localIsNull;"
} else {
""
Expand All @@ -142,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] {
|}
""".stripMargin)

eval.value = VariableValue(newValue, javaType)
eval.value = JavaCode.variable(newValue, dataType)
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,7 @@ case class Least(children: Seq[Expression]) extends Expression {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull),
CodeGenerator.JAVA_BOOLEAN)
ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull))
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
Expand Down Expand Up @@ -671,8 +670,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull),
CodeGenerator.JAVA_BOOLEAN)
ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull))
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)

object ExprCode {
def apply(isNull: ExprValue, value: ExprValue): ExprCode = {
ExprCode(code = "", isNull, value)
}

def forNullValue(dataType: DataType): ExprCode = {
val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true)
ExprCode(code = "", isNull = TrueLiteral,
value = LiteralValue(defaultValueLiteral, CodeGenerator.javaType(dataType)))
ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
}

def forNonNullValue(value: ExprValue): ExprCode = {
Expand Down Expand Up @@ -331,7 +333,7 @@ class CodegenContext {
case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
case _ => s"$value = $initCode;"
}
ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType)))
ExprCode(code, FalseLiteral, JavaCode.global(value, dataType))
}

def declareMutableStates(): String = {
Expand Down Expand Up @@ -1004,8 +1006,9 @@ class CodegenContext {
// at least two nodes) as the cost of doing it is expected to be low.

subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN),
GlobalValue(value, javaType(expr.dataType)))
val state = SubExprEliminationState(
JavaCode.isNullGlobal(isNull),
JavaCode.global(value, expr.dataType))
subExprEliminationExprs ++= e.map(_ -> state).toMap
}
}
Expand Down Expand Up @@ -1479,6 +1482,26 @@ object CodeGenerator extends Logging {
case _ => "Object"
}

def javaClass(dt: DataType): Class[_] = dt match {
case BooleanType => java.lang.Boolean.TYPE
case ByteType => java.lang.Byte.TYPE
case ShortType => java.lang.Short.TYPE
case IntegerType | DateType => java.lang.Integer.TYPE
case LongType | TimestampType => java.lang.Long.TYPE
case FloatType => java.lang.Float.TYPE
case DoubleType => java.lang.Double.TYPE
case _: DecimalType => classOf[Decimal]
case BinaryType => classOf[Array[Byte]]
case StringType => classOf[UTF8String]
case CalendarIntervalType => classOf[CalendarInterval]
case _: StructType => classOf[InternalRow]
case _: ArrayType => classOf[ArrayData]
case _: MapType => classOf[MapData]
case udt: UserDefinedType[_] => javaClass(udt.sqlType)
case ObjectType(cls) => cls
case _ => classOf[Object]
}

/**
* Returns the boxed type in Java.
*/
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -52,43 +52,45 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
expressions: Seq[Expression],
useSubexprElimination: Boolean): MutableProjection = {
val ctx = newCodeGenContext()
val (validExpr, index) = expressions.zipWithIndex.filter {
val validExpr = expressions.zipWithIndex.filter {
case (NoOp, _) => false
case _ => true
}.unzip
val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination)
}
val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination)

// 4-tuples: (code for projection, isNull variable name, value variable name, column index)
val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map {
case (ev, i) =>
val e = expressions(i)
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value")
if (e.nullable) {
val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map {
case ((e, i), ev) =>
val value = JavaCode.global(
ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value"),
e.dataType)
val (code, isNull) = if (e.nullable) {
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull")
(s"""
|${ev.code}
|$isNull = ${ev.isNull};
|$value = ${ev.value};
""".stripMargin, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), value, i)
""".stripMargin, JavaCode.isNullGlobal(isNull))
} else {
(s"""
|${ev.code}
|$value = ${ev.value};
""".stripMargin, ev.isNull, value, i)
""".stripMargin, FalseLiteral)
}
val update = CodeGenerator.updateColumn(
"mutableRow",
e.dataType,
i,
ExprCode(isNull, value),
e.nullable)
(code, update)
}

// Evaluate all the subexpressions.
val evalSubexpr = ctx.subexprFunctions.mkString("\n")

val updates = validExpr.zip(projectionCodes).map {
case (e, (_, isNull, value, i)) =>
val ev = ExprCode("", isNull, GlobalValue(value, CodeGenerator.javaType(e.dataType)))
CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
}

val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1))
val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates)
val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2))

val codeBody = s"""
public java.lang.Object generate(Object[] references) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen

import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -53,9 +54,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val rowClass = classOf[GenericInternalRow].getName

val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
val converter = convertToSafe(ctx,
StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString),
CodeGenerator.javaType(dt)), dt)
val converter = convertToSafe(
ctx,
JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt),
dt)
s"""
if (!$tmpInput.isNullAt($i)) {
${converter.code}
Expand All @@ -76,7 +78,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
|final InternalRow $output = new $rowClass($values);
""".stripMargin

ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow"))
ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[InternalRow]))
}

private def createCodeForArray(
Expand All @@ -91,9 +93,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val index = ctx.freshName("index")
val arrayClass = classOf[GenericArrayData].getName

val elementConverter = convertToSafe(ctx,
StatementValue(CodeGenerator.getValue(tmpInput, elementType, index),
CodeGenerator.javaType(elementType)), elementType)
val elementConverter = convertToSafe(
ctx,
JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType),
elementType)
val code = s"""
final ArrayData $tmpInput = $input;
final int $numElements = $tmpInput.numElements();
Expand All @@ -107,7 +110,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
final ArrayData $output = new $arrayClass($values);
"""

ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData"))
ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[ArrayData]))
}

private def createCodeForMap(
Expand All @@ -128,7 +131,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value});
"""

ExprCode(code, FalseLiteral, VariableValue(output, "MapData"))
ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[MapData]))
}

@tailrec
Expand All @@ -140,7 +143,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
case _ => ExprCode("", FalseLiteral, input)
case _ => ExprCode(FalseLiteral, input)
}

protected def create(expressions: Seq[Expression]): Projection = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", CodeGenerator.JAVA_BOOLEAN),
StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString),
CodeGenerator.javaType(dt)))
ExprCode(
JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"),
JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt))
}

val rowWriterClass = classOf[UnsafeRowWriter].getName
Expand Down Expand Up @@ -109,7 +109,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}

val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)
if (input.isNull == "false") {
if (input.isNull == FalseLiteral) {
s"""
|${input.code}
|${writeField.trim}
Expand Down Expand Up @@ -292,8 +292,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|$writeExpressions
""".stripMargin
// `rowWriter` is declared as a class field, so we can access it directly in methods.
ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow",
canDirectAccess = true))
ExprCode(code, FalseLiteral, JavaCode.expression(s"$rowWriter.getRow()", classOf[UnsafeRow]))
}

protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
Expand Down
Loading