Skip to content

Commit dded1c5

Browse files
committed
wip
1 parent 39e4e7e commit dded1c5

File tree

11 files changed

+631
-4
lines changed

11 files changed

+631
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.sql.AnalysisException
21+
import org.apache.spark.sql.catalyst.expressions.aggregate2.{AggregateExpression2, AggregateFunction2}
2122
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.catalyst.plans.logical._
@@ -482,7 +483,11 @@ class Analyzer(
482483
q transformExpressions {
483484
case u @ UnresolvedFunction(name, children) =>
484485
withPosition(u) {
485-
registry.lookupFunction(name, children)
486+
registry.lookupFunction(name, children) match {
487+
case agg2: AggregateFunction2 =>
488+
AggregateExpression2(agg2, aggregate2.Complete, false)
489+
case other => other
490+
}
486491
}
487492
}
488493
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.expressions.aggregate2.AggregateExpression2
2223
import org.apache.spark.sql.catalyst.plans.logical._
2324
import org.apache.spark.sql.types._
2425

@@ -85,6 +86,7 @@ trait CheckAnalysis {
8586
case Aggregate(groupingExprs, aggregateExprs, child) =>
8687
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
8788
case _: AggregateExpression => // OK
89+
case _: AggregateExpression2 => // OK
8890
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
8991
failAnalysis(
9092
s"expression '${e.prettyString}' is neither present in the group by, " +

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ object FunctionRegistry {
148148

149149
// aggregate functions
150150
expression[Average]("avg"),
151+
expression[aggregate2.Average]("avg2"),
151152
expression[Count]("count"),
152153
expression[First]("first"),
153154
expression[Last]("last"),
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate2
19+
20+
import org.apache.spark.sql.catalyst.errors.TreeNodeException
21+
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.trees.{LeafNode, UnaryNode}
24+
import org.apache.spark.sql.types._
25+
26+
private[sql] sealed trait AggregateMode
27+
28+
private[sql] case object Partial extends AggregateMode
29+
30+
private[sql] case object PartialMerge extends AggregateMode
31+
32+
private[sql] case object Final extends AggregateMode
33+
34+
private[sql] case object Complete extends AggregateMode
35+
36+
/**
37+
* A container of a Aggregate Function, Aggregate Mode, and a field (`isDistinct`) indicating
38+
* if DISTINCT keyword is specified for this function.
39+
* @param aggregateFunction
40+
* @param mode
41+
* @param isDistinct
42+
*/
43+
private[sql] case class AggregateExpression2(
44+
aggregateFunction: AggregateFunction2,
45+
mode: AggregateMode,
46+
isDistinct: Boolean) extends Expression {
47+
48+
override def children: Seq[Expression] = aggregateFunction :: Nil
49+
50+
override def dataType: DataType = aggregateFunction.dataType
51+
override def foldable: Boolean = aggregateFunction.foldable
52+
override def nullable: Boolean = aggregateFunction.nullable
53+
54+
override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"
55+
56+
override def eval(input: InternalRow = null): Any =
57+
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
58+
}
59+
60+
abstract class AggregateFunction2
61+
extends Expression {
62+
63+
self: Product =>
64+
65+
var bufferOffset: Int = 0
66+
67+
def withBufferOffset(newBufferOffset: Int): AggregateFunction2 = {
68+
bufferOffset = newBufferOffset
69+
this
70+
}
71+
72+
def bufferValueDataTypes: StructType
73+
74+
def initialBufferValues: Array[Any]
75+
76+
def initialize(buffer: MutableRow): Unit
77+
78+
def updateBuffer(buffer: MutableRow, bufferValues: Array[Any]): Unit = {
79+
var i = 0
80+
println("bufferOffset in average2 " + bufferOffset)
81+
while (i < bufferValues.length) {
82+
buffer.update(bufferOffset + i, bufferValues(i))
83+
i += 1
84+
}
85+
}
86+
87+
def update(buffer: MutableRow, input: InternalRow): Unit
88+
89+
def merge(buffer1: MutableRow, buffer2: InternalRow): Unit
90+
91+
override def eval(buffer: InternalRow = null): Any
92+
}
93+
94+
case class Average(child: Expression)
95+
extends AggregateFunction2 with UnaryNode[Expression] {
96+
97+
override def nullable: Boolean = child.nullable
98+
99+
override def bufferValueDataTypes: StructType = child match {
100+
case e @ DecimalType() =>
101+
StructType(
102+
StructField("Sum", DecimalType.Unlimited) ::
103+
StructField("Count", LongType) :: Nil)
104+
case _ =>
105+
StructType(
106+
StructField("Sum", DoubleType) ::
107+
StructField("Count", LongType) :: Nil)
108+
}
109+
110+
override def dataType: DataType = child.dataType match {
111+
case DecimalType.Fixed(precision, scale) =>
112+
DecimalType(precision + 4, scale + 4)
113+
case DecimalType.Unlimited => DecimalType.Unlimited
114+
case _ => DoubleType
115+
}
116+
117+
override def initialBufferValues: Array[Any] = {
118+
Array(
119+
Cast(Literal(0), bufferValueDataTypes("Sum").dataType).eval(null), // Sum
120+
0L) // Count
121+
}
122+
123+
override def initialize(buffer: MutableRow): Unit =
124+
updateBuffer(buffer, initialBufferValues)
125+
126+
private val inputLiteral =
127+
MutableLiteral(null, child.dataType)
128+
private val bufferedSum =
129+
MutableLiteral(null, bufferValueDataTypes("Sum").dataType)
130+
private val bufferedCount = MutableLiteral(null, LongType)
131+
private val updateSum =
132+
Add(Cast(inputLiteral, bufferValueDataTypes("Sum").dataType), bufferedSum)
133+
private val inputBufferedSum =
134+
MutableLiteral(null, bufferValueDataTypes("Sum").dataType)
135+
private val mergeSum = Add(inputBufferedSum, bufferedSum)
136+
private val evaluateAvg =
137+
Cast(Divide(bufferedSum, Cast(bufferedCount, bufferValueDataTypes("Sum").dataType)), dataType)
138+
139+
override def update(buffer: MutableRow, input: InternalRow): Unit = {
140+
val newInput = child.eval(input)
141+
println("newInput " + newInput)
142+
if (newInput != null) {
143+
inputLiteral.value = newInput
144+
bufferedSum.value = buffer(bufferOffset)
145+
val newSum = updateSum.eval(null)
146+
val newCount = buffer.getLong(bufferOffset + 1) + 1
147+
buffer.update(bufferOffset, newSum)
148+
buffer.update(bufferOffset + 1, newCount)
149+
}
150+
}
151+
152+
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
153+
if (buffer2(bufferOffset + 1) != 0L) {
154+
inputBufferedSum.value = buffer2(bufferOffset)
155+
bufferedSum.value = buffer1(bufferOffset)
156+
val newSum = mergeSum.eval(null)
157+
val newCount =
158+
buffer1.getLong(bufferOffset + 1) + buffer2.getLong(bufferOffset + 1)
159+
buffer1.update(bufferOffset, newSum)
160+
buffer1.update(bufferOffset + 1, newCount)
161+
}
162+
}
163+
164+
override def eval(buffer: InternalRow): Any = {
165+
if (buffer(bufferOffset + 1) == 0L) {
166+
null
167+
} else {
168+
bufferedSum.value = buffer(bufferOffset)
169+
bufferedCount.value = buffer.getLong(bufferOffset + 1)
170+
evaluateAvg.eval(null)
171+
}
172+
}
173+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate2
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, BoundReference, Literal, ExpressionEvalHelper}
22+
import org.apache.spark.sql.types._
23+
24+
class AggregateExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
25+
26+
test("Average") {
27+
val inputValues = Array(Int.MaxValue, null, 1000, Int.MinValue, 2)
28+
val avg = Average(child = BoundReference(0, IntegerType, true)).withBufferOffset(2)
29+
val inputRow = new GenericMutableRow(1)
30+
val buffer = new GenericMutableRow(4)
31+
avg.initialize(buffer)
32+
33+
// We there is no input data, average should return null.
34+
assert(avg.eval(buffer) === null)
35+
// When input values are all nulls, average should return null.
36+
var i = 0
37+
while (i < 10) {
38+
inputRow.update(0, null)
39+
avg.update(inputRow, buffer)
40+
i += 1
41+
}
42+
assert(avg.eval(buffer) === null)
43+
44+
// Add some values.
45+
i = 0
46+
while (i < inputValues.length) {
47+
inputRow.update(0, inputValues(i))
48+
avg.update(buffer, inputRow)
49+
i += 1
50+
}
51+
assert(avg.eval(buffer) === 1001 / 4.0)
52+
53+
// eval should not reset the buffer
54+
assert(buffer(2) === 1001L)
55+
assert(buffer(3) === 4L)
56+
assert(avg.eval(buffer) === 1001 / 4.0)
57+
58+
// Merge with a just initialized buffer.
59+
val inputBuffer = new GenericMutableRow(4)
60+
avg.initialize(inputBuffer)
61+
avg.merge(buffer, inputBuffer)
62+
assert(buffer(2) === 1001L)
63+
assert(buffer(3) === 4L)
64+
assert(avg.eval(buffer) === 1001 / 4.0)
65+
66+
// Merge with a buffer containing partial results.
67+
inputBuffer.update(2, 2000.0)
68+
inputBuffer.update(3, 10L)
69+
avg.merge(buffer, inputBuffer)
70+
assert(buffer(2) === 3001L)
71+
assert(buffer(3) === 14L)
72+
assert(avg.eval(buffer) === 3001 / 14.0)
73+
}
74+
}

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,9 @@ private[spark] object SQLConf {
380380
val USE_SQL_SERIALIZER2 = booleanConf("spark.sql.useSerializer2",
381381
defaultValue = Some(true), doc = "<TODO>")
382382

383+
val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
384+
defaultValue = Some(false), doc = "<TODO>")
385+
383386
val USE_JACKSON_STREAMING_API = booleanConf("spark.sql.json.useJacksonStreamingAPI",
384387
defaultValue = Some(true), doc = "<TODO>")
385388

@@ -479,6 +482,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
479482

480483
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
481484

485+
private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2)
486+
482487
/**
483488
* Selects between the new (true) and old (false) JSON handlers, to be removed in Spark 1.5.0
484489
*/

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
864864
DDLStrategy ::
865865
TakeOrderedAndProject ::
866866
HashAggregation ::
867+
AggregateOperator2 ::
867868
LeftSemiJoin ::
868869
HashJoin ::
869870
InMemoryScans ::

0 commit comments

Comments
 (0)