Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,28 +67,34 @@ class EquivalentExpressions {
/**
* Adds the expression to this data structure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
* If ignoreLeaf is true, leaf nodes are ignored.
*/
def addExprTree(
root: Expression,
ignoreLeaf: Boolean = true,
skipReferenceToExpressions: Boolean = true): Unit = {
val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) ||
def addExprTree(expr: Expression): Unit = {
val skip = expr.isInstanceOf[LeafExpression] ||
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the code change, I don't see any place other than tests using ignoreLeaf = false. Curious why we have it.

root.find(_.isInstanceOf[LambdaVariable]).isDefined
// There are some special expressions that we should not recurse into children.
expr.find(_.isInstanceOf[LambdaVariable]).isDefined

// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.
val shouldRecurse = root match {
// TODO: some expressions implements `CodegenFallback` but can still do codegen,
// e.g. `CaseWhen`, we should support them.
case _: CodegenFallback => false
case _: ReferenceToExpressions if skipReferenceToExpressions => false
case _ => true
// 2. If: common subexpressions will always be evaluated at the beginning, but the true and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this's cool.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just found that not all the children of AtLeastNNonNulls get accessed during evaluation too. Do we need to add it here too?

// false expressions in `If` may not get accessed, according to the predicate
// expression. We should only recurse into the predicate expression.
// 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain
// condition. We should only recurse into the first condition expression as it
// will always get accessed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CaseWhen could be very deep.

CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END
When expr1 = true, returns expr2; when expr3 = true, return expr4; else return expr5.

Compared with the previous impl, will we miss some expression elimination chances?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, CaseWhen implements CodegenFallback. Thus, the previous impl skips it.

// 4. Coalesce: it's also a conditional expression, we should only recurse into the first
// children, because others may not get accessed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although Coalesce might miss some expression elimination chances, I think it is very rare when users use the same expressions in Coalesce.

Could you update the comments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coalesce may be just a small part of the whole expression tree, and the children of Coalesce may be same with other expressions inside the expression tree.

def childrenToRecurse: Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case i: If => i.predicate :: Nil
// `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here.
case c: CaseWhenCodegen => c.children.head :: Nil
case c: Coalesce => c.children.head :: Nil
case other => other.children
}
if (!skip && !addExpr(root) && shouldRecurse) {
root.children.foreach(addExprTree(_, ignoreLeaf))

if (!skip && !addExpr(expr)) {
childrenToRecurse.foreach(addExprTree)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ object UnsafeProjection {
* Returns an UnsafeProjection for given Array of DataTypes.
*/
def create(fields: Array[DataType]): UnsafeProjection = {
create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true)))
}

/**
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -726,18 +726,18 @@ class CodegenContext {
val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]

// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_, true, false))
expressions.foreach(equivalentExpressions.addExprTree)

// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
val codes = commonExprs.map { e =>
val expr = e.head
// Generate the code for this expression tree.
val code = expr.genCode(this)
val state = SubExprEliminationState(code.isNull, code.value)
val eval = expr.genCode(this)
val state = SubExprEliminationState(eval.isNull, eval.value)
e.foreach(subExprEliminationExprs.put(_, state))
code.code.trim
eval.code.trim
}
SubExprCodes(codes, subExprEliminationExprs.toMap)
}
Expand All @@ -747,7 +747,7 @@ class CodegenContext {
* common subexpressions, generates the functions that evaluate those expressions and populates
* the mapping of common subexpressions to the generated functions.
*/
private def subexpressionElimination(expressions: Seq[Expression]) = {
private def subexpressionElimination(expressions: Seq[Expression]): Unit = {
// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_))

Expand All @@ -761,13 +761,13 @@ class CodegenContext {
val value = s"${fnName}Value"

// Generate the code for this expression tree and wrap it in a function.
val code = expr.genCode(this)
val eval = expr.genCode(this)
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
| ${code.code.trim}
| $isNull = ${code.isNull};
| $value = ${code.value};
| ${eval.code.trim}
| $isNull = ${eval.isNull};
| $value = ${eval.value};
|}
""".stripMargin

Expand All @@ -780,9 +780,6 @@ class CodegenContext {
// The cost of doing subexpression elimination is:
// 1. Extra function call, although this is probably *good* as the JIT can decide to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we removed 2. and 3.. We do not need 1., right?

Copy link
Contributor Author

@cloud-fan cloud-fan Jan 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we do have an extra function call to evaluate common subexpression at the beginning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

: ) Just removed 1.. Not the whole sentence

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should still keep it, to make the indent consistent between the "cost" part and the "benefit" part. It also makes it more obvious that we only have one "cost".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine to keep it.

// inline or not.
// 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly
// very often. The reason it is not loaded is because of a prior branch.
// 3. Extra store into isLoaded.
// The benefit doing subexpression elimination is:
// 1. Running the expression logic. Even for a simple expression, it is likely more than 3
// above.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, GetExternalRowField, ValidateExternalType}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ThreadUtils
Expand Down Expand Up @@ -313,4 +313,15 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-17160: field names are properly escaped by AssertTrue") {
GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil)
}

test("should not apply common subexpression elimination on conditional expressions") {
val row = InternalRow(null)
val bound = BoundReference(0, IntegerType, true)
val assertNotNull = AssertNotNull(bound, Nil)
val expr = If(IsNull(bound), Literal(1), Add(assertNotNull, assertNotNull))
val projection = GenerateUnsafeProjection.generate(
Seq(expr), subexpressionEliminationEnabled = true)
// should not throw exception
projection(row)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
val add2 = Add(add, add)

var equivalence = new EquivalentExpressions
equivalence.addExprTree(add, true)
equivalence.addExprTree(abs, true)
equivalence.addExprTree(add2, true)
equivalence.addExprTree(add)
equivalence.addExprTree(abs)
equivalence.addExprTree(add2)

// Should only have one equivalence for `one + two`
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 1)
Expand All @@ -115,41 +115,17 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
val mul2 = Multiply(mul, mul)
val sqrt = Sqrt(mul2)
val sum = Add(mul2, sqrt)
equivalence.addExprTree(mul, true)
equivalence.addExprTree(mul2, true)
equivalence.addExprTree(sqrt, true)
equivalence.addExprTree(sum, true)
equivalence.addExprTree(mul)
equivalence.addExprTree(mul2)
equivalence.addExprTree(sqrt)
equivalence.addExprTree(sum)

// (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 3)
assert(equivalence.getEquivalentExprs(mul).size == 3)
assert(equivalence.getEquivalentExprs(mul2).size == 3)
assert(equivalence.getEquivalentExprs(sqrt).size == 2)
assert(equivalence.getEquivalentExprs(sum).size == 1)

// Some expressions inspired by TPCH-Q1
// sum(l_quantity) as sum_qty,
// sum(l_extendedprice) as sum_base_price,
// sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
// sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
// avg(l_extendedprice) as avg_price,
// avg(l_discount) as avg_disc
equivalence = new EquivalentExpressions
val quantity = Literal(1)
val price = Literal(1.1)
val discount = Literal(.24)
val tax = Literal(0.1)
equivalence.addExprTree(quantity, false)
equivalence.addExprTree(price, false)
equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false)
equivalence.addExprTree(
Multiply(
Multiply(price, Subtract(Literal(1), discount)),
Add(Literal(1), tax)), false)
equivalence.addExprTree(price, false)
equivalence.addExprTree(discount, false)
// quantity, price, discount and (price * (1 - discount))
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 4)
}

test("Expression equivalence - non deterministic") {
Expand All @@ -167,11 +143,24 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
val add = Add(two, fallback)

val equivalence = new EquivalentExpressions
equivalence.addExprTree(add, true)
equivalence.addExprTree(add)
// the `two` inside `fallback` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
}

test("Children of conditional expressions") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To other reviewers: the new addExprTree always ignores the leaf nodes. Thus, these test cases are not needed.

val condition = And(Literal(true), Literal(false))
val add = Add(Literal(1), Literal(2))
val ifExpr = If(condition, add, add)

val equivalence = new EquivalentExpressions
equivalence.addExprTree(ifExpr)
// the `add` inside `If` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
// only ifExpr and its predicate expression
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2)
}
}

case class CodegenFallbackExpression(child: Expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,15 @@ case class SimpleTypedAggregateExpression(
override lazy val aggBufferAttributes: Seq[AttributeReference] =
bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference])

private def serializeToBuffer(expr: Expression): Seq[Expression] = {
bufferSerializer.map(_.transform {
case _: BoundReference => expr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why bufferSerializer now replaced with bufferDeserializer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, typo...

})
}

override lazy val initialValues: Seq[Expression] = {
val zero = Literal.fromObject(aggregator.zero, bufferExternalType)
bufferSerializer.map(ReferenceToExpressions(_, zero :: Nil))
serializeToBuffer(zero)
}

override lazy val updateExpressions: Seq[Expression] = {
Expand All @@ -154,8 +160,7 @@ case class SimpleTypedAggregateExpression(
"reduce",
bufferExternalType,
bufferDeserializer :: inputDeserializer.get :: Nil)

bufferSerializer.map(ReferenceToExpressions(_, reduced :: Nil))
serializeToBuffer(reduced)
}

override lazy val mergeExpressions: Seq[Expression] = {
Expand All @@ -170,8 +175,7 @@ case class SimpleTypedAggregateExpression(
"merge",
bufferExternalType,
leftBuffer :: rightBuffer :: Nil)

bufferSerializer.map(ReferenceToExpressions(_, merged :: Nil))
serializeToBuffer(merged)
}

override lazy val evaluateExpression: Expression = {
Expand All @@ -181,19 +185,17 @@ case class SimpleTypedAggregateExpression(
outputExternalType,
bufferDeserializer :: Nil)

val outputSerializeExprs = outputSerializer.map(_.transform {
case _: BoundReference => resultObj
})

dataType match {
case s: StructType =>
case _: StructType =>
val objRef = outputSerializer.head.find(_.isInstanceOf[BoundReference]).get
val struct = If(
IsNull(objRef),
Literal.create(null, dataType),
CreateStruct(outputSerializer))
ReferenceToExpressions(struct, resultObj :: Nil)
If(IsNull(objRef), Literal.create(null, dataType), CreateStruct(outputSerializeExprs))
case _ =>
assert(outputSerializer.length == 1)
outputSerializer.head transform {
case b: BoundReference => resultObj
}
assert(outputSerializeExprs.length == 1)
outputSerializeExprs.head
}
}

Expand Down