Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable

/**
* This class is used to compute equality of (sub)expression trees. Expressions can be added
Expand Down Expand Up @@ -72,7 +73,10 @@ class EquivalentExpressions {
root: Expression,
ignoreLeaf: Boolean = true,
skipReferenceToExpressions: Boolean = true): Unit = {
val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) ||
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
root.find(_.isInstanceOf[LambdaVariable]).isDefined
// There are some special expressions that we should not recurse into children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ object NameAgg extends Aggregator[AggData, String, String] {
}


object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[Int]] {
object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[(Int, Int)]] {
def zero: Seq[Int] = Nil
def reduce(b: Seq[Int], a: AggData): Seq[Int] = a.a +: b
def merge(b1: Seq[Int], b2: Seq[Int]): Seq[Int] = b1 ++ b2
def finish(r: Seq[Int]): Seq[Int] = r
def finish(r: Seq[Int]): Seq[(Int, Int)] = r.map(i => i -> i)
override def bufferEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
override def outputEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
override def outputEncoder: Encoder[Seq[(Int, Int)]] = ExpressionEncoder()
}


Expand Down Expand Up @@ -281,7 +281,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {

checkDataset(
ds.groupByKey(_.b).agg(SeqAgg.toColumn),
"a" -> Seq(1, 2)
"a" -> Seq(1 -> 1, 2 -> 2)
)
}

Expand Down