Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -53,7 +53,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
ev.copy(code = oev.code)
} else {
assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
val javaType = CodeGenerator.javaType(dataType)
val javaType = inline"${CodeGenerator.javaType(dataType)}"
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable) {
ev.copy(code =
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ abstract class Expression extends TreeNode[Expression] {
""
}

val javaType = CodeGenerator.javaType(dataType)
val newValue = ctx.freshName("value")
val javaType = inline"${CodeGenerator.javaType(dataType)}"
val newValue = JavaCode.variable(ctx.freshName("value"), dataType)

val funcName = ctx.freshName(nodeName)
val funcFullName = ctx.addNewFunction(funcName,
Expand All @@ -143,8 +143,8 @@ abstract class Expression extends TreeNode[Expression] {
|}
""".stripMargin)

eval.value = JavaCode.variable(newValue, dataType)
eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
eval.value = newValue
eval.code = code"$javaType $newValue = ${inline"$funcFullName"}(${ctx.INPUT_ROW});"
}
}

Expand Down Expand Up @@ -416,9 +416,9 @@ abstract class UnaryExpression extends Expression {
protected def defineCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: String => String): ExprCode = {
f: ExprValue => Block): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
s"${ev.value} = ${f(eval)};"
code"${ev.value} = ${f(eval)};"
})
}

Expand All @@ -432,22 +432,23 @@ abstract class UnaryExpression extends Expression {
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: String => String): ExprCode = {
f: ExprValue => Block): ExprCode = {
val childGen = child.genCode(ctx)
val resultCode = f(childGen.value)
val javaType = inline"${CodeGenerator.javaType(dataType)}"

if (nullable) {
val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
ev.copy(code = code"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = code"""
${childGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
}
}
Expand Down Expand Up @@ -504,9 +505,9 @@ abstract class BinaryExpression extends Expression {
protected def defineCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String) => String): ExprCode = {
f: (ExprValue, ExprValue) => Block): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"${ev.value} = ${f(eval1, eval2)};"
code"${ev.value} = ${f(eval1, eval2)};"
})
}

Expand All @@ -521,16 +522,17 @@ abstract class BinaryExpression extends Expression {
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String) => String): ExprCode = {
f: (ExprValue, ExprValue) => Block): ExprCode = {
val leftGen = left.genCode(ctx)
val rightGen = right.genCode(ctx)
val resultCode = f(leftGen.value, rightGen.value)
val javaType = inline"${CodeGenerator.javaType(dataType)}"

if (nullable) {
val nullSafeEval =
leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) {
rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) {
s"""
code"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
Expand All @@ -539,14 +541,14 @@ abstract class BinaryExpression extends Expression {

ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = code"""
${leftGen.code}
${rightGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
}
}
Expand All @@ -568,9 +570,9 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
*/
def inputType: AbstractDataType

def symbol: String
def symbol: JavaCode

def sqlOperator: String = symbol
def sqlOperator: String = symbol.code

override def toString: String = s"($left $symbol $right)"

Expand Down Expand Up @@ -644,9 +646,9 @@ abstract class TernaryExpression extends Expression {
protected def defineCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String) => String): ExprCode = {
f: (ExprValue, ExprValue, ExprValue) => Block): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => {
s"${ev.value} = ${f(eval1, eval2, eval3)};"
code"${ev.value} = ${f(eval1, eval2, eval3)};"
})
}

