Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,17 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
s"""
boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
"""
if (nullable) {
s"""
boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
"""
} else {
ev.isNull = "false"
s"""
$javaType ${ev.value} = $value;
"""
}
}
}

Expand All @@ -92,7 +99,7 @@ object BindReferences extends Logging {
sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
}
} else {
BoundReference(ordinal, a.dataType, a.nullable)
BoundReference(ordinal, a.dataType, input(ordinal).nullable)
}
}
}.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,22 @@ object Cast {
private def resolvableNullability(from: Boolean, to: Boolean) = !from || to

private def forceNullable(from: DataType, to: DataType) = (from, to) match {
case (StringType, _: NumericType) => true
case (StringType, TimestampType) => true
case (DoubleType, TimestampType) => true
case (FloatType, TimestampType) => true
case (StringType, DateType) => true
case (_: NumericType, DateType) => true
case (BooleanType, DateType) => true
case (DateType, _: NumericType) => true
case (DateType, BooleanType) => true
case (DoubleType, _: DecimalType) => true
case (FloatType, _: DecimalType) => true
case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null
case (NullType, _) => true
case (_, _) if from == to => false

case (StringType, BinaryType) => false
case (StringType, _) => true
case (_, StringType) => false

case (FloatType | DoubleType, TimestampType) => true
case (TimestampType, DateType) => false
case (_, DateType) => true
case (DateType, TimestampType) => false
case (DateType, _) => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should DateType to StringType be false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, will fix it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current ccde, there is no case for (DateType, StringType).
I can send a PR if this is to be added

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's covered by (_, StringType)

case (_, CalendarIntervalType) => true

case (_, _: DecimalType) => true // overflow
case (_: FractionalType, _: IntegralType) => true // NaN, infinity
case _ => false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,14 +340,21 @@ abstract class UnaryExpression extends Expression {
ev: GeneratedExpressionCode,
f: String => String): String = {
val eval = child.gen(ctx)
val resultCode = f(eval.value)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
$resultCode
}
"""
if (nullable) {
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${eval.isNull}) {
${f(eval.value)}
}
"""
} else {
ev.isNull = "false"
eval.code + s"""
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${f(eval.value)}
"""
}
}
}

Expand Down Expand Up @@ -424,19 +431,30 @@ abstract class BinaryExpression extends Expression {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(eval1.value, eval2.value)
s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if (!${eval2.isNull}) {
$resultCode
} else {
${ev.isNull} = true;
if (nullable) {
s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if (!${eval2.isNull}) {
$resultCode
} else {
${ev.isNull} = true;
}
}
}
"""
"""

} else {
ev.isNull = "false"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this branch necessary? (not suggesting you change it) but does the nullable path collapse correctly if left and right are non nullable? What I mean is:

if eval1.isNull and eval2.isNull is always just false, do we get the same behavior as this special casing from the compiler optimizations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not necessary (in terms of performance). Compiler can do all these, but not sure how far Janino had achieved on constant folding.

We don't need to do this for every expression, but since UnaryExpression/BinaryExpression/TernaryExpression are used by many, this changes may worth it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to Janino the JIT might also do more constant folding etc, which makes it hard to tell unfortunately.

s"""
${eval1.code}
${eval2.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we assuming if a BinaryExpression is not nullable, its children are also not nullable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should forbid non-nullable BinaryExpression to call nullSafeCodeGen as it doesn't make sense(passing a f that supposed to only apply to not-null children, but actually it isn't.), and they should take care of null children themselves, i.e. override genCode directly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can add an assert: assert(nullable || (children.forall(!_.nullable)))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even left or right is nullable, the new code is still correct, if the old code is correct.

"""
}
}
}

Expand Down Expand Up @@ -548,20 +566,31 @@ abstract class TernaryExpression extends Expression {
f: (String, String, String) => String): String = {
val evals = children.map(_.gen(ctx))
val resultCode = f(evals(0).value, evals(1).value, evals(2).value)
s"""
${evals(0).code}
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${evals(0).isNull}) {
${evals(1).code}
if (!${evals(1).isNull}) {
${evals(2).code}
if (!${evals(2).isNull}) {
${ev.isNull} = false; // resultCode could change nullability
$resultCode
if (nullable) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a TernaryExpression is nullable, currently we will always generate 3 nested if branches. But we still have chance to remove some if branches if some children are non-nullable, how about doing this optimization based on children's nullability?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but it have too much combinations.

s"""
${evals(0).code}
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${evals(0).isNull}) {
${evals(1).code}
if (!${evals(1).isNull}) {
${evals(2).code}
if (!${evals(2).isNull}) {
${ev.isNull} = false; // resultCode could change nullability
$resultCode
}
}
}
}
"""
"""
} else {
ev.isNull = "false"
s"""
${evals(0).code}
${evals(1).code}
${evals(2).code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode
"""
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w

override def children: Seq[Expression] = Seq(child)

override def nullable: Boolean = false
override def nullable: Boolean = true

override def dataType: DataType = DoubleType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

/**
Expand All @@ -42,7 +41,7 @@ case class Corr(

override def children: Seq[Expression] = Seq(left, right)

override def nullable: Boolean = false
override def nullable: Boolean = true

override def dataType: DataType = DoubleType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,32 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)

private lazy val count = AttributeReference("count", LongType)()
private lazy val count = AttributeReference("count", LongType, nullable = false)()

override lazy val aggBufferAttributes = count :: Nil

override lazy val initialValues = Seq(
/* count = */ Literal(0L)
)

override lazy val updateExpressions = Seq(
/* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L)
)
override lazy val updateExpressions = {
val nullableChildren = children.filter(_.nullable)
if (nullableChildren.isEmpty) {
Seq(
/* count = */ count + 1L
)
} else {
Seq(
/* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L)
)
}
}

override lazy val mergeExpressions = Seq(
/* count = */ count.left + count.right
)

override lazy val evaluateExpression = Cast(count, LongType)
override lazy val evaluateExpression = count

override def defaultResult: Option[Literal] = Option(Literal(0L))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
private lazy val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
// TODO: Remove this line once we remove the NullType from inputTypes.
case NullType => IntegerType
case _ => child.dataType
}

Expand All @@ -57,18 +55,26 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
/* sum = */ Literal.create(null, sumDataType)
)

override lazy val updateExpressions: Seq[Expression] = Seq(
/* sum = */
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
)
override lazy val updateExpressions: Seq[Expression] = {
if (child.nullable) {
Seq(
/* sum = */
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
)
} else {
Seq(
/* sum = */
Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType))
)
}
}

override lazy val mergeExpressions: Seq[Expression] = {
val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType))
Seq(
/* sum = */
Coalesce(Seq(add, sum.left))
Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left))
)
}

override lazy val evaluateExpression: Expression = Cast(sum, resultType)
override lazy val evaluateExpression: Expression = sum
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,23 @@ trait CodegenFallback extends Expression {

ctx.references += this
val objectTerm = ctx.freshName("obj")
s"""
/* expression: ${this.toCommentSafeString} */
java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
boolean ${ev.isNull} = $objectTerm == null;
${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
if (!${ev.isNull}) {
${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
}
"""
if (nullable) {
s"""
/* expression: ${this.toCommentSafeString} */
Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
boolean ${ev.isNull} = $objectTerm == null;
${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
if (!${ev.isNull}) {
${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
}
"""
} else {
ev.isNull = "false"
s"""
/* expression: ${this.toCommentSafeString} */
Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
"""
}
}
}
Loading