Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val JOIN = Keyword("JOIN")
protected val LEFT = Keyword("LEFT")
protected val LIMIT = Keyword("LIMIT")
protected val MAX = Keyword("MAX")
protected val MIN = Keyword("MIN")
protected val NOT = Keyword("NOT")
protected val NULL = Keyword("NULL")
protected val ON = Keyword("ON")
Expand Down Expand Up @@ -318,6 +320,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |
MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } |
IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
case c ~ "," ~ t ~ "," ~ f => If(c,t,f)
} |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,67 @@ abstract class AggregateFunction
override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
}

case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = child.nullable
override def dataType = child.dataType
override def toString = s"MIN($child)"

override def asPartial: SplitEvaluation = {
val partialMin = Alias(Min(child), "PartialMin")()
SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil)
}

override def newInstance() = new MinFunction(child, this)
}

case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is unrelated to this pr - but I just realized the way we are storing the aggregation buffer in Spark SQL uses much more memory than needed, because there are two extra pointers to expr/base, which is identical for every tuple.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, though this is not an issue in the code gen version.
On May 7, 2014 2:28 PM, "Reynold Xin" [email protected] wrote:

In
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala:

@@ -86,6 +86,67 @@ abstract class AggregateFunction
override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
}

+case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {

  • override def references = child.references
  • override def nullable = child.nullable
  • override def dataType = child.dataType
  • override def toString = s"MIN($child)"
  • override def asPartial: SplitEvaluation = {
  • val partialMin = Alias(Min(child), "PartialMin")()
  • SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil)
  • }
  • override def newInstance() = new MinFunction(child, this)
    +}

+case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {

this is unrelated to this pr - but I just realized the way we are storing
the aggregation buffer in Spark SQL uses much more memory than needed,
because there are two extra pointers to expr/base, which is identical for
every tuple.


Reply to this email directly or view it on GitHubhttps://github.com//pull/683/files#r12404003
.

def this() = this(null, null) // Required for serialization.

var currentMin: Any = _

override def update(input: Row): Unit = {
if (currentMin == null) {
currentMin = expr.eval(input)
} else if(GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) {
currentMin = expr.eval(input)
}
}

override def eval(input: Row): Any = currentMin
}

case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = child.nullable
override def dataType = child.dataType
override def toString = s"MAX($child)"

override def asPartial: SplitEvaluation = {
val partialMax = Alias(Max(child), "PartialMax")()
SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil)
}

override def newInstance() = new MaxFunction(child, this)
}

case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.

var currentMax: Any = _

override def update(input: Row): Unit = {
if (currentMax == null) {
currentMax = expr.eval(input)
} else if(LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) {
currentMax = expr.eval(input)
}
}

override def eval(input: Row): Any = currentMax
}


case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
Expand All @@ -97,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
}

override def newInstance()= new CountFunction(child, this)
override def newInstance() = new CountFunction(child, this)
}

case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
Expand All @@ -106,7 +167,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
override def nullable = false
override def dataType = IntegerType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
override def newInstance()= new CountDistinctFunction(expressions, this)
override def newInstance() = new CountDistinctFunction(expressions, this)
}

case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
Expand All @@ -126,7 +187,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
partialCount :: partialSum :: Nil)
}

override def newInstance()= new AverageFunction(child, this)
override def newInstance() = new AverageFunction(child, this)
}

case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
Expand All @@ -142,7 +203,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
partialSum :: Nil)
}

override def newInstance()= new SumFunction(child, this)
override def newInstance() = new SumFunction(child, this)
}

case class SumDistinct(child: Expression)
Expand All @@ -153,7 +214,7 @@ case class SumDistinct(child: Expression)
override def dataType = child.dataType
override def toString = s"SUM(DISTINCT $child)"

override def newInstance()= new SumDistinctFunction(child, this)
override def newInstance() = new SumDistinctFunction(child, this)
}

case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
Expand All @@ -168,19 +229,21 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod
First(partialFirst.toAttribute),
partialFirst :: Nil)
}
override def newInstance()= new FirstFunction(child, this)
override def newInstance() = new FirstFunction(child, this)
}

case class AverageFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {

def this() = this(null, null) // Required for serialization.

private val zero = Cast(Literal(0), expr.dataType)

private var count: Long = _
private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(EmptyRow))
private val sum = MutableLiteral(zero.eval(EmptyRow))
private val sumAsDouble = Cast(sum, DoubleType)

private val addFunction = Add(sum, expr)
private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))

override def eval(input: Row): Any =
sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
Expand Down Expand Up @@ -209,9 +272,11 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.

private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(null))
private val zero = Cast(Literal(0), expr.dataType)

private val sum = MutableLiteral(zero.eval(null))

private val addFunction = Add(sum, expr)
private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))

override def update(input: Row): Unit = {
sum.update(addFunction, input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ class SQLQuerySuite extends QueryTest {
Seq((1,3),(2,3),(3,3)))
}

test("aggregates with nulls") {
checkAnswer(
sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"),
(1, 3, 2, 6, 3) :: Nil
)
}

test("select *") {
checkAnswer(
sql("SELECT * FROM testData"),
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,14 @@ object TestData {
List.fill(2)(StringData(null)) ++
List.fill(2)(StringData("test")))
nullableRepeatedData.registerAsTable("nullableRepeatedData")

case class NullInts(a: Integer)
val nullInts =
TestSQLContext.sparkContext.parallelize(
NullInts(1) ::
NullInts(2) ::
NullInts(3) ::
NullInts(null) :: Nil
)
nullInts.registerAsTable("nullInts")
}