Skip to content

Commit 64fe30b

Browse files
committed
Improve SparkSQL Aggregates
* Add native min/max (was using hive before). * Handle nulls correctly in Avg and Sum.
1 parent 3308722 commit 64fe30b

File tree

4 files changed

+96
-10
lines changed

4 files changed

+96
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
114114
protected val JOIN = Keyword("JOIN")
115115
protected val LEFT = Keyword("LEFT")
116116
protected val LIMIT = Keyword("LIMIT")
117+
protected val MAX = Keyword("MAX")
118+
protected val MIN = Keyword("MIN")
117119
protected val NOT = Keyword("NOT")
118120
protected val NULL = Keyword("NULL")
119121
protected val ON = Keyword("ON")
@@ -318,6 +320,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
318320
COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
319321
FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
320322
AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
323+
MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |
324+
MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } |
321325
IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
322326
case c ~ "," ~ t ~ "," ~ f => If(c,t,f)
323327
} |

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,67 @@ abstract class AggregateFunction
8686
override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
8787
}
8888

89+
case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
90+
override def references = child.references
91+
override def nullable = child.nullable
92+
override def dataType = child.dataType
93+
override def toString = s"MIN($child)"
94+
95+
override def asPartial: SplitEvaluation = {
96+
val partialMin = Alias(Min(child), "PartialMin")()
97+
SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil)
98+
}
99+
100+
override def newInstance() = new MinFunction(child, this)
101+
}
102+
103+
case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
104+
def this() = this(null, null) // Required for serialization.
105+
106+
var currentMin: Any = _
107+
108+
override def update(input: Row): Unit = {
109+
if (currentMin == null) {
110+
currentMin = expr.eval(input)
111+
} else if(GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) {
112+
currentMin = expr.eval(input)
113+
}
114+
}
115+
116+
override def eval(input: Row): Any = currentMin
117+
}
118+
119+
case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
120+
override def references = child.references
121+
override def nullable = child.nullable
122+
override def dataType = child.dataType
123+
override def toString = s"MAX($child)"
124+
125+
override def asPartial: SplitEvaluation = {
126+
val partialMax = Alias(Max(child), "PartialMax")()
127+
SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil)
128+
}
129+
130+
override def newInstance() = new MaxFunction(child, this)
131+
}
132+
133+
case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
134+
def this() = this(null, null) // Required for serialization.
135+
136+
var currentMax: Any = _
137+
138+
override def update(input: Row): Unit = {
139+
if (currentMax == null) {
140+
currentMax = expr.eval(input)
141+
} else if(LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) {
142+
currentMax = expr.eval(input)
143+
}
144+
}
145+
146+
override def eval(input: Row): Any = currentMax
147+
}
148+
149+
89150
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
90151
override def references = child.references
91152
override def nullable = false
@@ -97,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
97158
SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
98159
}
99160

100-
override def newInstance()= new CountFunction(child, this)
161+
override def newInstance() = new CountFunction(child, this)
101162
}
102163

103164
case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
@@ -106,7 +167,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
106167
override def nullable = false
107168
override def dataType = IntegerType
108169
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
109-
override def newInstance()= new CountDistinctFunction(expressions, this)
170+
override def newInstance() = new CountDistinctFunction(expressions, this)
110171
}
111172

112173
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
@@ -126,7 +187,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
126187
partialCount :: partialSum :: Nil)
127188
}
128189

129-
override def newInstance()= new AverageFunction(child, this)
190+
override def newInstance() = new AverageFunction(child, this)
130191
}
131192

132193
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
@@ -142,7 +203,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
142203
partialSum :: Nil)
143204
}
144205

145-
override def newInstance()= new SumFunction(child, this)
206+
override def newInstance() = new SumFunction(child, this)
146207
}
147208

148209
case class SumDistinct(child: Expression)
@@ -153,7 +214,7 @@ case class SumDistinct(child: Expression)
153214
override def dataType = child.dataType
154215
override def toString = s"SUM(DISTINCT $child)"
155216

156-
override def newInstance()= new SumDistinctFunction(child, this)
217+
override def newInstance() = new SumDistinctFunction(child, this)
157218
}
158219

159220
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
@@ -168,19 +229,21 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod
168229
First(partialFirst.toAttribute),
169230
partialFirst :: Nil)
170231
}
171-
override def newInstance()= new FirstFunction(child, this)
232+
override def newInstance() = new FirstFunction(child, this)
172233
}
173234

174235
case class AverageFunction(expr: Expression, base: AggregateExpression)
175236
extends AggregateFunction {
176237

177238
def this() = this(null, null) // Required for serialization.
178239

240+
private val zero = Cast(Literal(0), expr.dataType)
241+
179242
private var count: Long = _
180-
private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(EmptyRow))
243+
private val sum = MutableLiteral(zero.eval(EmptyRow))
181244
private val sumAsDouble = Cast(sum, DoubleType)
182245

183-
private val addFunction = Add(sum, expr)
246+
private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
184247

185248
override def eval(input: Row): Any =
186249
sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
@@ -209,9 +272,11 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
209272
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
210273
def this() = this(null, null) // Required for serialization.
211274

212-
private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(null))
275+
private val zero = Cast(Literal(0), expr.dataType)
276+
277+
private val sum = MutableLiteral(zero.eval(null))
213278

214-
private val addFunction = Add(sum, expr)
279+
private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
215280

216281
override def update(input: Row): Unit = {
217282
sum.update(addFunction, input)

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ class SQLQuerySuite extends QueryTest {
5050
Seq((1,3),(2,3),(3,3)))
5151
}
5252

53+
test("aggregates with nulls") {
54+
checkAnswer(
55+
sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"),
56+
(1, 3, 2, 6, 3) :: Nil
57+
)
58+
}
59+
5360
test("select *") {
5461
checkAnswer(
5562
sql("SELECT * FROM testData"),

sql/core/src/test/scala/org/apache/spark/sql/TestData.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,14 @@ object TestData {
8484
List.fill(2)(StringData(null)) ++
8585
List.fill(2)(StringData("test")))
8686
nullableRepeatedData.registerAsTable("nullableRepeatedData")
87+
88+
case class NullInts(a: Integer)
89+
val nullInts =
90+
TestSQLContext.sparkContext.parallelize(
91+
NullInts(1) ::
92+
NullInts(2) ::
93+
NullInts(3) ::
94+
NullInts(null) :: Nil
95+
)
96+
nullInts.registerAsTable("nullInts")
8797
}

0 commit comments

Comments
 (0)