Skip to content

Commit 594cdf5

Browse files
committed
Change existing AggregateExpression to AggregateExpression1 and add an AggregateExpression as the common interface for both AggregateExpression1 and AggregateExpression2.
1 parent 380880f commit 594cdf5

File tree

12 files changed

+83
-84
lines changed

12 files changed

+83
-84
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,6 @@ class Analyzer(
541541
def containsAggregates(exprs: Seq[Expression]): Boolean = {
542542
exprs.foreach(_.foreach {
543543
case agg: AggregateExpression => return true
544-
case agg2: AggregateExpression2 => return true
545544
case _ =>
546545
})
547546
false

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ trait CheckAnalysis {
8686
case Aggregate(groupingExprs, aggregateExprs, child) =>
8787
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
8888
case _: AggregateExpression => // OK
89-
case _: AggregateExpression2 => // OK
9089
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
9190
failAnalysis(
9291
s"expression '${e.prettyString}' is neither present in the group by, " +

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ abstract class Expression extends TreeNode[Expression] {
8989
val primitive = ctx.freshName("primitive")
9090
val ve = GeneratedExpressionCode("", isNull, primitive)
9191
ve.code = genCode(ctx, ve)
92-
ve.copy(s"/* $this */\n" + ve.code)
92+
// We may want to print out $this in the comment of generated code for debugging.
93+
// ve.copy(s"/* $this */\n" + ve.code)
94+
ve
9395
}
9496

9597
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,30 @@ package org.apache.spark.sql.catalyst.expressions.aggregate2
2020
import org.apache.spark.sql.catalyst.dsl.expressions._
2121
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2222
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
23+
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
24+
import org.apache.spark.sql.catalyst.InternalRow
2425
import org.apache.spark.sql.types._
25-
import org.apache.spark.sql.Row
2626

27-
/** The mode of an [[AggregateFunction]]. */
27+
/** The mode of an [[AggregateFunction1]]. */
2828
private[sql] sealed trait AggregateMode
2929

3030
/**
31-
* An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation.
31+
* An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation.
3232
* This function updates the given aggregation buffer with the original input of this
3333
* function. When it has processed all input rows, the aggregation buffer is returned.
3434
*/
3535
private[sql] case object Partial extends AggregateMode
3636

3737
/**
38-
* An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers
38+
* An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
3939
* containing intermediate results for this function.
4040
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
4141
* When it has processed all input rows, the aggregation buffer is returned.
4242
*/
4343
private[sql] case object PartialMerge extends AggregateMode
4444

4545
/**
46-
* An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers
46+
* An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
4747
* containing intermediate results for this function and the generate final result.
4848
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
4949
* When it has processed all input rows, the final result of this function is returned.
@@ -58,7 +58,7 @@ private[sql] case object Final extends AggregateMode
5858
*/
5959
private[sql] case object Complete extends AggregateMode
6060

61-
private[sql] case object NoOp extends Expression {
61+
private[sql] case object NoOp extends Expression with Unevaluable {
6262
override def nullable: Boolean = true
6363
override def eval(input: InternalRow): Any = {
6464
throw new TreeNodeException(
@@ -78,19 +78,14 @@ private[sql] case object NoOp extends Expression {
7878
private[sql] case class AggregateExpression2(
7979
aggregateFunction: AggregateFunction2,
8080
mode: AggregateMode,
81-
isDistinct: Boolean) extends Expression {
81+
isDistinct: Boolean) extends Expression with Unevaluable {
8282

8383
override def children: Seq[Expression] = aggregateFunction :: Nil
8484
override def dataType: DataType = aggregateFunction.dataType
8585
override def foldable: Boolean = false
8686
override def nullable: Boolean = aggregateFunction.nullable
8787

8888
override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"
89-
90-
override def eval(input: InternalRow = null): Any = {
91-
throw new TreeNodeException(
92-
this, s"No function to evaluate expression. type: ${this.nodeName}")
93-
}
9489
}
9590

9691
abstract class AggregateFunction2
@@ -136,6 +131,9 @@ abstract class AggregateFunction2
136131
* and `buffer2`.
137132
*/
138133
def merge(buffer1: MutableRow, buffer2: InternalRow): Unit
134+
135+
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
136+
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
139137
}
140138

141139
/**

0 commit comments

Comments
 (0)