Expand All @@ -661,18 +663,19 @@ abstract class TernaryExpression extends Expression {
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String) => String): ExprCode = {
f: (ExprValue, ExprValue, ExprValue) => Block): ExprCode = {
val leftGen = children(0).genCode(ctx)
val midGen = children(1).genCode(ctx)
val rightGen = children(2).genCode(ctx)
val resultCode = f(leftGen.value, midGen.value, rightGen.value)
val javaType = inline"${CodeGenerator.javaType(dataType)}"

if (nullable) {
val nullSafeEval =
leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) {
midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) {
rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) {
s"""
code"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
Expand All @@ -682,14 +685,14 @@ abstract class TernaryExpression extends Expression {

ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(code = code"""
${leftGen.code}
${midGen.code}
${rightGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType}

Expand Down Expand Up @@ -67,14 +67,16 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count")
val partitionMaskTerm = "partitionMask"
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm)
val countTerm = JavaCode.variable(
ctx.addMutableState(CodeGenerator.JAVA_LONG, "count"), LongType)
val partitionMaskTerm = JavaCode.variable("partitionMask", LongType)
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm.code)
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
val javaType = inline"${CodeGenerator.javaType(dataType)}"

ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
final $javaType ${ev.value} = $partitionMaskTerm + $countTerm;
$countTerm++;""", isNull = FalseLiteral)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -997,9 +997,12 @@ case class ScalaUDF(
val converters: Array[Any => Any] = children.map { c =>
CatalystTypeConverters.createToScalaConverter(c.dataType)
}.toArray :+ CatalystTypeConverters.createToCatalystConverter(dataType)
val convertersTerm = ctx.addReferenceObj("converters", converters, s"$converterClassName[]")
val errorMsgTerm = ctx.addReferenceObj("errMsg", udfErrorMessage)
val resultTerm = ctx.freshName("result")
val convertersTerm = JavaCode.global(
ctx.addReferenceObj("converters", converters, s"$converterClassName[]"),
classOf[Array[Object]])
val errorMsgTerm = JavaCode.global(ctx.addReferenceObj("errMsg", udfErrorMessage),
classOf[String])
val resultTerm = JavaCode.variable(ctx.freshName("result"), dataType)

// codegen for children expressions
val evals = children.map(_.genCode(ctx))
Expand All @@ -1008,20 +1011,28 @@ case class ScalaUDF(
// We need to get the boxedType of dataType's javaType here. Because for the dataType
// such as IntegerType, its javaType is `int` and the returned type of user-defined
// function is Object. Trying to convert an Object to `int` will cause casting exception.
val evalCode = evals.map(_.code).mkString("\n")
val evalCode = Blocks(evals.map(_.code))
val (funcArgs, initArgs) = evals.zipWithIndex.map { case (eval, i) =>
val argTerm = ctx.freshName("arg")
val convert = s"$convertersTerm[$i].apply(${eval.value})"
val initArg = s"Object $argTerm = ${eval.isNull} ? null : $convert;"
val argTerm = JavaCode.variable(ctx.freshName("arg"), classOf[Object])
val convert = code"$convertersTerm[$i].apply(${eval.value})"
val initArg = code"Object $argTerm = ${eval.isNull} ? null : $convert;"
(argTerm, initArg)
}.unzip

val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}")
val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})"
val resultConverter = s"$convertersTerm[${children.length}]"
val boxedType = CodeGenerator.boxedType(dataType)
val udf = JavaCode.global(
ctx.addReferenceObj("udf", function, s"scala.Function${children.length}"), classOf[Object])
val funcArgBlock = funcArgs.foldLeft[Block](EmptyBlock) { (block, arg) =>
if (block.length == 0) {
code"$arg"
} else {
code"$block, $arg"
}
}
val getFuncResult = code"$udf.apply($funcArgBlock)"
val resultConverter = code"$convertersTerm[${children.length}]"
val boxedType = inline"${CodeGenerator.boxedType(dataType)}"
val callFunc =
s"""
code"""
|$boxedType $resultTerm = null;
|try {
| $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult);
Expand All @@ -1030,14 +1041,15 @@ case class ScalaUDF(
|}
""".stripMargin

val javaType = inline"${CodeGenerator.javaType(dataType)}"
ev.copy(code =
code"""
|$evalCode
|${initArgs.mkString("\n")}
|${Blocks(initArgs)}
|$callFunc
|
|boolean ${ev.isNull} = $resultTerm == null;
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, JavaCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -188,32 +188,32 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childCode = child.child.genCode(ctx)
val input = childCode.value
val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName
val DoublePrefixCmp = classOf[DoublePrefixComparator].getName
val StringPrefixCmp = classOf[StringPrefixComparator].getName
val BinaryPrefixCmp = inline"${classOf[BinaryPrefixComparator].getName}"
val DoublePrefixCmp = inline"${classOf[DoublePrefixComparator].getName}"
val StringPrefixCmp = inline"${classOf[StringPrefixComparator].getName}"
val prefixCode = child.child.dataType match {
case BooleanType =>
s"$input ? 1L : 0L"
code"$input ? 1L : 0L"
case _: IntegralType =>
s"(long) $input"
code"(long) $input"
case DateType | TimestampType =>
s"(long) $input"
code"(long) $input"
case FloatType | DoubleType =>
s"$DoublePrefixCmp.computePrefix((double)$input)"
case StringType => s"$StringPrefixCmp.computePrefix($input)"
case BinaryType => s"$BinaryPrefixCmp.computePrefix($input)"
code"$DoublePrefixCmp.computePrefix((double)$input)"
case StringType => code"$StringPrefixCmp.computePrefix($input)"
case BinaryType => code"$BinaryPrefixCmp.computePrefix($input)"
case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
s"$input.toUnscaledLong()"
code"$input.toUnscaledLong()"
} else {
// reduce the scale to fit in a long
val p = Decimal.MAX_LONG_DIGITS
val s = p - (dt.precision - dt.scale)
s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L"
code"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L"
}
case dt: DecimalType =>
s"$DoublePrefixCmp.computePrefix($input.toDouble())"
case _ => "0L"
code"$DoublePrefixCmp.computePrefix($input.toDouble())"
case _ => code"0L"
}

ev.copy(code = childCode.code +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}

Expand All @@ -44,10 +44,11 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
override protected def evalInternal(input: InternalRow): Int = partitionId

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val idTerm = "partitionId"
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
val idTerm = JavaCode.variable("partitionId", dataType)
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm.code)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
val javaType = inline"${CodeGenerator.javaType(dataType)}"
ev.copy(code = code"final $javaType ${ev.value} = $idTerm;",
isNull = FalseLiteral)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
Expand Down Expand Up @@ -164,9 +164,10 @@ case class PreciseTimestampConversion(
override def dataType: DataType = toType
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val javaType = inline"${CodeGenerator.javaType(dataType)}"
ev.copy(code = eval.code +
code"""boolean ${ev.isNull} = ${eval.isNull};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
|$javaType ${ev.value} = ${eval.value};
""".stripMargin)
}
override def nullSafeEval(input: Any): Any = input
Expand Down
Loading