Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty =>
case e: Attribute if groupingExprs.find(_ == e).isEmpty =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.")
case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK
case e if groupingExprs.find(_ == e).isDefined => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,14 @@
package org.apache.spark.sql.catalyst.expressions


protected class AttributeEquals(val a: Attribute) {
override def hashCode(): Int = a match {
case ar: AttributeReference => ar.exprId.hashCode()
case a => a.hashCode()
}

override def equals(other: Any): Boolean = (a, other.asInstanceOf[AttributeEquals].a) match {
case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId
case (a1, a2) => a1 == a2
}
}

object AttributeSet {
def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a)))
def apply(a: Attribute): AttributeSet = new AttributeSet(Set(a))

/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
def apply(baseSet: Iterable[Expression]): AttributeSet = {
new AttributeSet(
baseSet
.flatMap(_.references)
.map(new AttributeEquals(_)).toSet)
.flatMap(_.references).toSet)
}
}

Expand All @@ -53,30 +40,30 @@ object AttributeSet {
* and also makes doing transformations hard (we always try keep older trees instead of new ones
* when the transformation was a no-op).
*/
class AttributeSet private (val baseSet: Set[AttributeEquals])
class AttributeSet private (val baseSet: Set[Attribute])
extends Traversable[Attribute] with Serializable {

/** Returns true if the members of this AttributeSet and other are the same. */
override def equals(other: Any): Boolean = other match {
case otherSet: AttributeSet =>
otherSet.size == baseSet.size && baseSet.map(_.a).forall(otherSet.contains)
otherSet.size == baseSet.size && baseSet.forall(otherSet.contains)
case _ => false
}

/** Returns true if this set contains an Attribute with the same expression id as `elem` */
def contains(elem: NamedExpression): Boolean =
baseSet.contains(new AttributeEquals(elem.toAttribute))
baseSet.contains(elem.toAttribute)

/** Returns a new [[AttributeSet]] that contains `elem` in addition to the current elements. */
def +(elem: Attribute): AttributeSet = // scalastyle:ignore
new AttributeSet(baseSet + new AttributeEquals(elem))
new AttributeSet(baseSet + elem)

/** Returns a new [[AttributeSet]] that does not contain `elem`. */
def -(elem: Attribute): AttributeSet =
new AttributeSet(baseSet - new AttributeEquals(elem))
new AttributeSet(baseSet - elem)

/** Returns an iterator containing all of the attributes in the set. */
def iterator: Iterator[Attribute] = baseSet.map(_.a).iterator
def iterator: Iterator[Attribute] = baseSet.iterator

/**
* Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in
Expand All @@ -89,7 +76,7 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
* in `other`.
*/
def --(other: Traversable[NamedExpression]): AttributeSet =
new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute)))
new AttributeSet(baseSet -- other.map(_.toAttribute))

/**
* Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found
Expand All @@ -102,7 +89,7 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
* true.
*/
override def filter(f: Attribute => Boolean): AttributeSet =
new AttributeSet(baseSet.filter(ae => f(ae.a)))
new AttributeSet(baseSet.filter(f))

/**
* Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in
Expand All @@ -111,13 +98,13 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
def intersect(other: AttributeSet): AttributeSet =
new AttributeSet(baseSet.intersect(other.baseSet))

override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f)
override def foreach[U](f: (Attribute) => U): Unit = baseSet.foreach(f)

// We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
// sorts of things in its closure.
override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
override def toSeq: Seq[Attribute] = baseSet.toArray.toSeq

override def toString: String = "{" + baseSet.map(_.a).mkString(", ") + "}"
override def toString: String = "{" + baseSet.mkString(", ") + "}"

override def isEmpty: Boolean = baseSet.isEmpty
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,24 +141,6 @@ abstract class Expression extends TreeNode[Expression] {
}.toString
}

/**
* Returns true when two expressions will always compute the same result, even if they differ
* cosmetically (i.e. capitalization of names in attributes may be different).
*/
def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && {
def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = {
elements1.length == elements2.length && elements1.zip(elements2).forall {
case (e1: Expression, e2: Expression) => e1 semanticEquals e2
case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2
case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq)
case (i1, i2) => i1 == i2
}
}
val elements1 = this.productIterator.toSeq
val elements2 = other.asInstanceOf[Product].productIterator.toSeq
checkSemantic(elements1, elements2)
}

/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,8 @@ case class AttributeReference(
def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId

override def equals(other: Any): Boolean = other match {
case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType
case _ => false
}

override def semanticEquals(other: Expression): Boolean = other match {
case ar: AttributeReference => sameRef(ar)
case ar: AttributeReference =>
exprId == ar.exprId && dataType == ar.dataType && metadata == ar.metadata
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,13 @@ object PartialAggregation {

// Replace aggregations with a new expression that computes the result from the already
// computed partial evaluations and grouping values.
val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown {
case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) =>
partialEvaluations(new TreeNodeRef(e)).finalEvaluation

case e: Expression =>
namedGroupingExpressions.collectFirst {
case (expr, ne) if expr semanticEquals e => ne.toAttribute
case (expr, ne) if expr == e => ne.toAttribute
}.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
val newArgs = productIterator.map {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (newChild fastEquals arg) {
if (newChild eq arg) {
arg
} else {
changed = true
Expand All @@ -181,7 +181,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
if (newChild eq oldChild) {
oldChild
} else {
changed = true
Expand All @@ -193,7 +193,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
if (newChild eq oldChild) {
oldChild
} else {
changed = true
Expand Down Expand Up @@ -228,7 +228,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
}

// Check if unchanged and then possibly return old copy to avoid gc churn.
if (this fastEquals afterRule) {
if (this eq afterRule) {
transformChildrenDown(rule)
} else {
afterRule.transformChildrenDown(rule)
Expand All @@ -245,15 +245,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
val newArgs = productIterator.map {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
if (!(newChild fastEquals arg)) {
if (newChild ne arg) {
changed = true
newChild
} else {
arg
}
case Some(arg: TreeNode[_]) if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
if (!(newChild fastEquals arg)) {
if (newChild ne arg) {
changed = true
Some(newChild)
} else {
Expand All @@ -264,7 +264,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
if (!(newChild fastEquals arg)) {
if (newChild ne arg) {
changed = true
newChild
} else {
Expand All @@ -286,7 +286,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
*/
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
val afterRuleOnChildren = transformChildrenUp(rule)
if (this fastEquals afterRuleOnChildren) {
if (this eq afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
}
Expand All @@ -302,15 +302,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
val newArgs = productIterator.map {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
if (!(newChild fastEquals arg)) {
if (newChild ne arg) {
changed = true
newChild
} else {
arg
}
case Some(arg: TreeNode[_]) if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
if (!(newChild fastEquals arg)) {
if (newChild ne arg) {
changed = true
Some(newChild)
} else {
Expand All @@ -321,7 +321,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
if (!(newChild fastEquals arg)) {
if (newChild ne arg) {
changed = true
newChild
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class AttributeSetSuite extends SparkFunSuite {
val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

test("sanity check") {
assert(aUpper != aLower)
assert(bUpper != bLower)
assert(aUpper == aLower)
assert(bUpper == bLower)
}

test("checks by id not name") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ case class GeneratedAggregate(
case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e))
case e: Expression =>
namedGroups.collectFirst {
case (expr, attr) if expr semanticEquals e => attr
case (expr, attr) if expr == e => attr
}.getOrElse(e)
})

Expand Down