Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ abstract class Expression extends TreeNode[Expression] {
ctx.subExprEliminationExprs.get(ExpressionEquals(this)).map { subExprState =>
// This expression is repeated which means that the code to evaluate it has already been added
// as a function before. In that case, we just re-use it.
ctx.collectedSubExprEliminationExprs += subExprState
ExprCode(
ctx.registerComment(this.toString),
subExprState.eval.isNull,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,13 @@ class CodegenContext extends Logging {
private[expressions] var subExprEliminationExprs =
Map.empty[ExpressionEquals, SubExprEliminationState]

/**
* This purpose of this variable is used to keep some `SubExprEliminationState`s collected
* during expression tree codegen.
*/
private[sql] var collectedSubExprEliminationExprs =
ArrayBuffer.empty[SubExprEliminationState]

// The collection of sub-expression result resetting methods that need to be called on each row.
private val subexprFunctions = mutable.ArrayBuffer.empty[String]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.{Future => JFuture}
import java.util.concurrent.TimeUnit._

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext
import scala.concurrent.duration.Duration

Expand Down Expand Up @@ -139,25 +140,37 @@ trait GeneratePredicateHelper extends PredicateHelper {
outputAttrs: Seq[Attribute],
notNullPreds: Seq[Expression],
otherPreds: Seq[Expression],
nonNullAttrExprIds: Seq[ExprId]): String = {
nonNullAttrExprIds: Seq[ExprId],
subexpressionEliminationEnabled: Boolean = false): String = {
/**
* Generates code for `c`, using `in` for input attributes and `attrs` for nullability.
*/
def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = {
def genPredicate(
c: Expression,
required: AttributeSet,
in: Seq[ExprCode],
attrs: Seq[Attribute],
states: Map[ExpressionEquals, SubExprEliminationState] = Map.empty): String = {
val bound = BindReferences.bindReference(c, attrs)
val evaluated = evaluateRequiredVariables(inputAttrs, in, c.references)
val evaluated = evaluateRequiredVariables(inputAttrs, in, required)

// Generate the code for the predicate.
val ev = ExpressionCanonicalizer.execute(bound).genCode(ctx)
val ev = ctx.withSubExprEliminationExprs(states) {
Seq(ExpressionCanonicalizer.execute(bound).genCode(ctx))
}.head
val nullCheck = if (bound.nullable) {
s"${ev.isNull} || "
} else {
s""
}

s"""
|// RequiredVariables
|$evaluated
|// end of RequiredVariables
|// ev code
|${ev.code}
|// end of ev code
|if (${nullCheck}!${ev.value}) continue;
""".stripMargin
}
Expand All @@ -173,41 +186,88 @@ trait GeneratePredicateHelper extends PredicateHelper {
// TODO: revisit this. We can consider reordering predicates as well.
val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length)
val extraIsNotNullAttrs = mutable.Set[Attribute]()
val generated = otherPreds.map { c =>
val (generatedNullChecks, predsToGen) = otherPreds.map { c =>
val nullChecks = c.references.map { r =>
val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)}
if (idx != -1 && !generatedIsNotNullChecks(idx)) {
generatedIsNotNullChecks(idx) = true
// Use the child's output. The nullability is what the child produced.
genPredicate(notNullPreds(idx), inputExprCode, inputAttrs)
genPredicate(notNullPreds(idx), notNullPreds(idx).references, inputExprCode, inputAttrs)
} else if (nonNullAttrExprIds.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) {
extraIsNotNullAttrs += r
genPredicate(IsNotNull(r), inputExprCode, inputAttrs)
genPredicate(IsNotNull(r), r.references, inputExprCode, inputAttrs)
} else {
""
}
}.mkString("\n").trim

(nullChecks, c)
}.unzip

val (subExprsCode, generatedPreds, localValInputs) = if (subexpressionEliminationEnabled) {
// To do subexpression elimination, we need to use bound expressions. Although `genPredicate`
// will bind expressions too, for simplicity we don't skip binding in `genPredicate` for this
// case.
val boundPredsToGen =
predsToGen.map(pred => (BindReferences.bindReference(pred, outputAttrs), pred.references))
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundPredsToGen.map(_._1))

// Here we use *this* operator's output with this output's nullability since we already
// enforced them with the IsNotNull checks above.
val (selectedSubExprs, genPreds, selectedLocalInputs) = boundPredsToGen.map { pred =>
ctx.collectedSubExprEliminationExprs = ArrayBuffer.empty
val gen = genPredicate(pred._1, pred._2, inputExprCode, outputAttrs, subExprs.states)
val localInputs = CodeGenerator.getLocalInputVariableValues(ctx, pred._1)._2
// For subexpression elimination enabled case, we collect common subexpressions for each
// predicate, and its required local input variables (if subexpressions in split functions).
(ctx.evaluateSubExprEliminationState(ctx.collectedSubExprEliminationExprs.toIterable), gen,
evaluateVariables(localInputs.toSeq))
}.unzip3

(selectedSubExprs, genPreds, selectedLocalInputs)
} else {
// Here we use *this* operator's output with this output's nullability since we already
// enforced them with the IsNotNull checks above.
(Seq.empty,
predsToGen.map(pred => genPredicate(pred, pred.references, inputExprCode, outputAttrs)),
Seq.empty)
}

val generated = generatedNullChecks.zipWithIndex.map { case (nullChecks, index) =>
val localInputs = if (index < localValInputs.length) {
localValInputs(index)
} else {
""
}
val subExpr = if (index < subExprsCode.length) {
subExprsCode(index)
} else {
""
}
s"""
|$nullChecks
|${genPredicate(c, inputExprCode, outputAttrs)}
|// common subexpressions
|$localInputs
|$subExpr
|// end of common subexpressions
|${generatedPreds(index)}
""".stripMargin.trim
}.mkString("\n")

val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) =>
if (!generatedIsNotNullChecks(idx)) {
genPredicate(c, inputExprCode, inputAttrs)
genPredicate(c, c.references, inputExprCode, inputAttrs)
} else {
""
}
}.mkString("\n")

s"""
val predicateCode = s"""
|$generated
|$nullChecks
""".stripMargin

predicateCode
}
}

Expand Down Expand Up @@ -245,7 +305,8 @@ case class FilterExec(condition: Expression, child: SparkPlan)
val numOutput = metricTerm(ctx, "numOutputRows")

val predicateCode = generatePredicateCode(
ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes)
ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes,
conf.subexpressionEliminationEnabled)

// Reset the isNull to false for the not-null columns, then the followed operators could
// generate better code (remove dead branches).
Expand Down