From 368236fd73a21dfdc52c2819e7db26427eea523d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 Dec 2020 22:34:32 -0800 Subject: [PATCH 1/5] Subexpression elimination for whole-stage codegen in Filter. --- .../execution/basicPhysicalOperators.scala | 36 ++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 006fa0fba4138..d99c3d82ea6da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -159,12 +159,18 @@ case class FilterExec(condition: Expression, child: SparkPlan) /** * Generates code for `c`, using `in` for input attributes and `attrs` for nullability. */ - def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = { + def genPredicate( + c: Expression, + in: Seq[ExprCode], + attrs: Seq[Attribute], + states: Map[Expression, SubExprEliminationState] = Map.empty): String = { val bound = BindReferences.bindReference(c, attrs) val evaluated = evaluateRequiredVariables(child.output, in, c.references) // Generate the code for the predicate. - val ev = ExpressionCanonicalizer.execute(bound).genCode(ctx) + val ev = ctx.withSubExprEliminationExprs(states) { + Seq(ExpressionCanonicalizer.execute(bound).genCode(ctx)) + }.head val nullCheck = if (bound.nullable) { s"${ev.isNull} || " } else { @@ -189,7 +195,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // TODO: revisit this. We can consider reordering predicates as well. val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) val extraIsNotNullAttrs = mutable.Set[Attribute]() - val generated = otherPreds.map { c => + val (generatedNullChecks, predsToGen) = otherPreds.map { c => val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} if (idx != -1 && !generatedIsNotNullChecks(idx)) { @@ -204,11 +210,30 @@ case class FilterExec(condition: Expression, child: SparkPlan) } }.mkString("\n").trim + (nullChecks, c) + }.unzip + + val (subExprsCode, generatedPreds, localValInputs) = if (conf.subexpressionEliminationEnabled) { + // To do subexpression elimination, we need to use bound expressions. Although `genPredicate` + // will bind expressions too, for simplicity we don't skip binding in `genPredicate` for this + // case. + val boundPredsToGen = bindReferences[Expression](predsToGen, child.output) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundPredsToGen) // Here we use *this* operator's output with this output's nullability since we already // enforced them with the IsNotNull checks above. + (subExprs.codes.mkString("\n"), + boundPredsToGen.map(genPredicate(_, input, output, subExprs.states)), + subExprs.exprCodesNeedEvaluate) + } else { + // Here we use *this* operator's output with this output's nullability since we already + // enforced them with the IsNotNull checks above. + ("", predsToGen.map(genPredicate(_, input, output)), Seq.empty) + } + + val generated = generatedNullChecks.zip(generatedPreds).map { case (nullChecks, genPred) => s""" |$nullChecks - |${genPredicate(c, input, output)} + |$genPred """.stripMargin.trim }.mkString("\n") @@ -231,6 +256,9 @@ case class FilterExec(condition: Expression, child: SparkPlan) // Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;" s""" + |// common sub-expressions + |${evaluateVariables(localValInputs)} + |$subExprsCode |do { | $generated | $nullChecks From 60319b89fbdfafdb17ed27c7ec1d1a99fff16c33 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 Dec 2020 10:32:24 -0800 Subject: [PATCH 2/5] fix. --- .../sql/catalyst/expressions/EquivalentExpressions.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 1dfff412d9a8e..d2ca212890ad7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -44,6 +44,12 @@ class EquivalentExpressions { // For each expression, the set of equivalent expressions. private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]] + // Stores the mapping between expression and subexpressions. Note that we only store the top-level + // expression relation. For example, for two expressions ExprA(Expr(Expr1, ...), ...), + // ExprB(Expr(Expr1, ...), ...), extracted subexpression is Expr1. Here we will store + // ExprA -> Expr1, ExprB -> Expr1. + private val exprToSubExprMap = mutable.HashMap.empty[Exp, Expr] + /** * Adds each expression to this data structure, grouping them with existing equivalent * expressions. Non-recursive. From 766cb8db3c689e651865e098e87fb54be6e40916 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 Jun 2021 09:08:53 -0700 Subject: [PATCH 3/5] rebased with latest change. --- .../execution/basicPhysicalOperators.scala | 51 +++++++++++++++---- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index b537040fe71df..42669c21af05b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -128,7 +128,7 @@ trait GeneratePredicateHelper extends PredicateHelper { val outputAttrs = outputWithNullability(inputAttrs, nonNullAttrExprIds) generatePredicateCode( ctx, inputAttrs, inputExprCode, outputAttrs, notNullPreds, otherPreds, - nonNullAttrExprIds) + nonNullAttrExprIds)._3 } protected def generatePredicateCode( @@ -138,16 +138,23 @@ trait GeneratePredicateHelper extends PredicateHelper { outputAttrs: Seq[Attribute], notNullPreds: Seq[Expression], otherPreds: Seq[Expression], - nonNullAttrExprIds: Seq[ExprId]): String = { + nonNullAttrExprIds: Seq[ExprId], + subexpressionEliminationEnabled: Boolean = false): (String, String, String) = { /** * Generates code for `c`, using `in` for input attributes and `attrs` for nullability. */ - def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = { + def genPredicate( + c: Expression, + in: Seq[ExprCode], + attrs: Seq[Attribute], + states: Map[Expression, SubExprEliminationState] = Map.empty): String = { val bound = BindReferences.bindReference(c, attrs) val evaluated = evaluateRequiredVariables(inputAttrs, in, c.references) // Generate the code for the predicate. - val ev = ExpressionCanonicalizer.execute(bound).genCode(ctx) + val ev = ctx.withSubExprEliminationExprs(states) { + Seq(ExpressionCanonicalizer.execute(bound).genCode(ctx)) + }.head val nullCheck = if (bound.nullable) { s"${ev.isNull} || " } else { @@ -172,7 +179,7 @@ trait GeneratePredicateHelper extends PredicateHelper { // TODO: revisit this. We can consider reordering predicates as well. val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) val extraIsNotNullAttrs = mutable.Set[Attribute]() - val generated = otherPreds.map { c => + val (generatedNullChecks, predsToGen) = otherPreds.map { c => val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} if (idx != -1 && !generatedIsNotNullChecks(idx)) { @@ -187,11 +194,30 @@ trait GeneratePredicateHelper extends PredicateHelper { } }.mkString("\n").trim + (nullChecks, c) + }.unzip + + val (subExprsCode, generatedPreds, localValInputs) = if (subexpressionEliminationEnabled) { + // To do subexpression elimination, we need to use bound expressions. Although `genPredicate` + // will bind expressions too, for simplicity we don't skip binding in `genPredicate` for this + // case. + val boundPredsToGen = bindReferences[Expression](predsToGen, output) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundPredsToGen) // Here we use *this* operator's output with this output's nullability since we already // enforced them with the IsNotNull checks above. + (subExprs.codes.mkString("\n"), + boundPredsToGen.map(genPredicate(_, inputExprCode, outputAttrs, subExprs.states)), + subExprs.exprCodesNeedEvaluate) + } else { + // Here we use *this* operator's output with this output's nullability since we already + // enforced them with the IsNotNull checks above. + ("", predsToGen.map(genPredicate(_, inputExprCode, outputAttrs)), Seq.empty) + } + + val generated = generatedNullChecks.zip(generatedPreds).map { case (nullChecks, genPred) => s""" |$nullChecks - |${genPredicate(c, inputExprCode, outputAttrs)} + |$genPred """.stripMargin.trim }.mkString("\n") @@ -203,10 +229,13 @@ trait GeneratePredicateHelper extends PredicateHelper { } }.mkString("\n") - s""" + val localInputs = evaluateVariables(localValInputs) + val predicateCode = s""" |$generated |$nullChecks """.stripMargin + + (localInputs, subExprsCode, predicateCode) } } @@ -243,8 +272,9 @@ case class FilterExec(condition: Expression, child: SparkPlan) override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val predicateCode = generatePredicateCode( - ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes) + val (localValInputs, subExprsCode, predicateCode) = generatePredicateCode( + ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes, + false) // Reset the isNull to false for the not-null columns, then the followed operators could // generate better code (remove dead branches). @@ -257,6 +287,9 @@ case class FilterExec(condition: Expression, child: SparkPlan) // Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;" s""" + |// common sub-expressions 123 + |$localValInputs + |$subExprsCode |do { | $predicateCode | $numOutput.add(1); From 6a462b6bbb08301d7e161a217d1734deae8f4581 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 14 Jun 2021 00:28:40 -0700 Subject: [PATCH 4/5] fix. --- .../expressions/codegen/CodeGenerator.scala | 1 + .../execution/basicPhysicalOperators.scala | 37 ++++++++++++------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 81ed64675729f..ecdb58ed5e288 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1048,6 +1048,7 @@ class CodegenContext extends Logging { // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. val commonExprs = equivalentExpressions.getAllEquivalentExprs(1) + println(s"expressions: $expressions, commonExprs: $commonExprs") val nonSplitExprCode = { commonExprs.map { exprs => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 42669c21af05b..9d571fcd15a51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -145,11 +145,12 @@ trait GeneratePredicateHelper extends PredicateHelper { */ def genPredicate( c: Expression, + required: AttributeSet, in: Seq[ExprCode], attrs: Seq[Attribute], states: Map[Expression, SubExprEliminationState] = Map.empty): String = { val bound = BindReferences.bindReference(c, attrs) - val evaluated = evaluateRequiredVariables(inputAttrs, in, c.references) + val evaluated = evaluateRequiredVariables(inputAttrs, in, required) // Generate the code for the predicate. val ev = ctx.withSubExprEliminationExprs(states) { @@ -162,8 +163,12 @@ trait GeneratePredicateHelper extends PredicateHelper { } s""" + |// RequiredVariables |$evaluated + |// end of RequiredVariables + |// ev code |${ev.code} + |// end of ev code |if (${nullCheck}!${ev.value}) continue; """.stripMargin } @@ -185,10 +190,10 @@ trait GeneratePredicateHelper extends PredicateHelper { if (idx != -1 && !generatedIsNotNullChecks(idx)) { generatedIsNotNullChecks(idx) = true // Use the child's output. The nullability is what the child produced. - genPredicate(notNullPreds(idx), inputExprCode, inputAttrs) + genPredicate(notNullPreds(idx), notNullPreds(idx).references, inputExprCode, inputAttrs) } else if (nonNullAttrExprIds.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) { extraIsNotNullAttrs += r - genPredicate(IsNotNull(r), inputExprCode, inputAttrs) + genPredicate(IsNotNull(r), r.references, inputExprCode, inputAttrs) } else { "" } @@ -201,17 +206,23 @@ trait GeneratePredicateHelper extends PredicateHelper { // To do subexpression elimination, we need to use bound expressions. Although `genPredicate` // will bind expressions too, for simplicity we don't skip binding in `genPredicate` for this // case. - val boundPredsToGen = bindReferences[Expression](predsToGen, output) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundPredsToGen) + val boundPredsToGen = + predsToGen.map(pred => (BindReferences.bindReference(pred, outputAttrs), pred.references)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundPredsToGen.map(_._1)) + println(s"subExprs.exprCodesNeedEvaluate: ${subExprs.exprCodesNeedEvaluate}") + val localInputs = evaluateVariables(subExprs.exprCodesNeedEvaluate) // Here we use *this* operator's output with this output's nullability since we already // enforced them with the IsNotNull checks above. (subExprs.codes.mkString("\n"), - boundPredsToGen.map(genPredicate(_, inputExprCode, outputAttrs, subExprs.states)), - subExprs.exprCodesNeedEvaluate) + boundPredsToGen.map(pred => + genPredicate(pred._1, pred._2, inputExprCode, outputAttrs, subExprs.states)), + localInputs) } else { // Here we use *this* operator's output with this output's nullability since we already // enforced them with the IsNotNull checks above. - ("", predsToGen.map(genPredicate(_, inputExprCode, outputAttrs)), Seq.empty) + ("", + predsToGen.map(pred => genPredicate(pred, pred.references, inputExprCode, outputAttrs)), + "") } val generated = generatedNullChecks.zip(generatedPreds).map { case (nullChecks, genPred) => @@ -223,19 +234,18 @@ trait GeneratePredicateHelper extends PredicateHelper { val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) => if (!generatedIsNotNullChecks(idx)) { - genPredicate(c, inputExprCode, inputAttrs) + genPredicate(c, c.references, inputExprCode, inputAttrs) } else { "" } }.mkString("\n") - val localInputs = evaluateVariables(localValInputs) val predicateCode = s""" |$generated |$nullChecks """.stripMargin - (localInputs, subExprsCode, predicateCode) + (localValInputs, subExprsCode, predicateCode) } } @@ -274,7 +284,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) val (localValInputs, subExprsCode, predicateCode) = generatePredicateCode( ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes, - false) + true) // Reset the isNull to false for the not-null columns, then the followed operators could // generate better code (remove dead branches). @@ -287,8 +297,9 @@ case class FilterExec(condition: Expression, child: SparkPlan) // Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;" s""" - |// common sub-expressions 123 + |// common sub-expressions |$localValInputs + |// after local inputs |$subExprsCode |do { | $predicateCode From ba4172076f3f8030510632978a0e47d5b720617a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 3 Jul 2021 23:37:05 -0700 Subject: [PATCH 5/5] Use config. --- .../org/apache/spark/sql/execution/basicPhysicalOperators.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index c96daa8d0ec34..6ae5ccc5267e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -306,7 +306,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) val predicateCode = generatePredicateCode( ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes, - true) + conf.subexpressionEliminationEnabled) // Reset the isNull to false for the not-null columns, then the followed operators could // generate better code (remove dead branches).