diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 221f5ae73673e..bc6381ea82dbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -139,6 +139,7 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(ExpressionEquals(this)).map { subExprState => // This expression is repeated which means that the code to evaluate it has already been added // as a function before. In that case, we just re-use it. + ctx.collectedSubExprEliminationExprs += subExprState ExprCode( ctx.registerComment(this.toString), subExprState.eval.isNull, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 43ac2abaf7e82..7474fc0f99a5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -427,6 +427,13 @@ class CodegenContext extends Logging { private[expressions] var subExprEliminationExprs = Map.empty[ExpressionEquals, SubExprEliminationState] + /** + * This purpose of this variable is used to keep some `SubExprEliminationState`s collected + * during expression tree codegen. + */ + private[sql] var collectedSubExprEliminationExprs = + ArrayBuffer.empty[SubExprEliminationState] + // The collection of sub-expression result resetting methods that need to be called on each row. private val subexprFunctions = mutable.ArrayBuffer.empty[String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 7bd4dc7be129b..6ae5ccc5267e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -21,6 +21,7 @@ import java.util.concurrent.{Future => JFuture} import java.util.concurrent.TimeUnit._ import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext import scala.concurrent.duration.Duration @@ -139,16 +140,24 @@ trait GeneratePredicateHelper extends PredicateHelper { outputAttrs: Seq[Attribute], notNullPreds: Seq[Expression], otherPreds: Seq[Expression], - nonNullAttrExprIds: Seq[ExprId]): String = { + nonNullAttrExprIds: Seq[ExprId], + subexpressionEliminationEnabled: Boolean = false): String = { /** * Generates code for `c`, using `in` for input attributes and `attrs` for nullability. */ - def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = { + def genPredicate( + c: Expression, + required: AttributeSet, + in: Seq[ExprCode], + attrs: Seq[Attribute], + states: Map[ExpressionEquals, SubExprEliminationState] = Map.empty): String = { val bound = BindReferences.bindReference(c, attrs) - val evaluated = evaluateRequiredVariables(inputAttrs, in, c.references) + val evaluated = evaluateRequiredVariables(inputAttrs, in, required) // Generate the code for the predicate. - val ev = ExpressionCanonicalizer.execute(bound).genCode(ctx) + val ev = ctx.withSubExprEliminationExprs(states) { + Seq(ExpressionCanonicalizer.execute(bound).genCode(ctx)) + }.head val nullCheck = if (bound.nullable) { s"${ev.isNull} || " } else { @@ -156,8 +165,12 @@ trait GeneratePredicateHelper extends PredicateHelper { } s""" + |// RequiredVariables |$evaluated + |// end of RequiredVariables + |// ev code |${ev.code} + |// end of ev code |if (${nullCheck}!${ev.value}) continue; """.stripMargin } @@ -173,41 +186,88 @@ trait GeneratePredicateHelper extends PredicateHelper { // TODO: revisit this. We can consider reordering predicates as well. val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) val extraIsNotNullAttrs = mutable.Set[Attribute]() - val generated = otherPreds.map { c => + val (generatedNullChecks, predsToGen) = otherPreds.map { c => val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} if (idx != -1 && !generatedIsNotNullChecks(idx)) { generatedIsNotNullChecks(idx) = true // Use the child's output. The nullability is what the child produced. - genPredicate(notNullPreds(idx), inputExprCode, inputAttrs) + genPredicate(notNullPreds(idx), notNullPreds(idx).references, inputExprCode, inputAttrs) } else if (nonNullAttrExprIds.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) { extraIsNotNullAttrs += r - genPredicate(IsNotNull(r), inputExprCode, inputAttrs) + genPredicate(IsNotNull(r), r.references, inputExprCode, inputAttrs) } else { "" } }.mkString("\n").trim + (nullChecks, c) + }.unzip + + val (subExprsCode, generatedPreds, localValInputs) = if (subexpressionEliminationEnabled) { + // To do subexpression elimination, we need to use bound expressions. Although `genPredicate` + // will bind expressions too, for simplicity we don't skip binding in `genPredicate` for this + // case. + val boundPredsToGen = + predsToGen.map(pred => (BindReferences.bindReference(pred, outputAttrs), pred.references)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundPredsToGen.map(_._1)) + + // Here we use *this* operator's output with this output's nullability since we already + // enforced them with the IsNotNull checks above. + val (selectedSubExprs, genPreds, selectedLocalInputs) = boundPredsToGen.map { pred => + ctx.collectedSubExprEliminationExprs = ArrayBuffer.empty + val gen = genPredicate(pred._1, pred._2, inputExprCode, outputAttrs, subExprs.states) + val localInputs = CodeGenerator.getLocalInputVariableValues(ctx, pred._1)._2 + // For subexpression elimination enabled case, we collect common subexpressions for each + // predicate, and its required local input variables (if subexpressions in split functions). + (ctx.evaluateSubExprEliminationState(ctx.collectedSubExprEliminationExprs.toIterable), gen, + evaluateVariables(localInputs.toSeq)) + }.unzip3 + + (selectedSubExprs, genPreds, selectedLocalInputs) + } else { // Here we use *this* operator's output with this output's nullability since we already // enforced them with the IsNotNull checks above. + (Seq.empty, + predsToGen.map(pred => genPredicate(pred, pred.references, inputExprCode, outputAttrs)), + Seq.empty) + } + + val generated = generatedNullChecks.zipWithIndex.map { case (nullChecks, index) => + val localInputs = if (index < localValInputs.length) { + localValInputs(index) + } else { + "" + } + val subExpr = if (index < subExprsCode.length) { + subExprsCode(index) + } else { + "" + } s""" |$nullChecks - |${genPredicate(c, inputExprCode, outputAttrs)} + |// common subexpressions + |$localInputs + |$subExpr + |// end of common subexpressions + |${generatedPreds(index)} """.stripMargin.trim }.mkString("\n") val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) => if (!generatedIsNotNullChecks(idx)) { - genPredicate(c, inputExprCode, inputAttrs) + genPredicate(c, c.references, inputExprCode, inputAttrs) } else { "" } }.mkString("\n") - s""" + val predicateCode = s""" |$generated |$nullChecks """.stripMargin + + predicateCode } } @@ -245,7 +305,8 @@ case class FilterExec(condition: Expression, child: SparkPlan) val numOutput = metricTerm(ctx, "numOutputRows") val predicateCode = generatePredicateCode( - ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes) + ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes, + conf.subexpressionEliminationEnabled) // Reset the isNull to false for the not-null columns, then the followed operators could // generate better code (remove dead branches).