From 440b689ce11825964404334916dac7424c19d6dc Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sat, 11 Apr 2015 14:08:23 +0800 Subject: [PATCH 01/20] migrate to support both version of UDAF --- .../expressions/aggregate2/aggregates.scala | 526 ++++++++++++++++++ .../sql/catalyst/planning/patterns.scala | 100 +++- .../scala/org/apache/spark/sql/SQLConf.scala | 6 + .../org/apache/spark/sql/SQLContext.scala | 3 + .../spark/sql/execution/SparkStrategies.scala | 32 ++ .../sql/execution/aggregate2/Aggregate.scala | 480 ++++++++++++++++ .../apache/spark/sql/hive/HiveContext.scala | 3 + .../org/apache/spark/sql/hive/hiveUdfs.scala | 118 ++++ .../sql/hive/execution/AggregateSuite.scala | 171 ++++++ 9 files changed, 1438 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala new file mode 100644 index 000000000000..3c846c67e1ea --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -0,0 +1,526 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate2 + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + + +/** + * This is from org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode + * Just a hint for the UDAF developers which stage we are about to process, + * However, we probably don't want the developers knows so many details, here + * is just for keep consistent with Hive (when integrated with Hive), need to + * figure out if we have work around for that soon. + */ +@deprecated +trait Mode + +/** + * PARTIAL1: from original data to partial aggregation data: iterate() and + * terminatePartial() will be called. + */ +@deprecated +case object PARTIAL1 extends Mode + +/** + * PARTIAL2: from partial aggregation data to partial aggregation data: + * merge() and terminatePartial() will be called. + */ +@deprecated +case object PARTIAL2 extends Mode +/** + * FINAL: from partial aggregation to full aggregation: merge() and + * terminate() will be called. + */ +@deprecated +case object FINAL extends Mode +/** + * COMPLETE: from original data directly to full aggregation: iterate() and + * terminate() will be called. + */ +@deprecated +case object COMPLETE extends Mode + + +/** + * Aggregation Function Interface + * All of the function will be called within Spark executors. + */ +trait AggregateFunction2 { + self: Product => + + // Specify the BoundReference for Aggregate Buffer + def initialBoundReference(buffers: Seq[BoundReference]): Unit + + // Initialize (reinitialize) the aggregation buffer + def reset(buf: MutableRow): Unit + + // Expect the aggregate function fills the aggregation buffer when + // fed with each value in the group + def iterate(arguments: Any, buf: MutableRow): Unit + + // Merge 2 aggregation buffer, and write back to the later one + def merge(value: Row, buf: MutableRow): Unit + + // Semantically we probably don't need this, however, we need it when + // integrating with Hive UDAF(GenericUDAF) + @deprecated + def terminatePartial(buf: MutableRow): Unit = {} + + // Output the final result by feeding the aggregation buffer + def terminate(input: Row): Any +} + +trait AggregateExpression2 extends Expression with AggregateFunction2 { + self: Product => + type EvaluatedType = Any + + var mode: Mode = COMPLETE + + def initial(m: Mode): Unit = { + this.mode = m + } + + // Aggregation Buffer data types + def bufferDataType: Seq[DataType] = Nil + // Is it a distinct aggregate expression? + def distinct: Boolean + // Is it a distinct like aggregate expression (e.g. Min/Max is distinctLike, while avg is not) + def distinctLike: Boolean = false + + def nullable = true + + override def eval(input: Row): EvaluatedType = children.map(_.eval(input)) +} + +abstract class UnaryAggregateExpression extends UnaryExpression with AggregateExpression2 { + self: Product => + + override def eval(input: Row): EvaluatedType = child.eval(input) +} + +case class Min( + child: Expression) + extends UnaryAggregateExpression { + + override def distinct: Boolean = false + override def distinctLike: Boolean = true + override def dataType = child.dataType + override def bufferDataType: Seq[DataType] = dataType :: Nil + override def toString = s"MIN($child)" + + /* The below code will be called in executors, be sure to make the instance transientable */ + @transient var arg: MutableLiteral = _ + @transient var buffer: MutableLiteral = _ + @transient var cmp: LessThan = _ + @transient var aggr: BoundReference = _ + + /* Initialization on executors */ + override def initialBoundReference(buffers: Seq[BoundReference]): Unit = { + aggr = buffers(0) + arg = MutableLiteral(null, dataType) + buffer = MutableLiteral(null, dataType) + cmp = LessThan(arg, buffer) + } + + override def reset(buf: MutableRow): Unit = { + buf(aggr) = null + } + + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (argument != null) { + arg.value = argument + buffer.value = buf(aggr) + if (buf.isNullAt(aggr) || cmp.eval(null) == true) { + buf(aggr) = argument + } + } + } + + override def merge(value: Row, rowBuf: MutableRow): Unit = { + if (!value.isNullAt(aggr)) { + arg.value = value(aggr) + buffer.value = rowBuf(aggr) + if (rowBuf.isNullAt(aggr) || cmp.eval(null) == true) { + rowBuf(aggr) = arg.value + } + } + } + + override def terminate(row: Row): Any = aggr.eval(row) +} + +case class Average(child: Expression, distinct: Boolean = false) + extends UnaryAggregateExpression { + override def nullable = false + + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + DoubleType + } + + override def bufferDataType: Seq[DataType] = LongType :: dataType :: Nil + override def toString = s"AVG($child)" + + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var count: BoundReference = _ + @transient var sum: BoundReference = _ + + // for iterate + @transient var arg: MutableLiteral = _ + @transient var cast: Expression = _ + @transient var add: Add = _ + + // for merge + @transient var argInMerge: MutableLiteral = _ + @transient var addInMerge: Add = _ + + // for terminate + @transient var divide: Divide = _ + + /* Initialization on executors */ + override def initialBoundReference(buffers: Seq[BoundReference]): Unit = { + count = buffers(0) + sum = buffers(1) + + arg = MutableLiteral(null, child.dataType) + cast = if (arg.dataType != dataType) Cast(arg, dataType) else arg + add = Add(cast, sum) + + argInMerge = MutableLiteral(null, dataType) + addInMerge = Add(argInMerge, sum) + + divide = Divide(sum, Cast(count, dataType)) + } + + override def reset(buf: MutableRow): Unit = { + buf(count) = 0L + buf(sum) = null + } + + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (argument != null) { + arg.value = argument + buf(count) = buf.getLong(count) + 1 + if (buf.isNullAt(sum)) { + buf(sum) = cast.eval() + } else { + buf(sum) = add.eval(buf) + } + } + } + + override def merge(value: Row, buf: MutableRow): Unit = { + if (!value.isNullAt(sum)) { + buf(count) = value.getLong(count) + buf.getLong(count) + if (buf.isNullAt(sum)) { + buf(sum) = value(sum) + } else { + argInMerge.value = value(sum) + buf(sum) = addInMerge.eval(buf) + } + } + } + + override def terminate(row: Row): Any = if (count.eval(row) == 0) null else divide.eval(row) +} + +case class Max(child: Expression) + extends UnaryAggregateExpression { + override def distinct: Boolean = false + override def distinctLike: Boolean = true + + override def nullable = true + override def dataType = child.dataType + override def bufferDataType: Seq[DataType] = dataType :: Nil + override def toString = s"MAX($child)" + + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var aggr: BoundReference = _ + @transient var arg: MutableLiteral = _ + @transient var buffer: MutableLiteral = _ + @transient var cmp: GreaterThan = _ + + override def initialBoundReference(buffers: Seq[BoundReference]) = { + aggr = buffers(0) + arg = MutableLiteral(null, dataType) + buffer = MutableLiteral(null, dataType) + cmp = GreaterThan(arg, buffer) + } + + override def reset(buf: MutableRow): Unit = { + buf(aggr) = null + } + + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (argument != null) { + arg.value = argument + buffer.value = buf(aggr) + if (buf.isNullAt(aggr) || cmp.eval(null) == true) { + buf(aggr) = argument + } + } + } + + override def merge(value: Row, rowBuf: MutableRow): Unit = { + if (!value.isNullAt(aggr)) { + arg.value = value(aggr) + buffer.value = rowBuf(aggr) + if (rowBuf.isNullAt(aggr) || cmp.eval(null) == true) { + rowBuf(aggr) = arg.value + } + } + } + + override def terminate(row: Row): Any = aggr.eval(row) +} + +case class Count(child: Expression) + extends UnaryAggregateExpression { + def distinct: Boolean = false + override def nullable = false + override def dataType = LongType + override def bufferDataType: Seq[DataType] = LongType :: Nil + override def toString = s"COUNT($child)" + + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var aggr: BoundReference = _ + + override def initialBoundReference(buffers: Seq[BoundReference]) = { + aggr = buffers(0) + } + + override def reset(buf: MutableRow): Unit = { + buf(aggr) = 0L + } + + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (argument != null) { + if (buf.isNullAt(aggr)) { + buf(aggr) = 1L + } else { + buf(aggr) = buf.getLong(aggr) + 1L + } + } + } + + override def merge(value: Row, rowBuf: MutableRow): Unit = { + if (value.isNullAt(aggr)) { + // do nothing + } else if (rowBuf.isNullAt(aggr)) { + rowBuf(aggr) = value(aggr) + } else { + rowBuf(aggr) = value.getLong(aggr) + rowBuf.getLong(aggr) + } + } + + override def terminate(row: Row): Any = aggr.eval(row) +} + +case class CountDistinct(children: Seq[Expression]) + extends AggregateExpression2 { + def distinct: Boolean = true + override def nullable = false + override def dataType = LongType + override def toString = s"COUNT($children)" + override def bufferDataType: Seq[DataType] = LongType :: Nil + + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var aggr: BoundReference = _ + override def initialBoundReference(buffers: Seq[BoundReference]) = { + aggr = buffers(0) + } + + override def reset(buf: MutableRow): Unit = { + buf(aggr) = 0L + } + + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (!argument.asInstanceOf[Seq[_]].exists(_ == null)) { + // CountDistinct supports multiple expression, and ONLY IF + // none of its expressions value equals null + if (buf.isNullAt(aggr)) { + buf(aggr) = 1L + } else { + buf(aggr) = buf.getLong(aggr) + 1L + } + } + } + + override def merge(value: Row, rowBuf: MutableRow): Unit = { + if (value.isNullAt(aggr)) { + // do nothing + } else if (rowBuf.isNullAt(aggr)) { + rowBuf(aggr) = value(aggr) + } else { + rowBuf(aggr) = value.getLong(aggr) + rowBuf.getLong(aggr) + } + } + + override def terminate(row: Row): Any = aggr.eval(row) +} + + /** + * Sum should satisfy 3 cases: + * 1) sum of all null values = zero + * 2) sum for table column with no data = null + * 3) sum of column with null and not null values = sum of not null values + * Require separate CombineSum Expression and function as it has to distinguish "No data" case + * versus "data equals null" case, while aggregating results and at each partial expression.i.e., + * Combining PartitionLevel InputData + * <-- null + * Zero <-- Zero <-- null + * + * <-- null <-- no data + * null <-- null <-- no data + */ +case class Sum(child: Expression, distinct: Boolean = false) + extends UnaryAggregateExpression { + override def nullable = true + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + + override def bufferDataType: Seq[DataType] = dataType :: Nil + override def toString = s"SUM($child)" + + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var aggr: BoundReference = _ + @transient var arg: MutableLiteral = _ + @transient var sum: Add = _ + + lazy val DEFAULT_VALUE = Cast(Literal.create(0, IntegerType), dataType).eval() + + override def initialBoundReference(buffers: Seq[BoundReference]) = { + aggr = buffers(0) + arg = MutableLiteral(null, dataType) + sum = Add(arg, aggr) + } + + override def reset(buf: MutableRow): Unit = { + buf(aggr) = null + } + + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (argument != null) { + if (buf.isNullAt(aggr)) { + buf(aggr) = argument + } else { + arg.value = argument + buf(aggr) = sum.eval(buf) + } + } else { + if (buf.isNullAt(aggr)) { + buf(aggr) = DEFAULT_VALUE + } + } + } + + override def merge(value: Row, buf: MutableRow): Unit = { + if (!value.isNullAt(aggr)) { + arg.value = value(aggr) + if (buf.isNullAt(aggr)) { + buf(aggr) = arg.value + } else { + buf(aggr) = sum.eval(buf) + } + } + } + + override def terminate(row: Row): Any = aggr.eval(row) +} + +case class First(child: Expression, distinct: Boolean = false) + extends UnaryAggregateExpression { + override def nullable = true + override def dataType = child.dataType + override def bufferDataType: Seq[DataType] = dataType :: Nil + override def toString = s"FIRST($child)" + + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var aggr: BoundReference = _ + + override def initialBoundReference(buffers: Seq[BoundReference]) = { + aggr = buffers(0) + } + + override def reset(buf: MutableRow): Unit = { + buf(aggr) = null + } + + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (buf.isNullAt(aggr)) { + if (argument != null) { + buf(aggr) = argument + } + } + } + + override def merge(value: Row, buf: MutableRow): Unit = { + if (buf.isNullAt(aggr)) { + if (!value.isNullAt(aggr)) { + buf(aggr) = value(aggr) + } + } + } + + override def terminate(row: Row): Any = aggr.eval(row) +} + +case class Last(child: Expression, distinct: Boolean = false) + extends UnaryAggregateExpression { + override def nullable = true + override def dataType = child.dataType + override def bufferDataType: Seq[DataType] = dataType :: Nil + override def toString = s"LAST($child)" + + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var aggr: BoundReference = _ + + override def initialBoundReference(buffers: Seq[BoundReference]) = { + aggr = buffers(0) + } + + override def reset(buf: MutableRow): Unit = { + buf(aggr) = null + } + + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (argument != null) { + buf(aggr) = argument + } + } + + override def merge(value: Row, buf: MutableRow): Unit = { + if (!value.isNullAt(aggr)) { + buf(aggr) = value(aggr) + } + } + + override def terminate(row: Row): Any = aggr.eval(row) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 1dd75a884630..9ef8fea8e892 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.planning +import org.apache.spark.sql.catalyst.expressions.aggregate2.AggregateExpression2 + import scala.annotation.tailrec import org.apache.spark.Logging @@ -106,6 +108,99 @@ object PhysicalOperation extends PredicateHelper { } } +/** + * + * TODO: This is a temporal solution to substitute the expression tree from + * AggregateExpression with aggregate2.AggregateExpression2, and will be + * removed once the aggregate2.AggregateExpression2 is stable enough. + */ +class AggregateExpressionSubsitution { + def subsitute(aggr: AggregateExpression): AggregateExpression2 = aggr match { + case Min(child) => aggregate2.Min(child) + case Max(child) => aggregate2.Max(child) + case Count(child) => aggregate2.Count(child) + case CountDistinct(children) => aggregate2.CountDistinct(children) + // TODO: we don't support approximate in aggregate2 yet. + case ApproxCountDistinct(child, sd) => aggregate2.CountDistinct(child :: Nil) + case Average(child) => aggregate2.Average(child) + case Sum(child) => aggregate2.Sum(child) + case SumDistinct(child) => aggregate2.Sum(child, true) + case First(child) => aggregate2.First(child) + case Last(child) => aggregate2.Last(child) + } +} + +// TODO: Will be removed once aggregate2.AggregateExpression2 is stable enough +object AggregateExpressionSubsitution extends AggregateExpressionSubsitution + +/** + * Matches a logical aggregation that can be performed on distributed data in two steps. The first + * operates on the data in each partition performing partial aggregation for each group. The second + * occurs after the shuffle and completes the aggregation. + * + * This pattern will only match if all aggregate expressions can be computed partially and will + * return the rewritten aggregation expressions for both phases. + * + * The returned values for this match are as follows: + * - Grouping attributes for the final aggregation. + * - Aggregates for the final aggregation. + * - Grouping expressions for the partial aggregation. + * - Partial aggregate expressions. + * - Input to the aggregation. + */ +object PartialAggregation2 { + type ReturnType = + (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) + + def unapply(plan: LogicalPlan) + : Option[ReturnType] = plan match { + case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => + // Collect all aggregate expressions that can be computed partially. + val allAggregates = aggregateExpressions.flatMap(_ collect { + case a: aggregate2.AggregateExpression2 => a + }) + + // Only do partial aggregation if supported by all aggregate expressions. + if (!allAggregates.exists(_.distinct)) { + // We need to pass all grouping expressions though so the grouping can happen a second + // time. However some of them might be unnamed so we alias them allowing them to be + // referenced in the second aggregation. + val namedGroupingExpressions = groupingExpressions.filter(!_.isInstanceOf[Literal]).map { + case n: NamedExpression => (n, n) + case other => (other, Alias(other, "PartialGroup")()) + } + val substitutions = namedGroupingExpressions.toMap + + // Replace aggregations with a new expression that computes the result from the already + // computed partial evaluations and grouping values. + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + case e: Expression if substitutions.contains(e) => + substitutions(e).toAttribute + case e: Expression => + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + substitutions + .get(e.transform { case Alias(g: GetField, _) => g }) + .map(_.toAttribute) + .getOrElse(e) + }).asInstanceOf[Seq[NamedExpression]] + + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + + Some( + (namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + aggregateExpressions, + child)) + } else { + None + } + case _ => None + } +} + /** * Matches a logical aggregation that can be performed on distributed data in two steps. The first * operates on the data in each partition performing partial aggregation for each group. The second @@ -126,7 +221,10 @@ object PartialAggregation { (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => + case logical.Aggregate(groupingExpressions, aggregateExpressions, child) + if (aggregateExpressions.flatMap(_.collect { + case a: aggregate2.AggregateExpression2 => a + })).length == 0 => // Collect all aggregate expressions. val allAggregates = aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 77c6af27d100..fc7515a421c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -90,6 +90,9 @@ private[spark] object SQLConf { val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI" + // Whether enable the aggregate2, which is the refactor one + val AGGREGATE_2 = "spark.sql.aggregate2" + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -264,6 +267,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean + private[spark] def aggregate2: Boolean = + getConf(AGGREGATE_2, "false").toBoolean + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ddb54025baa2..b2d3a866e98d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -21,6 +21,8 @@ import java.beans.Introspector import java.util.Properties import java.util.concurrent.atomic.AtomicReference +import org.apache.spark.sql.catalyst.planning.AggregateExpressionSubsitution + import scala.collection.JavaConversions._ import scala.collection.immutable import scala.language.implicitConversions @@ -825,6 +827,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def strategies: Seq[Strategy] = experimental.extraStrategies ++ ( + new HashAggregation2(AggregateExpressionSubsitution) :: DataSourceStrategy :: DDLStrategy :: TakeOrdered :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7a1331a39151..ce44e3b627e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate2.AggregateExpression2 import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -117,6 +118,37 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + class HashAggregation2(aggrSubsitution: AggregateExpressionSubsitution) extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case aggr @ logical.Aggregate(groupingExpressions, aggregateExpressions, child) + if sqlContext.conf.aggregate2 => + + val subusitutedAggrExpression = aggregateExpressions.map(_.transformUp { + case a: AggregateExpression => aggrSubsitution.subsitute(a) + }.asInstanceOf[NamedExpression]) + + aggr.copy(aggregateExpressions = subusitutedAggrExpression) match { + // Aggregations that can be performed in two phases, before and after the shuffle. + case PartialAggregation2( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + aggregateExpressions, + child) => + aggregate2.AggregatePostShuffle( + namedGroupingAttributes, + rewrittenAggregateExpressions, + aggregate2.AggregatePreShuffle( + groupingExpressions, + aggregateExpressions, + namedGroupingAttributes, + planLater(child))) :: Nil + case _ => Nil + } + case _ => Nil + } + } + object HashAggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Aggregations that can be performed in two phases, before and after the shuffle. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala new file mode 100644 index 000000000000..ed5ff633746a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -0,0 +1,480 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate2 + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate2._ +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} + +import org.apache.spark.util.collection.{OpenHashSet, OpenHashMap} + +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.plans.physical._ + +/** + * An aggregate that needs to be computed for each row in a group. + * + * @param aggregate AggregateExpression2, associated with the function + * @param substitution A MutableLiteral used to refer to the result of this aggregate in the final + * output. + */ +sealed case class AggregateFunctionBind( + aggregate: AggregateExpression2, + substitution: MutableLiteral) + +sealed class InputBufferSeens( + var input: Row, // + var buffer: MutableRow, + var seens: Array[OpenHashSet[Any]] = null) { + def this() { + this(new GenericMutableRow(0), null) + } + + def withInput(row: Row): InputBufferSeens = { + this.input = row + this + } + + def withBuffer(row: MutableRow): InputBufferSeens = { + this.buffer = row + this + } + + def withSeens(seens: Array[OpenHashSet[Any]]): InputBufferSeens = { + this.seens = seens + this + } +} + +sealed trait Aggregate { + self: Product => + // HACK: Generators don't correctly preserve their output through serializations so we grab + // out child's output attributes statically here. + val childOutput = child.output + val isGlobalAggregation = groupingExpressions.isEmpty + + def computedAggregates: Array[AggregateExpression2] = { + boundProjection.flatMap { expr => + expr.collect { + case ae: AggregateExpression2 => ae + } + } + }.toArray + + // This is a hack, instead of relying on the BindReferences for the aggregation + // buffer schema in PostShuffle, we have a strong protocols which represented as the + // BoundReferences in PostShuffle for aggregation buffer. + @transient lazy val bufferSchema: Array[AttributeReference] = + computedAggregates.zipWithIndex.flatMap { case (ca, idx) => + ca.bufferDataType.zipWithIndex.map { case (dt, i) => + AttributeReference(s"aggr.${idx}_$i", dt)() } + }.toArray + + // The tuples of aggregate expressions with information + // (AggregateExpression2, Aggregate Function, Placeholder of AggregateExpression2 result) + @transient lazy val aggregateFunctionBinds: Array[AggregateFunctionBind] = { + var pos = 0 + computedAggregates.map { ae => + ae.initial(mode) + + // we connect all of the aggregation buffers in a single Row, + // and "BIND" the attribute references in a Hack way. + val bufferDataTypes = ae.bufferDataType + ae.initialBoundReference(for (i <- 0 until bufferDataTypes.length) yield { + BoundReference(pos + i, bufferDataTypes(i), true) + }) + pos += bufferDataTypes.length + + AggregateFunctionBind(ae, MutableLiteral(null, ae.dataType)) + } + } + + @transient lazy val groupByProjection = if (groupingExpressions.isEmpty) { + InterpretedMutableProjection(Nil) + } else { + new InterpretedMutableProjection(groupingExpressions, childOutput) + } + + // Indicate which stage we are running into + def mode: Mode + // This is provided by SparkPlan + def child: SparkPlan + // Group By Key Expressions + def groupingExpressions: Seq[Expression] + // Bounded Projection + def boundProjection: Seq[NamedExpression] +} + +sealed trait PreShuffle extends Aggregate { + self: Product => + + def boundProjection: Seq[NamedExpression] = projection.map { + case a: Attribute => // Attribute will be converted into BoundReference + Alias( + BindReferences.bindReference(a: Expression, childOutput), a.name)(a.exprId, a.qualifiers) + case a: NamedExpression => BindReferences.bindReference(a, childOutput) + } + + // The expression list for output, this is the unbound expressions + def projection: Seq[NamedExpression] +} + +sealed trait PostShuffle extends Aggregate { + self: Product => + /** + * Substituted version of boundProjection expressions which are used to compute final + * output rows given a group and the result of all aggregate computations. + */ + @transient lazy val finalExpressions = { + val resultMap = aggregateFunctionBinds.map { ae => ae.aggregate -> ae.substitution }.toMap + boundProjection.map { agg => + agg.transform { + case e: AggregateExpression2 if resultMap.contains(e) => resultMap(e) + } + } + }.map(e => {BindReferences.bindReference(e: Expression, childOutput)}) + + @transient lazy val finalProjection = new InterpretedMutableProjection(finalExpressions) + + def aggregateFunctionBinds: Array[AggregateFunctionBind] + + def createIterator( + aggregates: Array[AggregateExpression2], + iterator: Iterator[InputBufferSeens]) = { + val substitutions = aggregateFunctionBinds.map(_.substitution) + + new Iterator[Row] { + override final def hasNext: Boolean = iterator.hasNext + + override final def next(): Row = { + val keybuffer = iterator.next() + + var idx = 0 + while (idx < aggregates.length) { + // substitute the AggregateExpression2 value + substitutions(idx).value = aggregates(idx).terminate(keybuffer.buffer) + idx += 1 + } + + finalProjection(keybuffer.input) + } + } + } +} + +/** + * :: DeveloperApi :: + * Groups input data by `groupingExpressions` and computes the `projection` for each + * group. + * + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param projection expressions that are computed for each group. + * @param namedGroupingAttributes the attributes represent the output of the groupby expressions + * @param child the input data source. + */ +@DeveloperApi +case class AggregatePreShuffle( + groupingExpressions: Seq[Expression], + projection: Seq[NamedExpression], + namedGroupingAttributes: Seq[Attribute], + child: SparkPlan) + extends UnaryNode with PreShuffle { + + override def requiredChildDistribution = UnspecifiedDistribution :: Nil + + override def output = bufferSchema.map(_.toAttribute) ++ namedGroupingAttributes + + override def mode: Mode = PARTIAL1 // iterate & terminalPartial will be called + + /** + * Create Iterator for the in-memory hash map. + */ + private[this] def createIterator( + functions: Array[AggregateExpression2], + iterator: Iterator[InputBufferSeens]) = { + new Iterator[Row] { + private[this] val joinedRow = new JoinedRow + + override final def hasNext: Boolean = iterator.hasNext + + override final def next(): Row = { + val keybuffer = iterator.next() + var idx = 0 + while (idx < functions.length) { + functions(idx).terminatePartial(keybuffer.buffer) + idx += 1 + } + + joinedRow(keybuffer.buffer, keybuffer.input).copy() + } + } + } + + override def execute() = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + val aggregates = aggregateFunctionBinds.map(_.aggregate) + + if (groupingExpressions.isEmpty) { + // without group by keys + val buffer = new GenericMutableRow(bufferSchema.length) + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + idx += 1 + } + + while (iter.hasNext) { + val currentRow = iter.next() + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.iterate(ae.eval(currentRow), buffer) + idx += 1 + } + } + + createIterator(aggregates, Iterator(new InputBufferSeens().withBuffer(buffer))) + } else { + val results = new OpenHashMap[Row, InputBufferSeens]() + while (iter.hasNext) { + val currentRow = iter.next() + + val keys = groupByProjection(currentRow) + results(keys) match { + case null => + val buffer = new GenericMutableRow(bufferSchema.length) + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + val value = ae.eval(currentRow) + // TODO distinctLike? We need to store the "seen" for + // AggregationExpression that distinctLike=true + // This is a trade off between memory & computing + ae.reset(buffer) + ae.iterate(value, buffer) + idx += 1 + } + + val copies = keys.copy() + results(copies) = new InputBufferSeens(copies, buffer) + case inputbuffer => + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.iterate(ae.eval(currentRow), inputbuffer.buffer) + idx += 1 + } + + } + } + + createIterator(aggregates, results.iterator.map(_._2)) + } + } + } +} + +case class AggregatePostShuffle( + groupingExpressions: Seq[Expression], + boundProjection: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with PostShuffle { + + override def output = boundProjection.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + + override def mode: Mode = FINAL // merge & terminate will be called + + override def execute() = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + val aggregates = aggregateFunctionBinds.map(_.aggregate) + if (groupingExpressions.isEmpty) { + val buffer = new GenericMutableRow(bufferSchema.length) + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + idx += 1 + } + + while (iter.hasNext) { + val currentRow = iter.next() + + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.merge(currentRow, buffer) + idx += 1 + } + } + + createIterator(aggregates, Iterator(new InputBufferSeens().withBuffer(buffer))) + } else { + val results = new OpenHashMap[Row, InputBufferSeens]() + while (iter.hasNext) { + val currentRow = iter.next() + val keys = groupByProjection(currentRow) + results(keys) match { + case null => + val buffer = new GenericMutableRow(bufferSchema.length) + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + ae.merge(currentRow, buffer) + idx += 1 + } + results(keys.copy()) = new InputBufferSeens(currentRow.copy(), buffer) + case pair => + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.merge(currentRow, pair.buffer) + idx += 1 + } + } + } + + createIterator(aggregates, results.iterator.map(_._2)) + } + } + } +} + +// TODO Currently even if only a single DISTINCT exists in the aggregate expressions, we will +// not do partial aggregation (aggregating before shuffling), all of the data have to be shuffled +// to the reduce side and do aggregation directly, this probably causes the performance regression +// for Aggregation Function like CountDistinct etc. +case class DistinctAggregate( + groupingExpressions: Seq[Expression], + projection: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with PreShuffle with PostShuffle { + override def output = boundProjection.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + + override def mode: Mode = COMPLETE // iterate() & terminate() will be called + + override def execute() = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + val aggregates = aggregateFunctionBinds.map(_.aggregate) + if (groupingExpressions.isEmpty) { + val buffer = new GenericMutableRow(bufferSchema.length) + // TODO save the memory only for those DISTINCT aggregate expressions + val seens = new Array[OpenHashSet[Any]](aggregateFunctionBinds.length) + + var idx = 0 + while (idx < aggregateFunctionBinds.length) { + val ae = aggregates(idx) + ae.reset(buffer) + + if (ae.distinct) { + seens(idx) = new OpenHashSet[Any]() + } + + idx += 1 + } + val ibs = new InputBufferSeens().withBuffer(buffer).withSeens(seens) + + while (iter.hasNext) { + val currentRow = iter.next() + + var idx = 0 + while (idx < aggregateFunctionBinds.length) { + val ae = aggregates(idx) + val value = ae.eval(currentRow) + + if (ae.distinct) { + if (value != null && !seens(idx).contains(value)) { + ae.iterate(value, buffer) + seens(idx).add(value) + } + } else { + ae.iterate(value, buffer) + } + idx += 1 + } + } + + createIterator(aggregates, Iterator(ibs)) + } else { + val results = new OpenHashMap[Row, InputBufferSeens]() + + while (iter.hasNext) { + val currentRow = iter.next() + + val keys = groupByProjection(currentRow) + results(keys) match { + case null => + val buffer = new GenericMutableRow(bufferSchema.length) + // TODO save the memory only for those DISTINCT aggregate expressions + val seens = new Array[OpenHashSet[Any]](aggregateFunctionBinds.length) + + var idx = 0 + while (idx < aggregateFunctionBinds.length) { + val ae = aggregates(idx) + val value = ae.eval(currentRow) + ae.reset(buffer) + ae.iterate(value, buffer) + + if (ae.distinct) { + val seen = new OpenHashSet[Any]() + if (value != null) { + seen.add(value) + } + seens.update(idx, seen) + } + + idx += 1 + } + results(keys.copy()) = new InputBufferSeens(currentRow.copy(), buffer, seens) + + case inputBufferSeens => + var idx = 0 + while (idx < aggregateFunctionBinds.length) { + val ae = aggregates(idx) + val value = ae.eval(currentRow) + + if (ae.distinct) { + if (value != null && !inputBufferSeens.seens(idx).contains(value)) { + ae.iterate(value, inputBufferSeens.buffer) + inputBufferSeens.seens(idx).add(value) + } + } else { + ae.iterate(value, inputBufferSeens.buffer) + } + idx += 1 + } + } + } + + createIterator(aggregates, results.iterator.map(_._2)) + } + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index b8f294c262af..e06f89c655a2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -25,6 +25,8 @@ import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.spark.sql.catalyst.ParserDialect +import org.apache.spark.sql.catalyst.planning.AggregateExpressionSubsitution + import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap import scala.language.implicitConversions @@ -445,6 +447,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val hiveContext = self override def strategies: Seq[Strategy] = experimental.extraStrategies ++ Seq( + new HashAggregation2(HiveAggregateExpressionSubsitution), DataSourceStrategy, HiveCommandStrategy(self), HiveDDLStrategy, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 01f47352b231..2a950bd1dfd2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.aggregate2.{AggregateExpression2, COMPLETE, FINAL, PARTIAL1} +import org.apache.spark.sql.catalyst.planning.AggregateExpressionSubsitution import scala.collection.mutable.ArrayBuffer @@ -468,6 +470,112 @@ private[hive] case class HiveUdaf( def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true) } +private[hive] case class HiveGenericUdaf2( + funcWrapper: HiveFunctionWrapper, + children: Seq[Expression], + distinct: Boolean, + isUDAF: Boolean) extends AggregateExpression2 with HiveInspectors { + type UDFType = AbstractGenericUDAFResolver + + protected def createEvaluator = resolver.getEvaluator( + new SimpleGenericUDAFParameterInfo(inspectors, false, false)) + + // Hive UDAF evaluator + @transient + lazy val evaluator = createEvaluator + + @transient + protected lazy val resolver: AbstractGenericUDAFResolver = if (isUDAF) { + // if it's UDAF, we need the UDAF bridge + new GenericUDAFBridge(funcWrapper.createFunction()) + } else { + funcWrapper.createFunction() + } + + // Output data object inspector + @transient + lazy val objectInspector = createEvaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + + // Aggregation Buffer Inspector + @transient + lazy val bufferObjectInspector = { + createEvaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inspectors) + } + + // Input arguments object inspectors + @transient + lazy val inspectors = children.map(toInspector).toArray + + @transient + override val distinctLike: Boolean = { + val annotation = evaluator.getClass().getAnnotation(classOf[HiveUDFType]) + if (annotation == null || !annotation.distinctLike()) false else true + } + override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + + // Aggregation Buffer Data Type, We assume only 1 element for the Hive Aggregation Buffer + // It will be StructType if more than 1 element (Actually will be StructSettableObjectInspector) + override def bufferDataType: Seq[DataType] = inspectorToDataType(bufferObjectInspector) :: Nil + + // Output data type + override def dataType: DataType = inspectorToDataType(objectInspector) + + /////////////////////////////////////////////////////////////////////////////////////////////// + // The following code will be called within the executors // + /////////////////////////////////////////////////////////////////////////////////////////////// + @transient var bound: BoundReference = _ + + override def initialBoundReference(buffers: Seq[BoundReference]) = { + bound = buffers(0) + mode match { + case FINAL => evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(bufferObjectInspector)) + case COMPLETE => evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + case PARTIAL1 => evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inspectors) + } + } + + // Initialize (reinitialize) the aggregation buffer + override def reset(buf: MutableRow): Unit = { + val buffer = evaluator.getNewAggregationBuffer + evaluator.reset(buffer) + // This is a hack, we never use the mutable row as buffer, but define our own buffer, + // which is set as the first element of the buffer + buf(bound.ordinal) = buffer + } + + // Expect the aggregate function fills the aggregation buffer when fed with each value + // in the group + override def iterate(arguments: Any, buf: MutableRow): Unit = { + val args = arguments.asInstanceOf[Seq[AnyRef]].zip(inspectors).map { + case (value, oi) => wrap(value, oi) + }.toArray + + evaluator.iterate( + buf.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal), + args) + } + + // Merge 2 aggregation buffer, and write back to the later one + override def merge(value: Row, buf: MutableRow): Unit = { + val buffer = buf.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal) + evaluator.merge(buffer, wrap(value.get(bound.ordinal), bufferObjectInspector)) + } + + @deprecated + override def terminatePartial(buf: MutableRow): Unit = { + val buffer = buf.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal) + // this is for serialization + buf(bound.ordinal) = unwrap(evaluator.terminatePartial(buffer), bufferObjectInspector) + } + + // Output the final result by feeding the aggregation buffer + override def terminate(input: Row): Any = { + unwrap(evaluator.terminate( + input.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal)), + objectInspector) + } +} + /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow @@ -587,3 +695,13 @@ private[hive] case class HiveUdafFunction( } } +private[hive] object HiveAggregateExpressionSubsitution extends AggregateExpressionSubsitution { + override def subsitute(aggr: AggregateExpression): AggregateExpression2 = aggr match { + // TODO: we don't support distinct for Hive UDAF(Generic) yet from the user interface + case HiveGenericUdaf(funcWrapper, children) => + HiveGenericUdaf2(funcWrapper, children, distinct = false, isUDAF = false) + case HiveUdaf(funcWrapper, children) => + HiveGenericUdaf2(funcWrapper, children, distinct = false, isUDAF = true) + case _ => super.subsitute(aggr) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala new file mode 100644 index 000000000000..5b685f5ddf3f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.test.TestHive +import org.scalatest.BeforeAndAfter + +class AggregateSuite extends HiveComparisonTest with BeforeAndAfter { + override def beforeAll() { + TestHive.cacheTables = true + TestHive.setConf(SQLConf.AGGREGATE_2, "true") + } + + override def afterAll() { + TestHive.cacheTables = false + TestHive.setConf(SQLConf.AGGREGATE_2, "false") + } + + createQueryTest("aggregation without group by expressions #1", + """ + |SELECT + | count(value), + | max(key), + | min(key) + |FROM src + """.stripMargin, false) + + createQueryTest("aggregation without group by expressions #2", + """ + |SELECT + | count(value), + | max(key), + | min(key), + | sum(key) + |FROM src + """.stripMargin, false) + + createQueryTest("aggregation without group by expressions #3", + """ + |SELECT + | count(distinct value), + | max(key), + | min(key), + | sum(distinct key) + |FROM src + """.stripMargin, false) + + createQueryTest("aggregation without group by expressions #4", + """ + |SELECT + | count(distinct value), + | max(distinct key), + | min(distinct key), + | sum(distinct key) + |FROM src + """.stripMargin, false) + + createQueryTest("aggregation without group by expressions #5", + """ + |SELECT + | count(value) + 3, + | max(key) + 1, + | min(key) + 2, + | sum(key) + 5 + |FROM src + """.stripMargin, false) + + createQueryTest("aggregation without group by expressions #6", + """ + |SELECT + | count(distinct value) + 4, + | max(distinct key) + 2, + | min(distinct key) + 3, + | sum(distinct key) + 4 + |FROM src + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #1", + """ + |SELECT key + 3 as a, count(value), max(key), min(key) + |FROM src group by key, value + |ORDER BY a LIMIT 5 + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #2", + """ + |SELECT + | key + 3 as a, + | count(value), + | max(key), + | min(key), + | sum(key) + |FROM src + |GROUP BY key, value + |ORDER BY a LIMIT 5 + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #3", + """ + |SELECT + | key + 3 as a, + | count(distinct value), + | max(key), min(key), + | sum(distinct key) + |FROM src + |GROUP BY key, value + |ORDER BY a LIMIT 5 + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #4", + """ + |SELECT + | key + 3 as a, + | count(distinct value), + | max(distinct key), + | min(distinct key), + | sum(distinct key) + |FROM src + |GROUP BY key, value + |ORDER BY a LIMIT 5 + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #5", + """ + |SELECT + | (key + 3) * 2 as a, + | (key + 3) + count(distinct value), + | (key + 3) + max(distinct (key + 3)), + | (key + 3) + min(distinct key + 3), + | (key + 3) + sum(distinct (key + 3)) + |FROM src + |GROUP BY key + 3, value + |ORDER BY a LIMIT 5 + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #6", + """ + |SELECT + | stddev_pop(key) as a, + | stddev_samp(key) as b + |FROM src + |GROUP BY key + 3, value + |ORDER BY a, b LIMIT 5 + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #7", + """ + |SELECT + | stddev_pop(distinct key) as a, + | stddev_samp(distinct key) as b + |FROM src + |GROUP BY key + 3, value + |ORDER BY a, b LIMIT 5 + """.stripMargin, false) +} From 7fb0662df18a2cc588f54d83498ad8bc7a041990 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sat, 11 Apr 2015 14:33:12 +0800 Subject: [PATCH 02/20] Update the unit test to comment out the not support ones --- .../sql/hive/execution/AggregateSuite.scala | 74 +++++++++---------- 1 file changed, 36 insertions(+), 38 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala index 5b685f5ddf3f..a543db3bdc46 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala @@ -61,15 +61,16 @@ class AggregateSuite extends HiveComparisonTest with BeforeAndAfter { |FROM src """.stripMargin, false) - createQueryTest("aggregation without group by expressions #4", - """ - |SELECT - | count(distinct value), - | max(distinct key), - | min(distinct key), - | sum(distinct key) - |FROM src - """.stripMargin, false) +// TODO: NOT support the max(distinct key) for now +// createQueryTest("aggregation without group by expressions #4", +// """ +// |SELECT +// | count(distinct value), +// | max(distinct key), +// | min(distinct key), +// | sum(distinct key) +// |FROM src +// """.stripMargin, false) createQueryTest("aggregation without group by expressions #5", """ @@ -81,15 +82,16 @@ class AggregateSuite extends HiveComparisonTest with BeforeAndAfter { |FROM src """.stripMargin, false) - createQueryTest("aggregation without group by expressions #6", - """ - |SELECT - | count(distinct value) + 4, - | max(distinct key) + 2, - | min(distinct key) + 3, - | sum(distinct key) + 4 - |FROM src - """.stripMargin, false) +// TODO: NOT support the max(distinct key) for now +// createQueryTest("aggregation without group by expressions #6", +// """ +// |SELECT +// | count(distinct value) + 4, +// | max(distinct key) + 2, +// | min(distinct key) + 3, +// | sum(distinct key) + 4 +// |FROM src +// """.stripMargin, false) createQueryTest("aggregation with group by expressions #1", """ @@ -123,26 +125,13 @@ class AggregateSuite extends HiveComparisonTest with BeforeAndAfter { |ORDER BY a LIMIT 5 """.stripMargin, false) - createQueryTest("aggregation with group by expressions #4", - """ - |SELECT - | key + 3 as a, - | count(distinct value), - | max(distinct key), - | min(distinct key), - | sum(distinct key) - |FROM src - |GROUP BY key, value - |ORDER BY a LIMIT 5 - """.stripMargin, false) - createQueryTest("aggregation with group by expressions #5", """ |SELECT | (key + 3) * 2 as a, | (key + 3) + count(distinct value), - | (key + 3) + max(distinct (key + 3)), - | (key + 3) + min(distinct key + 3), + | (key + 3) + max(key + 3), + | (key + 3) + min(key + 3), | (key + 3) + sum(distinct (key + 3)) |FROM src |GROUP BY key + 3, value @@ -158,14 +147,23 @@ class AggregateSuite extends HiveComparisonTest with BeforeAndAfter { |GROUP BY key + 3, value |ORDER BY a, b LIMIT 5 """.stripMargin, false) - - createQueryTest("aggregation with group by expressions #7", +// TODO: NOT support the stddev_pop(distinct key) for now +// createQueryTest("aggregation with group by expressions #7", +// """ +// |SELECT +// | stddev_pop(distinct key) as a, +// | stddev_samp(distinct key) as b +// |FROM src +// |GROUP BY key + 3, value +// |ORDER BY a, b LIMIT 5 +// """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #8", """ |SELECT - | stddev_pop(distinct key) as a, - | stddev_samp(distinct key) as b + | (key + 3) + count(distinct value, key) as a |FROM src |GROUP BY key + 3, value - |ORDER BY a, b LIMIT 5 + |ORDER BY a LIMIT 5 """.stripMargin, false) } From f118ffc1a624cfb6c2561431d0e6e02dd6c9264e Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 14 Apr 2015 07:05:37 +0800 Subject: [PATCH 03/20] update the interface name --- .../expressions/aggregate2/aggregates.scala | 36 +++++++++---------- .../sql/execution/aggregate2/Aggregate.scala | 18 +++++----- .../org/apache/spark/sql/hive/hiveUdfs.scala | 4 +-- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala index 3c846c67e1ea..1eb1c1f00978 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -66,14 +66,14 @@ trait AggregateFunction2 { self: Product => // Specify the BoundReference for Aggregate Buffer - def initialBoundReference(buffers: Seq[BoundReference]): Unit + def initialize(buffers: Seq[BoundReference]): Unit // Initialize (reinitialize) the aggregation buffer def reset(buf: MutableRow): Unit // Expect the aggregate function fills the aggregation buffer when // fed with each value in the group - def iterate(arguments: Any, buf: MutableRow): Unit + def update(arguments: Any, buf: MutableRow): Unit // Merge 2 aggregation buffer, and write back to the later one def merge(value: Row, buf: MutableRow): Unit @@ -132,7 +132,7 @@ case class Min( @transient var aggr: BoundReference = _ /* Initialization on executors */ - override def initialBoundReference(buffers: Seq[BoundReference]): Unit = { + override def initialize(buffers: Seq[BoundReference]): Unit = { aggr = buffers(0) arg = MutableLiteral(null, dataType) buffer = MutableLiteral(null, dataType) @@ -143,7 +143,7 @@ case class Min( buf(aggr) = null } - override def iterate(argument: Any, buf: MutableRow): Unit = { + override def update(argument: Any, buf: MutableRow): Unit = { if (argument != null) { arg.value = argument buffer.value = buf(aggr) @@ -199,7 +199,7 @@ case class Average(child: Expression, distinct: Boolean = false) @transient var divide: Divide = _ /* Initialization on executors */ - override def initialBoundReference(buffers: Seq[BoundReference]): Unit = { + override def initialize(buffers: Seq[BoundReference]): Unit = { count = buffers(0) sum = buffers(1) @@ -218,7 +218,7 @@ case class Average(child: Expression, distinct: Boolean = false) buf(sum) = null } - override def iterate(argument: Any, buf: MutableRow): Unit = { + override def update(argument: Any, buf: MutableRow): Unit = { if (argument != null) { arg.value = argument buf(count) = buf.getLong(count) + 1 @@ -261,7 +261,7 @@ case class Max(child: Expression) @transient var buffer: MutableLiteral = _ @transient var cmp: GreaterThan = _ - override def initialBoundReference(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]) = { aggr = buffers(0) arg = MutableLiteral(null, dataType) buffer = MutableLiteral(null, dataType) @@ -272,7 +272,7 @@ case class Max(child: Expression) buf(aggr) = null } - override def iterate(argument: Any, buf: MutableRow): Unit = { + override def update(argument: Any, buf: MutableRow): Unit = { if (argument != null) { arg.value = argument buffer.value = buf(aggr) @@ -306,7 +306,7 @@ case class Count(child: Expression) /* The below code will be called in executors, be sure to mark the instance as transient */ @transient var aggr: BoundReference = _ - override def initialBoundReference(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]) = { aggr = buffers(0) } @@ -314,7 +314,7 @@ case class Count(child: Expression) buf(aggr) = 0L } - override def iterate(argument: Any, buf: MutableRow): Unit = { + override def update(argument: Any, buf: MutableRow): Unit = { if (argument != null) { if (buf.isNullAt(aggr)) { buf(aggr) = 1L @@ -347,7 +347,7 @@ case class CountDistinct(children: Seq[Expression]) /* The below code will be called in executors, be sure to mark the instance as transient */ @transient var aggr: BoundReference = _ - override def initialBoundReference(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]) = { aggr = buffers(0) } @@ -355,7 +355,7 @@ case class CountDistinct(children: Seq[Expression]) buf(aggr) = 0L } - override def iterate(argument: Any, buf: MutableRow): Unit = { + override def update(argument: Any, buf: MutableRow): Unit = { if (!argument.asInstanceOf[Seq[_]].exists(_ == null)) { // CountDistinct supports multiple expression, and ONLY IF // none of its expressions value equals null @@ -416,7 +416,7 @@ case class Sum(child: Expression, distinct: Boolean = false) lazy val DEFAULT_VALUE = Cast(Literal.create(0, IntegerType), dataType).eval() - override def initialBoundReference(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]) = { aggr = buffers(0) arg = MutableLiteral(null, dataType) sum = Add(arg, aggr) @@ -426,7 +426,7 @@ case class Sum(child: Expression, distinct: Boolean = false) buf(aggr) = null } - override def iterate(argument: Any, buf: MutableRow): Unit = { + override def update(argument: Any, buf: MutableRow): Unit = { if (argument != null) { if (buf.isNullAt(aggr)) { buf(aggr) = argument @@ -465,7 +465,7 @@ case class First(child: Expression, distinct: Boolean = false) /* The below code will be called in executors, be sure to mark the instance as transient */ @transient var aggr: BoundReference = _ - override def initialBoundReference(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]) = { aggr = buffers(0) } @@ -473,7 +473,7 @@ case class First(child: Expression, distinct: Boolean = false) buf(aggr) = null } - override def iterate(argument: Any, buf: MutableRow): Unit = { + override def update(argument: Any, buf: MutableRow): Unit = { if (buf.isNullAt(aggr)) { if (argument != null) { buf(aggr) = argument @@ -502,7 +502,7 @@ case class Last(child: Expression, distinct: Boolean = false) /* The below code will be called in executors, be sure to mark the instance as transient */ @transient var aggr: BoundReference = _ - override def initialBoundReference(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]) = { aggr = buffers(0) } @@ -510,7 +510,7 @@ case class Last(child: Expression, distinct: Boolean = false) buf(aggr) = null } - override def iterate(argument: Any, buf: MutableRow): Unit = { + override def update(argument: Any, buf: MutableRow): Unit = { if (argument != null) { buf(aggr) = argument } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala index ed5ff633746a..08b68237ff39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -96,7 +96,7 @@ sealed trait Aggregate { // we connect all of the aggregation buffers in a single Row, // and "BIND" the attribute references in a Hack way. val bufferDataTypes = ae.bufferDataType - ae.initialBoundReference(for (i <- 0 until bufferDataTypes.length) yield { + ae.initialize(for (i <- 0 until bufferDataTypes.length) yield { BoundReference(pos + i, bufferDataTypes(i), true) }) pos += bufferDataTypes.length @@ -245,7 +245,7 @@ case class AggregatePreShuffle( var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) - ae.iterate(ae.eval(currentRow), buffer) + ae.update(ae.eval(currentRow), buffer) idx += 1 } } @@ -268,7 +268,7 @@ case class AggregatePreShuffle( // AggregationExpression that distinctLike=true // This is a trade off between memory & computing ae.reset(buffer) - ae.iterate(value, buffer) + ae.update(value, buffer) idx += 1 } @@ -278,7 +278,7 @@ case class AggregatePreShuffle( var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) - ae.iterate(ae.eval(currentRow), inputbuffer.buffer) + ae.update(ae.eval(currentRow), inputbuffer.buffer) idx += 1 } @@ -411,11 +411,11 @@ case class DistinctAggregate( if (ae.distinct) { if (value != null && !seens(idx).contains(value)) { - ae.iterate(value, buffer) + ae.update(value, buffer) seens(idx).add(value) } } else { - ae.iterate(value, buffer) + ae.update(value, buffer) } idx += 1 } @@ -440,7 +440,7 @@ case class DistinctAggregate( val ae = aggregates(idx) val value = ae.eval(currentRow) ae.reset(buffer) - ae.iterate(value, buffer) + ae.update(value, buffer) if (ae.distinct) { val seen = new OpenHashSet[Any]() @@ -462,11 +462,11 @@ case class DistinctAggregate( if (ae.distinct) { if (value != null && !inputBufferSeens.seens(idx).contains(value)) { - ae.iterate(value, inputBufferSeens.buffer) + ae.update(value, inputBufferSeens.buffer) inputBufferSeens.seens(idx).add(value) } } else { - ae.iterate(value, inputBufferSeens.buffer) + ae.update(value, inputBufferSeens.buffer) } idx += 1 } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 2a950bd1dfd2..fe3ea88cffaa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -525,7 +525,7 @@ private[hive] case class HiveGenericUdaf2( /////////////////////////////////////////////////////////////////////////////////////////////// @transient var bound: BoundReference = _ - override def initialBoundReference(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]) = { bound = buffers(0) mode match { case FINAL => evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(bufferObjectInspector)) @@ -545,7 +545,7 @@ private[hive] case class HiveGenericUdaf2( // Expect the aggregate function fills the aggregation buffer when fed with each value // in the group - override def iterate(arguments: Any, buf: MutableRow): Unit = { + override def update(arguments: Any, buf: MutableRow): Unit = { val args = arguments.asInstanceOf[Seq[AnyRef]].zip(inspectors).map { case (value, oi) => wrap(value, oi) }.toArray From f0b9ec03f62ff8b024cde941bccae9857f096afa Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 15 Apr 2015 10:43:40 +0800 Subject: [PATCH 04/20] change the update method from Any to Row --- .../expressions/aggregate2/aggregates.scala | 34 ++++++++++++------- .../sql/execution/aggregate2/Aggregate.scala | 27 +++++++-------- .../org/apache/spark/sql/hive/hiveUdfs.scala | 5 +-- 3 files changed, 37 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala index 1eb1c1f00978..f9e3542f59a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -71,11 +71,11 @@ trait AggregateFunction2 { // Initialize (reinitialize) the aggregation buffer def reset(buf: MutableRow): Unit - // Expect the aggregate function fills the aggregation buffer when - // fed with each value in the group - def update(arguments: Any, buf: MutableRow): Unit + // Get the children value from the input row, and then + // merge it with the given aggregate buffer + def update(input: Row, buf: MutableRow): Unit - // Merge 2 aggregation buffer, and write back to the later one + // Merge 2 aggregation buffers, and write back to the later one def merge(value: Row, buf: MutableRow): Unit // Semantically we probably don't need this, however, we need it when @@ -143,7 +143,8 @@ case class Min( buf(aggr) = null } - override def update(argument: Any, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow): Unit = { + val argument = child.eval(input) if (argument != null) { arg.value = argument buffer.value = buf(aggr) @@ -218,7 +219,8 @@ case class Average(child: Expression, distinct: Boolean = false) buf(sum) = null } - override def update(argument: Any, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow): Unit = { + val argument = child.eval(input) if (argument != null) { arg.value = argument buf(count) = buf.getLong(count) + 1 @@ -272,7 +274,8 @@ case class Max(child: Expression) buf(aggr) = null } - override def update(argument: Any, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow): Unit = { + val argument = child.eval(input) if (argument != null) { arg.value = argument buffer.value = buf(aggr) @@ -314,7 +317,8 @@ case class Count(child: Expression) buf(aggr) = 0L } - override def update(argument: Any, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow): Unit = { + val argument = child.eval(input) if (argument != null) { if (buf.isNullAt(aggr)) { buf(aggr) = 1L @@ -355,8 +359,9 @@ case class CountDistinct(children: Seq[Expression]) buf(aggr) = 0L } - override def update(argument: Any, buf: MutableRow): Unit = { - if (!argument.asInstanceOf[Seq[_]].exists(_ == null)) { + override def update(input: Row, buf: MutableRow): Unit = { + val arguments = children.map(_.eval(input)) + if (!arguments.exists(_ == null)) { // CountDistinct supports multiple expression, and ONLY IF // none of its expressions value equals null if (buf.isNullAt(aggr)) { @@ -426,7 +431,8 @@ case class Sum(child: Expression, distinct: Boolean = false) buf(aggr) = null } - override def update(argument: Any, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow): Unit = { + val argument = child.eval(input) if (argument != null) { if (buf.isNullAt(aggr)) { buf(aggr) = argument @@ -473,7 +479,8 @@ case class First(child: Expression, distinct: Boolean = false) buf(aggr) = null } - override def update(argument: Any, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow): Unit = { + val argument = child.eval(input) if (buf.isNullAt(aggr)) { if (argument != null) { buf(aggr) = argument @@ -510,7 +517,8 @@ case class Last(child: Expression, distinct: Boolean = false) buf(aggr) = null } - override def update(argument: Any, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow): Unit = { + val argument = child.eval(input) if (argument != null) { buf(aggr) = argument } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala index 08b68237ff39..70c8607c10a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -245,7 +245,7 @@ case class AggregatePreShuffle( var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) - ae.update(ae.eval(currentRow), buffer) + ae.update(currentRow, buffer) idx += 1 } } @@ -263,12 +263,8 @@ case class AggregatePreShuffle( var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) - val value = ae.eval(currentRow) - // TODO distinctLike? We need to store the "seen" for - // AggregationExpression that distinctLike=true - // This is a trade off between memory & computing ae.reset(buffer) - ae.update(value, buffer) + ae.update(currentRow, buffer) idx += 1 } @@ -278,7 +274,7 @@ case class AggregatePreShuffle( var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) - ae.update(ae.eval(currentRow), inputbuffer.buffer) + ae.update(currentRow, inputbuffer.buffer) idx += 1 } @@ -407,15 +403,17 @@ case class DistinctAggregate( var idx = 0 while (idx < aggregateFunctionBinds.length) { val ae = aggregates(idx) - val value = ae.eval(currentRow) if (ae.distinct) { + val value = ae.eval(currentRow) if (value != null && !seens(idx).contains(value)) { - ae.update(value, buffer) + // TODO how to avoid the children expression evaluation + // within Aggregate Expression? + ae.update(currentRow, buffer) seens(idx).add(value) } } else { - ae.update(value, buffer) + ae.update(currentRow, buffer) } idx += 1 } @@ -438,11 +436,12 @@ case class DistinctAggregate( var idx = 0 while (idx < aggregateFunctionBinds.length) { val ae = aggregates(idx) - val value = ae.eval(currentRow) + ae.reset(buffer) - ae.update(value, buffer) + ae.update(currentRow, buffer) if (ae.distinct) { + val value = ae.eval(currentRow) val seen = new OpenHashSet[Any]() if (value != null) { seen.add(value) @@ -462,11 +461,11 @@ case class DistinctAggregate( if (ae.distinct) { if (value != null && !inputBufferSeens.seens(idx).contains(value)) { - ae.update(value, inputBufferSeens.buffer) + ae.update(currentRow, inputBufferSeens.buffer) inputBufferSeens.seens(idx).add(value) } } else { - ae.update(value, inputBufferSeens.buffer) + ae.update(currentRow, inputBufferSeens.buffer) } idx += 1 } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index fe3ea88cffaa..bc7669a8d7e2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -545,8 +545,9 @@ private[hive] case class HiveGenericUdaf2( // Expect the aggregate function fills the aggregation buffer when fed with each value // in the group - override def update(arguments: Any, buf: MutableRow): Unit = { - val args = arguments.asInstanceOf[Seq[AnyRef]].zip(inspectors).map { + override def update(input: Row, buf: MutableRow): Unit = { + val arguments = children.map(_.eval(input)) + val args = arguments.zip(inspectors).map { case (value, oi) => wrap(value, oi) }.toArray From bee0f95dca39d5b3bc886daaacf34b1d949bef98 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 16 Apr 2015 02:29:51 +0800 Subject: [PATCH 05/20] move the distinct into the udaf --- .../expressions/aggregate2/aggregates.scala | 81 +++++++++++-------- .../spark/sql/execution/SparkStrategies.scala | 24 +++--- .../sql/execution/aggregate2/Aggregate.scala | 52 ++++-------- .../org/apache/spark/sql/hive/hiveUdfs.scala | 23 +++--- 4 files changed, 89 insertions(+), 91 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala index f9e3542f59a7..3f1383eb378c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate2 +import java.util.{Set => JSet} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -72,8 +73,11 @@ trait AggregateFunction2 { def reset(buf: MutableRow): Unit // Get the children value from the input row, and then - // merge it with the given aggregate buffer - def update(input: Row, buf: MutableRow): Unit + // merge it with the given aggregate buffer, + // `seen` is the set that the value showed up, that's will + // be useful for distinct aggregate. And it probably be + // null for non-distinct aggregate + def update(input: Row, buf: MutableRow, seen: JSet[Any]): Unit // Merge 2 aggregation buffers, and write back to the later one def merge(value: Row, buf: MutableRow): Unit @@ -101,8 +105,6 @@ trait AggregateExpression2 extends Expression with AggregateFunction2 { def bufferDataType: Seq[DataType] = Nil // Is it a distinct aggregate expression? def distinct: Boolean - // Is it a distinct like aggregate expression (e.g. Min/Max is distinctLike, while avg is not) - def distinctLike: Boolean = false def nullable = true @@ -120,7 +122,6 @@ case class Min( extends UnaryAggregateExpression { override def distinct: Boolean = false - override def distinctLike: Boolean = true override def dataType = child.dataType override def bufferDataType: Seq[DataType] = dataType :: Nil override def toString = s"MIN($child)" @@ -143,7 +144,8 @@ case class Min( buf(aggr) = null } - override def update(input: Row, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow, seen: JSet[Any]): Unit = { + // we don't care about if the argument has existed or not in the seen val argument = child.eval(input) if (argument != null) { arg.value = argument @@ -219,15 +221,18 @@ case class Average(child: Expression, distinct: Boolean = false) buf(sum) = null } - override def update(input: Row, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow, seen: JSet[Any]): Unit = { val argument = child.eval(input) if (argument != null) { - arg.value = argument - buf(count) = buf.getLong(count) + 1 - if (buf.isNullAt(sum)) { - buf(sum) = cast.eval() - } else { - buf(sum) = add.eval(buf) + if (!distinct || !seen.contains(argument)) { + arg.value = argument + buf(count) = buf.getLong(count) + 1 + if (buf.isNullAt(sum)) { + buf(sum) = cast.eval() + } else { + buf(sum) = add.eval(buf) + } + if (distinct) seen.add(argument) } } } @@ -250,7 +255,6 @@ case class Average(child: Expression, distinct: Boolean = false) case class Max(child: Expression) extends UnaryAggregateExpression { override def distinct: Boolean = false - override def distinctLike: Boolean = true override def nullable = true override def dataType = child.dataType @@ -274,7 +278,8 @@ case class Max(child: Expression) buf(aggr) = null } - override def update(input: Row, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow, seen: JSet[Any]): Unit = { + // we don't care about if the argument has existed or not in the seen val argument = child.eval(input) if (argument != null) { arg.value = argument @@ -317,7 +322,9 @@ case class Count(child: Expression) buf(aggr) = 0L } - override def update(input: Row, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow, seen: JSet[Any]): Unit = { + // we don't care about if the argument has existed or not in the seen + // we here only handle the non distinct case val argument = child.eval(input) if (argument != null) { if (buf.isNullAt(aggr)) { @@ -359,15 +366,18 @@ case class CountDistinct(children: Seq[Expression]) buf(aggr) = 0L } - override def update(input: Row, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow, seen: JSet[Any]): Unit = { val arguments = children.map(_.eval(input)) if (!arguments.exists(_ == null)) { // CountDistinct supports multiple expression, and ONLY IF // none of its expressions value equals null - if (buf.isNullAt(aggr)) { - buf(aggr) = 1L - } else { - buf(aggr) = buf.getLong(aggr) + 1L + if (!seen.contains(arguments)) { + if (buf.isNullAt(aggr)) { + buf(aggr) = 1L + } else { + buf(aggr) = buf.getLong(aggr) + 1L + } + seen.add(arguments) } } } @@ -431,19 +441,22 @@ case class Sum(child: Expression, distinct: Boolean = false) buf(aggr) = null } - override def update(input: Row, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow, seen: JSet[Any]): Unit = { val argument = child.eval(input) - if (argument != null) { - if (buf.isNullAt(aggr)) { - buf(aggr) = argument + if (!distinct || !seen.contains(argument)) { + if (argument != null) { + if (buf.isNullAt(aggr)) { + buf(aggr) = argument + } else { + arg.value = argument + buf(aggr) = sum.eval(buf) + } } else { - arg.value = argument - buf(aggr) = sum.eval(buf) - } - } else { - if (buf.isNullAt(aggr)) { - buf(aggr) = DEFAULT_VALUE + if (buf.isNullAt(aggr)) { + buf(aggr) = DEFAULT_VALUE + } } + if (distinct) seen.add(argument) } } @@ -479,7 +492,8 @@ case class First(child: Expression, distinct: Boolean = false) buf(aggr) = null } - override def update(input: Row, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow, seen: JSet[Any]): Unit = { + // we don't care about if the argument has existed or not in the seen val argument = child.eval(input) if (buf.isNullAt(aggr)) { if (argument != null) { @@ -517,7 +531,8 @@ case class Last(child: Expression, distinct: Boolean = false) buf(aggr) = null } - override def update(input: Row, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow, seen: JSet[Any]): Unit = { + // we don't care about if the argument has existed or not in the seen val argument = child.eval(input) if (argument != null) { buf(aggr) = argument diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ce44e3b627e1..0ca9a9233337 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -130,19 +130,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggr.copy(aggregateExpressions = subusitutedAggrExpression) match { // Aggregations that can be performed in two phases, before and after the shuffle. case PartialAggregation2( - namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - aggregateExpressions, - child) => - aggregate2.AggregatePostShuffle( - namedGroupingAttributes, - rewrittenAggregateExpressions, - aggregate2.AggregatePreShuffle( - groupingExpressions, - aggregateExpressions, + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + aggregateExpressions, + child) => + aggregate2.AggregatePostShuffle( namedGroupingAttributes, - planLater(child))) :: Nil + rewrittenAggregateExpressions, + aggregate2.AggregatePreShuffle( + groupingExpressions, + aggregateExpressions, + namedGroupingAttributes, + planLater(child))) :: Nil case _ => Nil } case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala index 70c8607c10a5..d81c9b54bcc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.aggregate2 +import java.util.{HashSet=>JHashSet, Set=>JSet} + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate2._ @@ -41,7 +43,7 @@ sealed case class AggregateFunctionBind( sealed class InputBufferSeens( var input: Row, // var buffer: MutableRow, - var seens: Array[OpenHashSet[Any]] = null) { + var seens: Array[JSet[Any]] = null) { def this() { this(new GenericMutableRow(0), null) } @@ -56,7 +58,7 @@ sealed class InputBufferSeens( this } - def withSeens(seens: Array[OpenHashSet[Any]]): InputBufferSeens = { + def withSeens(seens: Array[JSet[Any]]): InputBufferSeens = { this.seens = seens this } @@ -245,7 +247,7 @@ case class AggregatePreShuffle( var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) - ae.update(currentRow, buffer) + ae.update(currentRow, buffer, null) idx += 1 } } @@ -264,7 +266,7 @@ case class AggregatePreShuffle( while (idx < aggregates.length) { val ae = aggregates(idx) ae.reset(buffer) - ae.update(currentRow, buffer) + ae.update(currentRow, buffer, null) idx += 1 } @@ -274,7 +276,7 @@ case class AggregatePreShuffle( var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) - ae.update(currentRow, inputbuffer.buffer) + ae.update(currentRow, inputbuffer.buffer, null) idx += 1 } @@ -382,7 +384,7 @@ case class DistinctAggregate( if (groupingExpressions.isEmpty) { val buffer = new GenericMutableRow(bufferSchema.length) // TODO save the memory only for those DISTINCT aggregate expressions - val seens = new Array[OpenHashSet[Any]](aggregateFunctionBinds.length) + val seens = new Array[JSet[Any]](aggregateFunctionBinds.length) var idx = 0 while (idx < aggregateFunctionBinds.length) { @@ -390,7 +392,7 @@ case class DistinctAggregate( ae.reset(buffer) if (ae.distinct) { - seens(idx) = new OpenHashSet[Any]() + seens(idx) = new JHashSet[Any]() } idx += 1 @@ -403,18 +405,8 @@ case class DistinctAggregate( var idx = 0 while (idx < aggregateFunctionBinds.length) { val ae = aggregates(idx) + ae.update(currentRow, buffer, seens(idx)) - if (ae.distinct) { - val value = ae.eval(currentRow) - if (value != null && !seens(idx).contains(value)) { - // TODO how to avoid the children expression evaluation - // within Aggregate Expression? - ae.update(currentRow, buffer) - seens(idx).add(value) - } - } else { - ae.update(currentRow, buffer) - } idx += 1 } } @@ -430,25 +422,19 @@ case class DistinctAggregate( results(keys) match { case null => val buffer = new GenericMutableRow(bufferSchema.length) - // TODO save the memory only for those DISTINCT aggregate expressions - val seens = new Array[OpenHashSet[Any]](aggregateFunctionBinds.length) + val seens = new Array[JSet[Any]](aggregateFunctionBinds.length) var idx = 0 while (idx < aggregateFunctionBinds.length) { val ae = aggregates(idx) - ae.reset(buffer) - ae.update(currentRow, buffer) if (ae.distinct) { - val value = ae.eval(currentRow) - val seen = new OpenHashSet[Any]() - if (value != null) { - seen.add(value) - } - seens.update(idx, seen) + val seen = new JHashSet[Any]() + seens(idx) = seen } + ae.update(currentRow, buffer, seens(idx)) idx += 1 } results(keys.copy()) = new InputBufferSeens(currentRow.copy(), buffer, seens) @@ -459,14 +445,8 @@ case class DistinctAggregate( val ae = aggregates(idx) val value = ae.eval(currentRow) - if (ae.distinct) { - if (value != null && !inputBufferSeens.seens(idx).contains(value)) { - ae.update(currentRow, inputBufferSeens.buffer) - inputBufferSeens.seens(idx).add(value) - } - } else { - ae.update(currentRow, inputBufferSeens.buffer) - } + ae.update(currentRow, inputBufferSeens.buffer, inputBufferSeens.seens(idx)) + idx += 1 } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index bc7669a8d7e2..d03964d57649 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.util.{Set=>JSet} + import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import org.apache.spark.sql.AnalysisException @@ -506,8 +508,7 @@ private[hive] case class HiveGenericUdaf2( @transient lazy val inspectors = children.map(toInspector).toArray - @transient - override val distinctLike: Boolean = { + private val distinctLike: Boolean = { val annotation = evaluator.getClass().getAnnotation(classOf[HiveUDFType]) if (annotation == null || !annotation.distinctLike()) false else true } @@ -545,15 +546,17 @@ private[hive] case class HiveGenericUdaf2( // Expect the aggregate function fills the aggregation buffer when fed with each value // in the group - override def update(input: Row, buf: MutableRow): Unit = { + override def update(input: Row, buf: MutableRow, seen: JSet[Any]): Unit = { val arguments = children.map(_.eval(input)) - val args = arguments.zip(inspectors).map { - case (value, oi) => wrap(value, oi) - }.toArray - - evaluator.iterate( - buf.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal), - args) + if (distinctLike || !distinct || !seen.contains(arguments)) { + val args = arguments.zip(inspectors).map { + case (value, oi) => wrap(value, oi) + }.toArray + + evaluator.iterate( + buf.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal), args) + if (distinct && !distinctLike) seen.add(arguments) + } } // Merge 2 aggregation buffer, and write back to the later one From 760164e59ef4c4cbfc9157a5e7de394faa298330 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 16 Apr 2015 15:50:31 +0800 Subject: [PATCH 06/20] simpify the aggregate expression by uing the Projection --- .../main/scala/org/apache/spark/sql/Row.scala | 13 +- .../expressions/aggregate2/aggregates.scala | 6 +- .../spark/sql/catalyst/expressions/rows.scala | 2 + .../sql/catalyst/planning/patterns.scala | 76 ++--- .../spark/sql/execution/SparkStrategies.scala | 24 +- .../sql/execution/aggregate2/Aggregate.scala | 284 +++++++----------- 6 files changed, 180 insertions(+), 225 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 0d460b634d9b..960fcf091a1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.util.hashing.MurmurHash3 -import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, MutableRow, GenericRow} import org.apache.spark.sql.types.StructType object Row { @@ -348,6 +348,17 @@ trait Row extends Serializable { */ def copy(): Row + def makeMutable(): MutableRow = { + val totalSize = length + val copiedValues = new Array[Any](totalSize) + var i = 0 + while(i < totalSize) { + copiedValues(i) = apply(i) + i += 1 + } + new GenericMutableRow(copiedValues) + } + /** Returns true if there are any NULL values in this row. */ def anyNull: Boolean = { val len = length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala index 3f1383eb378c..547a5823e2b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -88,7 +88,7 @@ trait AggregateFunction2 { def terminatePartial(buf: MutableRow): Unit = {} // Output the final result by feeding the aggregation buffer - def terminate(input: Row): Any + def terminate(buffer: Row): Any } trait AggregateExpression2 extends Expression with AggregateFunction2 { @@ -108,13 +108,11 @@ trait AggregateExpression2 extends Expression with AggregateFunction2 { def nullable = true - override def eval(input: Row): EvaluatedType = children.map(_.eval(input)) + final override def eval(aggrBuffer: Row): EvaluatedType = terminate(aggrBuffer) } abstract class UnaryAggregateExpression extends UnaryExpression with AggregateExpression2 { self: Product => - - override def eval(input: Row): EvaluatedType = child.eval(input) } case class Min( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 5fd892c42e69..a29ed83c9cd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -57,6 +57,7 @@ object EmptyRow extends Row { override def getString(i: Int): String = throw new UnsupportedOperationException override def getAs[T](i: Int): T = throw new UnsupportedOperationException override def copy(): Row = this + override def makeMutable(): MutableRow = throw new UnsupportedOperationException } /** @@ -174,6 +175,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } override def copy(): Row = this + override def makeMutable(): MutableRow = new GenericMutableRow(values.clone()) } class GenericRowWithSchema(values: Array[Any], override val schema: StructType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9ef8fea8e892..6dc0ccfb4f94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -150,7 +150,7 @@ object AggregateExpressionSubsitution extends AggregateExpressionSubsitution */ object PartialAggregation2 { type ReturnType = - (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) + (Seq[Attribute], Seq[NamedExpression], Seq[aggregate2.AggregateExpression2], Seq[NamedExpression], Seq[NamedExpression], LogicalPlan) def unapply(plan: LogicalPlan) : Option[ReturnType] = plan match { @@ -160,43 +160,45 @@ object PartialAggregation2 { case a: aggregate2.AggregateExpression2 => a }) - // Only do partial aggregation if supported by all aggregate expressions. - if (!allAggregates.exists(_.distinct)) { - // We need to pass all grouping expressions though so the grouping can happen a second - // time. However some of them might be unnamed so we alias them allowing them to be - // referenced in the second aggregation. - val namedGroupingExpressions = groupingExpressions.filter(!_.isInstanceOf[Literal]).map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - } - val substitutions = namedGroupingExpressions.toMap - - // Replace aggregations with a new expression that computes the result from the already - // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { - case e: Expression if substitutions.contains(e) => - substitutions(e).toAttribute - case e: Expression => - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - substitutions - .get(e.transform { case Alias(g: GetField, _) => g }) - .map(_.toAttribute) - .getOrElse(e) - }).asInstanceOf[Seq[NamedExpression]] - - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - - Some( - (namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - aggregateExpressions, - child)) - } else { - None + // We need to pass all grouping expressions though so the grouping can happen a second + // time. However some of them might be unnamed so we alias them allowing them to be + // referenced in the second aggregation. + val namedGroupingExpressions = groupingExpressions.map { + case n: NamedExpression => (n, n) + case other => (other, Alias(other, "PartialGroup")()) } + val substitutions = namedGroupingExpressions.toMap + + // Replace aggregations with a new expression that computes the result from the already + // computed partial evaluations and grouping values. + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + case e: Expression if substitutions.contains(e) => + substitutions(e).toAttribute + case e: AggregateExpression2 => e.transformChildrenDown { + // replace the child expression of the aggregate expression, with + // Literal, as in PostShuffle, we don't need children expression any + // more(Only aggregate buffer required). + case expr => MutableLiteral(null, expr.dataType, expr.nullable) + } + case e: Expression => + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + substitutions + .get(e.transform { case Alias(g: GetField, _) => g }) + .map(_.toAttribute) + .getOrElse(e) + }).asInstanceOf[Seq[NamedExpression]] + + val groupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + + Some( + (groupingAttributes, + namedGroupingExpressions.map(_._2), + allAggregates, + rewrittenAggregateExpressions, + aggregateExpressions, + child)) case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0ca9a9233337..c69e3901f67d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -130,19 +130,27 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggr.copy(aggregateExpressions = subusitutedAggrExpression) match { // Aggregations that can be performed in two phases, before and after the shuffle. case PartialAggregation2( - namedGroupingAttributes, - rewrittenAggregateExpressions, + groupingAttributes, groupingExpressions, + aggregates, + rewrittenAggregateExpressions, aggregateExpressions, child) => - aggregate2.AggregatePostShuffle( - namedGroupingAttributes, - rewrittenAggregateExpressions, - aggregate2.AggregatePreShuffle( + if (aggregates.exists(_.distinct)) { + aggregate2.DistinctAggregate( groupingExpressions, aggregateExpressions, - namedGroupingAttributes, - planLater(child))) :: Nil + rewrittenAggregateExpressions, + planLater(child)) :: Nil + } else { + aggregate2.AggregatePostShuffle( + groupingAttributes, + rewrittenAggregateExpressions, + aggregate2.AggregatePreShuffle( + groupingExpressions, + aggregates, + planLater(child))) :: Nil + } case _ => Nil } case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala index d81c9b54bcc2..d7c5213b7190 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -24,41 +24,22 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate2._ import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} -import org.apache.spark.util.collection.{OpenHashSet, OpenHashMap} +import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.plans.physical._ -/** - * An aggregate that needs to be computed for each row in a group. - * - * @param aggregate AggregateExpression2, associated with the function - * @param substitution A MutableLiteral used to refer to the result of this aggregate in the final - * output. - */ -sealed case class AggregateFunctionBind( - aggregate: AggregateExpression2, - substitution: MutableLiteral) - -sealed class InputBufferSeens( - var input: Row, // - var buffer: MutableRow, - var seens: Array[JSet[Any]] = null) { +sealed class BufferSeens(var buffer: MutableRow, var seens: Array[JSet[Any]] = null) { def this() { this(new GenericMutableRow(0), null) } - def withInput(row: Row): InputBufferSeens = { - this.input = row - this - } - - def withBuffer(row: MutableRow): InputBufferSeens = { + def withBuffer(row: MutableRow): BufferSeens = { this.buffer = row this } - def withSeens(seens: Array[JSet[Any]]): InputBufferSeens = { + def withSeens(seens: Array[JSet[Any]]): BufferSeens = { this.seens = seens this } @@ -69,30 +50,11 @@ sealed trait Aggregate { // HACK: Generators don't correctly preserve their output through serializations so we grab // out child's output attributes statically here. val childOutput = child.output - val isGlobalAggregation = groupingExpressions.isEmpty - - def computedAggregates: Array[AggregateExpression2] = { - boundProjection.flatMap { expr => - expr.collect { - case ae: AggregateExpression2 => ae - } - } - }.toArray - // This is a hack, instead of relying on the BindReferences for the aggregation - // buffer schema in PostShuffle, we have a strong protocols which represented as the - // BoundReferences in PostShuffle for aggregation buffer. - @transient lazy val bufferSchema: Array[AttributeReference] = - computedAggregates.zipWithIndex.flatMap { case (ca, idx) => - ca.bufferDataType.zipWithIndex.map { case (dt, i) => - AttributeReference(s"aggr.${idx}_$i", dt)() } - }.toArray - - // The tuples of aggregate expressions with information - // (AggregateExpression2, Aggregate Function, Placeholder of AggregateExpression2 result) - @transient lazy val aggregateFunctionBinds: Array[AggregateFunctionBind] = { + def initializedAndGetAggregates(mode: Mode, aggregates: Seq[AggregateExpression2]): Array[AggregateExpression2] = { var pos = 0 - computedAggregates.map { ae => + + aggregates.map { ae => ae.initial(mode) // we connect all of the aggregation buffers in a single Row, @@ -103,81 +65,33 @@ sealed trait Aggregate { }) pos += bufferDataTypes.length - AggregateFunctionBind(ae, MutableLiteral(null, ae.dataType)) - } - } - - @transient lazy val groupByProjection = if (groupingExpressions.isEmpty) { - InterpretedMutableProjection(Nil) - } else { - new InterpretedMutableProjection(groupingExpressions, childOutput) + ae + }.toArray } - // Indicate which stage we are running into - def mode: Mode // This is provided by SparkPlan def child: SparkPlan - // Group By Key Expressions - def groupingExpressions: Seq[Expression] - // Bounded Projection - def boundProjection: Seq[NamedExpression] -} - -sealed trait PreShuffle extends Aggregate { - self: Product => - - def boundProjection: Seq[NamedExpression] = projection.map { - case a: Attribute => // Attribute will be converted into BoundReference - Alias( - BindReferences.bindReference(a: Expression, childOutput), a.name)(a.exprId, a.qualifiers) - case a: NamedExpression => BindReferences.bindReference(a, childOutput) - } - // The expression list for output, this is the unbound expressions - def projection: Seq[NamedExpression] + def bufferSchema(aggregates: Seq[AggregateExpression2]): Seq[Attribute] = + aggregates.zipWithIndex.flatMap { case (ca, idx) => + ca.bufferDataType.zipWithIndex.map { case (dt, i) => + AttributeReference(s"aggr.${idx}_$i", dt)().toAttribute } + } } sealed trait PostShuffle extends Aggregate { self: Product => - /** - * Substituted version of boundProjection expressions which are used to compute final - * output rows given a group and the result of all aggregate computations. - */ - @transient lazy val finalExpressions = { - val resultMap = aggregateFunctionBinds.map { ae => ae.aggregate -> ae.substitution }.toMap - boundProjection.map { agg => - agg.transform { - case e: AggregateExpression2 if resultMap.contains(e) => resultMap(e) - } - } - }.map(e => {BindReferences.bindReference(e: Expression, childOutput)}) - - @transient lazy val finalProjection = new InterpretedMutableProjection(finalExpressions) - - def aggregateFunctionBinds: Array[AggregateFunctionBind] - - def createIterator( - aggregates: Array[AggregateExpression2], - iterator: Iterator[InputBufferSeens]) = { - val substitutions = aggregateFunctionBinds.map(_.substitution) - - new Iterator[Row] { - override final def hasNext: Boolean = iterator.hasNext - - override final def next(): Row = { - val keybuffer = iterator.next() - var idx = 0 - while (idx < aggregates.length) { - // substitute the AggregateExpression2 value - substitutions(idx).value = aggregates(idx).terminate(keybuffer.buffer) - idx += 1 - } - - finalProjection(keybuffer.input) + def computedAggregates(aggregateExpressions: Seq[NamedExpression]): Seq[AggregateExpression2] = { + aggregateExpressions.flatMap { expr => + expr.collect { + case ae: AggregateExpression2 => ae } } } + + def createIterator(projection: Projection, aggregates: Array[AggregateExpression2], iterator: Iterator[BufferSeens]) = + iterator.map(it => projection(it.buffer)) } /** @@ -185,34 +99,31 @@ sealed trait PostShuffle extends Aggregate { * Groups input data by `groupingExpressions` and computes the `projection` for each * group. * - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param projection expressions that are computed for each group. - * @param namedGroupingAttributes the attributes represent the output of the groupby expressions + * @param groupingExpressions the attributes represent the output of the groupby expressions + * @param unboundAggregateExpressions Unbound Aggregate Function List. * @param child the input data source. */ @DeveloperApi case class AggregatePreShuffle( - groupingExpressions: Seq[Expression], - projection: Seq[NamedExpression], - namedGroupingAttributes: Seq[Attribute], - child: SparkPlan) - extends UnaryNode with PreShuffle { - - override def requiredChildDistribution = UnspecifiedDistribution :: Nil + groupingExpressions: Seq[NamedExpression], + unboundAggregateExpressions: Seq[AggregateExpression2], + child: SparkPlan) + extends UnaryNode with Aggregate { - override def output = bufferSchema.map(_.toAttribute) ++ namedGroupingAttributes + val aggregateExpressions: Seq[AggregateExpression2] = unboundAggregateExpressions.map { + BindReferences.bindReference(_, childOutput) + } - override def mode: Mode = PARTIAL1 // iterate & terminalPartial will be called + override def requiredChildDistribution = UnspecifiedDistribution :: Nil + override def output = bufferSchema(aggregateExpressions) ++ groupingExpressions.map(_.toAttribute) /** * Create Iterator for the in-memory hash map. */ private[this] def createIterator( - functions: Array[AggregateExpression2], - iterator: Iterator[InputBufferSeens]) = { + functions: Array[AggregateExpression2], + iterator: Iterator[BufferSeens]) = { new Iterator[Row] { - private[this] val joinedRow = new JoinedRow - override final def hasNext: Boolean = iterator.hasNext override final def next(): Row = { @@ -223,18 +134,18 @@ case class AggregatePreShuffle( idx += 1 } - joinedRow(keybuffer.buffer, keybuffer.input).copy() + keybuffer.buffer } } } override def execute() = attachTree(this, "execute") { child.execute().mapPartitions { iter => - val aggregates = aggregateFunctionBinds.map(_.aggregate) + val aggregates = initializedAndGetAggregates(PARTIAL1, aggregateExpressions) if (groupingExpressions.isEmpty) { // without group by keys - val buffer = new GenericMutableRow(bufferSchema.length) + val buffer = new GenericMutableRow(output.length) var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) @@ -252,16 +163,17 @@ case class AggregatePreShuffle( } } - createIterator(aggregates, Iterator(new InputBufferSeens().withBuffer(buffer))) + createIterator(aggregates, Iterator(new BufferSeens().withBuffer(buffer))) } else { - val results = new OpenHashMap[Row, InputBufferSeens]() + val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) + val results = new OpenHashMap[Row, BufferSeens]() while (iter.hasNext) { val currentRow = iter.next() val keys = groupByProjection(currentRow) results(keys) match { case null => - val buffer = new GenericMutableRow(bufferSchema.length) + val buffer = new GenericMutableRow(output.length) var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) @@ -269,9 +181,14 @@ case class AggregatePreShuffle( ae.update(currentRow, buffer, null) idx += 1 } + var idx2 = 0 + while (idx2 < keys.length) { + buffer(idx) = keys(idx2) + idx2 += 1 + idx += 1 + } - val copies = keys.copy() - results(copies) = new InputBufferSeens(copies, buffer) + results(keys.copy()) = new BufferSeens(buffer, null) case inputbuffer => var idx = 0 while (idx < aggregates.length) { @@ -279,7 +196,6 @@ case class AggregatePreShuffle( ae.update(currentRow, inputbuffer.buffer, null) idx += 1 } - } } @@ -290,11 +206,11 @@ case class AggregatePreShuffle( } case class AggregatePostShuffle( - groupingExpressions: Seq[Expression], - boundProjection: Seq[NamedExpression], - child: SparkPlan) extends UnaryNode with PostShuffle { + groupingExpressions: Seq[Attribute], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with PostShuffle { - override def output = boundProjection.map(_.toAttribute) + override def output = aggregateExpressions.map(_.toAttribute) override def requiredChildDistribution: Seq[Distribution] = if (groupingExpressions == Nil) { AllTuples :: Nil @@ -302,13 +218,15 @@ case class AggregatePostShuffle( ClusteredDistribution(groupingExpressions) :: Nil } - override def mode: Mode = FINAL // merge & terminate will be called override def execute() = attachTree(this, "execute") { child.execute().mapPartitions { iter => - val aggregates = aggregateFunctionBinds.map(_.aggregate) + val aggregates = initializedAndGetAggregates(FINAL, computedAggregates(aggregateExpressions)) + + val finalProjection = new InterpretedMutableProjection(aggregateExpressions, childOutput) + if (groupingExpressions.isEmpty) { - val buffer = new GenericMutableRow(bufferSchema.length) + val buffer = new GenericMutableRow(output.length) var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) @@ -327,23 +245,19 @@ case class AggregatePostShuffle( } } - createIterator(aggregates, Iterator(new InputBufferSeens().withBuffer(buffer))) + createIterator(finalProjection, aggregates, Iterator(new BufferSeens().withBuffer(buffer))) } else { - val results = new OpenHashMap[Row, InputBufferSeens]() + val results = new OpenHashMap[Row, BufferSeens]() + val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) + while (iter.hasNext) { val currentRow = iter.next() val keys = groupByProjection(currentRow) results(keys) match { case null => - val buffer = new GenericMutableRow(bufferSchema.length) - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.reset(buffer) - ae.merge(currentRow, buffer) - idx += 1 - } - results(keys.copy()) = new InputBufferSeens(currentRow.copy(), buffer) + // TODO currentRow seems most likely a MutableRow + val buffer = currentRow.makeMutable() + results(keys.copy()) = new BufferSeens(buffer, null) case pair => var idx = 0 while (idx < aggregates.length) { @@ -354,7 +268,7 @@ case class AggregatePostShuffle( } } - createIterator(aggregates, results.iterator.map(_._2)) + createIterator(finalProjection, aggregates, results.iterator.map(_._2)) } } } @@ -365,10 +279,12 @@ case class AggregatePostShuffle( // to the reduce side and do aggregation directly, this probably causes the performance regression // for Aggregation Function like CountDistinct etc. case class DistinctAggregate( - groupingExpressions: Seq[Expression], - projection: Seq[NamedExpression], - child: SparkPlan) extends UnaryNode with PreShuffle with PostShuffle { - override def output = boundProjection.map(_.toAttribute) + groupingExpressions: Seq[NamedExpression], + unboundAggregateExpressions: Seq[NamedExpression], + rewrittenAggregateExpressions: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with PostShuffle { + + override def output = rewrittenAggregateExpressions.map(_.toAttribute) override def requiredChildDistribution: Seq[Distribution] = if (groupingExpressions == Nil) { AllTuples :: Nil @@ -376,18 +292,31 @@ case class DistinctAggregate( ClusteredDistribution(groupingExpressions) :: Nil } - override def mode: Mode = COMPLETE // iterate() & terminate() will be called + val aggregateExpressions: Seq[NamedExpression] = unboundAggregateExpressions.map { + BindReferences.bindReference(_, childOutput) + } override def execute() = attachTree(this, "execute") { child.execute().mapPartitions { iter => - val aggregates = aggregateFunctionBinds.map(_.aggregate) + val aggregates = initializedAndGetAggregates(COMPLETE, computedAggregates(aggregateExpressions)) + + val outputSchema: Seq[Attribute] = + aggregates.zipWithIndex.flatMap { case (ca, idx) => + ca.bufferDataType.zipWithIndex.map { case (dt, i) => + AttributeReference(s"aggr.${idx}_$i", dt)() + } + } ++ groupingExpressions.map(_.toAttribute) + + initializedAndGetAggregates(COMPLETE, computedAggregates(rewrittenAggregateExpressions)) + val finalProjection = new InterpretedMutableProjection( + rewrittenAggregateExpressions, outputSchema) + if (groupingExpressions.isEmpty) { - val buffer = new GenericMutableRow(bufferSchema.length) - // TODO save the memory only for those DISTINCT aggregate expressions - val seens = new Array[JSet[Any]](aggregateFunctionBinds.length) + val buffer = new GenericMutableRow(output.length) + val seens = new Array[JSet[Any]](aggregates.length) var idx = 0 - while (idx < aggregateFunctionBinds.length) { + while (idx < aggregates.length) { val ae = aggregates(idx) ae.reset(buffer) @@ -397,13 +326,13 @@ case class DistinctAggregate( idx += 1 } - val ibs = new InputBufferSeens().withBuffer(buffer).withSeens(seens) + val ibs = new BufferSeens().withBuffer(buffer).withSeens(seens) while (iter.hasNext) { val currentRow = iter.next() var idx = 0 - while (idx < aggregateFunctionBinds.length) { + while (idx < aggregates.length) { val ae = aggregates(idx) ae.update(currentRow, buffer, seens(idx)) @@ -411,9 +340,10 @@ case class DistinctAggregate( } } - createIterator(aggregates, Iterator(ibs)) + createIterator(finalProjection, aggregates, Iterator(ibs)) } else { - val results = new OpenHashMap[Row, InputBufferSeens]() + val results = new OpenHashMap[Row, BufferSeens]() + val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) while (iter.hasNext) { val currentRow = iter.next() @@ -421,11 +351,11 @@ case class DistinctAggregate( val keys = groupByProjection(currentRow) results(keys) match { case null => - val buffer = new GenericMutableRow(bufferSchema.length) - val seens = new Array[JSet[Any]](aggregateFunctionBinds.length) + val buffer = new GenericMutableRow(aggregates.length + keys.length) + val seens = new Array[JSet[Any]](aggregates.length) var idx = 0 - while (idx < aggregateFunctionBinds.length) { + while (idx < aggregates.length) { val ae = aggregates(idx) ae.reset(buffer) @@ -437,22 +367,26 @@ case class DistinctAggregate( ae.update(currentRow, buffer, seens(idx)) idx += 1 } - results(keys.copy()) = new InputBufferSeens(currentRow.copy(), buffer, seens) + var idx2 = 0 + while (idx2 < keys.length) { + buffer(idx) = keys(idx2) + idx2 += 1 + idx += 1 + } + results(keys.copy()) = new BufferSeens(buffer, seens) - case inputBufferSeens => + case bufferSeens => var idx = 0 - while (idx < aggregateFunctionBinds.length) { + while (idx < aggregates.length) { val ae = aggregates(idx) - val value = ae.eval(currentRow) - - ae.update(currentRow, inputBufferSeens.buffer, inputBufferSeens.seens(idx)) + ae.update(currentRow, bufferSeens.buffer, bufferSeens.seens(idx)) idx += 1 } } } - createIterator(aggregates, results.iterator.map(_._2)) + createIterator(finalProjection, aggregates, results.iterator.map(_._2)) } } } From 0849ca36bcc093c679269f7808e0745df1192121 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 16 Apr 2015 16:23:43 +0800 Subject: [PATCH 07/20] revert the uncessary changes --- .../expressions/aggregate2/aggregates.scala | 2 + .../sql/execution/aggregate2/Aggregate.scala | 292 ++++++------------ 2 files changed, 104 insertions(+), 190 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala index 547a5823e2b6..f99277717acf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -93,6 +93,8 @@ trait AggregateFunction2 { trait AggregateExpression2 extends Expression with AggregateFunction2 { self: Product => + implicit def boundReferenceToIndex(br: BoundReference): Int = br.ordinal + type EvaluatedType = Any var mode: Mode = COMPLETE diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala index d7c5213b7190..abd1d1cf7189 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -89,9 +89,6 @@ sealed trait PostShuffle extends Aggregate { } } } - - def createIterator(projection: Projection, aggregates: Array[AggregateExpression2], iterator: Iterator[BufferSeens]) = - iterator.map(it => projection(it.buffer)) } /** @@ -105,9 +102,9 @@ sealed trait PostShuffle extends Aggregate { */ @DeveloperApi case class AggregatePreShuffle( - groupingExpressions: Seq[NamedExpression], - unboundAggregateExpressions: Seq[AggregateExpression2], - child: SparkPlan) + groupingExpressions: Seq[NamedExpression], + unboundAggregateExpressions: Seq[AggregateExpression2], + child: SparkPlan) extends UnaryNode with Aggregate { val aggregateExpressions: Seq[AggregateExpression2] = unboundAggregateExpressions.map { @@ -143,72 +140,49 @@ case class AggregatePreShuffle( child.execute().mapPartitions { iter => val aggregates = initializedAndGetAggregates(PARTIAL1, aggregateExpressions) - if (groupingExpressions.isEmpty) { - // without group by keys - val buffer = new GenericMutableRow(output.length) - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.reset(buffer) - idx += 1 - } - - while (iter.hasNext) { - val currentRow = iter.next() - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.update(currentRow, buffer, null) - idx += 1 - } - } - - createIterator(aggregates, Iterator(new BufferSeens().withBuffer(buffer))) - } else { - val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) - val results = new OpenHashMap[Row, BufferSeens]() - while (iter.hasNext) { - val currentRow = iter.next() - - val keys = groupByProjection(currentRow) - results(keys) match { - case null => - val buffer = new GenericMutableRow(output.length) - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.reset(buffer) - ae.update(currentRow, buffer, null) - idx += 1 - } - var idx2 = 0 - while (idx2 < keys.length) { - buffer(idx) = keys(idx2) - idx2 += 1 - idx += 1 - } - - results(keys.copy()) = new BufferSeens(buffer, null) - case inputbuffer => - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.update(currentRow, inputbuffer.buffer, null) - idx += 1 - } - } + val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) + val results = new OpenHashMap[Row, BufferSeens]() + while (iter.hasNext) { + val currentRow = iter.next() + + val keys = groupByProjection(currentRow) + results(keys) match { + case null => + val buffer = new GenericMutableRow(output.length) + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + ae.update(currentRow, buffer, null) + idx += 1 + } + var idx2 = 0 + while (idx2 < keys.length) { + buffer(idx) = keys(idx2) + idx2 += 1 + idx += 1 + } + + results(keys.copy()) = new BufferSeens(buffer, null) + case inputbuffer => + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.update(currentRow, inputbuffer.buffer, null) + idx += 1 + } } - - createIterator(aggregates, results.iterator.map(_._2)) } + + createIterator(aggregates, results.iterator.map(_._2)) } } } case class AggregatePostShuffle( - groupingExpressions: Seq[Attribute], - aggregateExpressions: Seq[NamedExpression], - child: SparkPlan) extends UnaryNode with PostShuffle { + groupingExpressions: Seq[Attribute], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with PostShuffle { override def output = aggregateExpressions.map(_.toAttribute) @@ -218,58 +192,34 @@ case class AggregatePostShuffle( ClusteredDistribution(groupingExpressions) :: Nil } - override def execute() = attachTree(this, "execute") { child.execute().mapPartitions { iter => val aggregates = initializedAndGetAggregates(FINAL, computedAggregates(aggregateExpressions)) val finalProjection = new InterpretedMutableProjection(aggregateExpressions, childOutput) - if (groupingExpressions.isEmpty) { - val buffer = new GenericMutableRow(output.length) - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.reset(buffer) - idx += 1 - } - - while (iter.hasNext) { - val currentRow = iter.next() - - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.merge(currentRow, buffer) - idx += 1 - } + val results = new OpenHashMap[Row, BufferSeens]() + val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) + + while (iter.hasNext) { + val currentRow = iter.next() + val keys = groupByProjection(currentRow) + results(keys) match { + case null => + // TODO currentRow seems most likely a MutableRow + val buffer = currentRow.makeMutable() + results(keys.copy()) = new BufferSeens(buffer, null) + case pair => + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.merge(currentRow, pair.buffer) + idx += 1 + } } - - createIterator(finalProjection, aggregates, Iterator(new BufferSeens().withBuffer(buffer))) - } else { - val results = new OpenHashMap[Row, BufferSeens]() - val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) - - while (iter.hasNext) { - val currentRow = iter.next() - val keys = groupByProjection(currentRow) - results(keys) match { - case null => - // TODO currentRow seems most likely a MutableRow - val buffer = currentRow.makeMutable() - results(keys.copy()) = new BufferSeens(buffer, null) - case pair => - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.merge(currentRow, pair.buffer) - idx += 1 - } - } - } - - createIterator(finalProjection, aggregates, results.iterator.map(_._2)) } + + results.iterator.map(it => finalProjection(it._2.buffer)) } } } @@ -279,10 +229,10 @@ case class AggregatePostShuffle( // to the reduce side and do aggregation directly, this probably causes the performance regression // for Aggregation Function like CountDistinct etc. case class DistinctAggregate( - groupingExpressions: Seq[NamedExpression], - unboundAggregateExpressions: Seq[NamedExpression], - rewrittenAggregateExpressions: Seq[NamedExpression], - child: SparkPlan) extends UnaryNode with PostShuffle { + groupingExpressions: Seq[NamedExpression], + unboundAggregateExpressions: Seq[NamedExpression], + rewrittenAggregateExpressions: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with PostShuffle { override def output = rewrittenAggregateExpressions.map(_.toAttribute) @@ -300,94 +250,56 @@ case class DistinctAggregate( child.execute().mapPartitions { iter => val aggregates = initializedAndGetAggregates(COMPLETE, computedAggregates(aggregateExpressions)) - val outputSchema: Seq[Attribute] = - aggregates.zipWithIndex.flatMap { case (ca, idx) => - ca.bufferDataType.zipWithIndex.map { case (dt, i) => - AttributeReference(s"aggr.${idx}_$i", dt)() - } - } ++ groupingExpressions.map(_.toAttribute) + val outputSchema: Seq[Attribute] = bufferSchema(aggregates) ++ groupingExpressions.map(_.toAttribute) initializedAndGetAggregates(COMPLETE, computedAggregates(rewrittenAggregateExpressions)) - val finalProjection = new InterpretedMutableProjection( - rewrittenAggregateExpressions, outputSchema) - - if (groupingExpressions.isEmpty) { - val buffer = new GenericMutableRow(output.length) - val seens = new Array[JSet[Any]](aggregates.length) - - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.reset(buffer) - - if (ae.distinct) { - seens(idx) = new JHashSet[Any]() - } - - idx += 1 - } - val ibs = new BufferSeens().withBuffer(buffer).withSeens(seens) + val finalProjection = new InterpretedMutableProjection(rewrittenAggregateExpressions, outputSchema) - while (iter.hasNext) { - val currentRow = iter.next() + val results = new OpenHashMap[Row, BufferSeens]() + val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.update(currentRow, buffer, seens(idx)) + while (iter.hasNext) { + val currentRow = iter.next() - idx += 1 - } - } - - createIterator(finalProjection, aggregates, Iterator(ibs)) - } else { - val results = new OpenHashMap[Row, BufferSeens]() - val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) - - while (iter.hasNext) { - val currentRow = iter.next() - - val keys = groupByProjection(currentRow) - results(keys) match { - case null => - val buffer = new GenericMutableRow(aggregates.length + keys.length) - val seens = new Array[JSet[Any]](aggregates.length) - - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.reset(buffer) + val keys = groupByProjection(currentRow) + results(keys) match { + case null => + val buffer = new GenericMutableRow(aggregates.length + keys.length) + val seens = new Array[JSet[Any]](aggregates.length) - if (ae.distinct) { - val seen = new JHashSet[Any]() - seens(idx) = seen - } + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) - ae.update(currentRow, buffer, seens(idx)) - idx += 1 + if (ae.distinct) { + val seen = new JHashSet[Any]() + seens(idx) = seen } - var idx2 = 0 - while (idx2 < keys.length) { - buffer(idx) = keys(idx2) - idx2 += 1 - idx += 1 - } - results(keys.copy()) = new BufferSeens(buffer, seens) - - case bufferSeens => - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.update(currentRow, bufferSeens.buffer, bufferSeens.seens(idx)) - idx += 1 - } - } + ae.update(currentRow, buffer, seens(idx)) + idx += 1 + } + var idx2 = 0 + while (idx2 < keys.length) { + buffer(idx) = keys(idx2) + idx2 += 1 + idx += 1 + } + results(keys.copy()) = new BufferSeens(buffer, seens) + + case bufferSeens => + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.update(currentRow, bufferSeens.buffer, bufferSeens.seens(idx)) + + idx += 1 + } } - - createIterator(finalProjection, aggregates, results.iterator.map(_._2)) } + + results.iterator.map(it => finalProjection(it._2.buffer)) } } } From 472a44057028b8d97094951f5557bc0b5c900324 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 17 Apr 2015 01:45:12 +0800 Subject: [PATCH 08/20] Add Unit test --- .../Aggregate2CompatibilitySuite.scala | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/Aggregate2CompatibilitySuite.scala diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/Aggregate2CompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/Aggregate2CompatibilitySuite.scala new file mode 100644 index 000000000000..31023dffa4ae --- /dev/null +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/Aggregate2CompatibilitySuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.test.TestHive + +/** + * Test the aggregation framework2. + */ +class Aggregate2CompatibilitySuite extends HiveCompatibilitySuite { + override def beforeAll() { + super.beforeAll() + TestHive.setConf(SQLConf.AGGREGATE_2, "true") + } + + override def afterAll() { + TestHive.setConf(SQLConf.AGGREGATE_2, "false") + super.afterAll() + } + + override def whiteList = Seq( + "groupby1", + "groupby11", + "groupby12", + "groupby1_limit", + "groupby_grouping_id1", + "groupby_grouping_id2", + "groupby_grouping_sets1", + "groupby_grouping_sets2", + "groupby_grouping_sets3", + "groupby_grouping_sets4", + "groupby_grouping_sets5", + "groupby1_map", + "groupby1_map_nomap", + "groupby1_map_skew", + "groupby1_noskew", + "groupby2", + "groupby2_limit", + "groupby2_map", + "groupby2_map_skew", + "groupby2_noskew", + "groupby4", + "groupby4_map", + "groupby4_map_skew", + "groupby4_noskew", + "groupby5", + "groupby5_map", + "groupby5_map_skew", + "groupby5_noskew", + "groupby6", + "groupby6_map", + "groupby6_map_skew", + "groupby6_noskew", + "groupby7", + "groupby7_map", + "groupby7_map_multi_single_reducer", + "groupby7_map_skew", + "groupby7_noskew", + "groupby7_noskew_multi_single_reducer", + "groupby8", + "groupby8_map", + "groupby8_map_skew", + "groupby8_noskew", + "groupby9", + "groupby_distinct_samekey", + "groupby_map_ppr", + "groupby_multi_insert_common_distinct", + "groupby_multi_single_reducer2", + "groupby_multi_single_reducer3", + "groupby_mutli_insert_common_distinct", + "groupby_neg_float", + "groupby_ppd", + "groupby_ppr", + "groupby_sort_10", + "groupby_sort_2", + "groupby_sort_3", + "groupby_sort_4", + "groupby_sort_5", + "groupby_sort_6", + "groupby_sort_7", + "groupby_sort_8", + "groupby_sort_9", + "groupby_sort_test_1", + "having", + "udaf_collect_set", + "udaf_corr", + "udaf_covar_pop", + "udaf_covar_samp", + "udaf_histogram_numeric", + "udaf_number_format", + "udf_sum", + "udf_avg", + "udf_count" + ) +} From 241aee13e1ee7aad71964f3eea4e16469e3dc64b Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 17 Apr 2015 04:10:50 +0800 Subject: [PATCH 09/20] Add some doc --- .../sql/catalyst/planning/patterns.scala | 30 +++++----- .../spark/sql/execution/SparkStrategies.scala | 19 +++--- .../sql/execution/aggregate2/Aggregate.scala | 30 +++++----- .../org/apache/spark/sql/hive/hiveUdfs.scala | 15 +++-- .../sql/hive/execution/AggregateSuite.scala | 58 +++++++++---------- 5 files changed, 75 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 6dc0ccfb4f94..2fa6548085ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -142,21 +142,21 @@ object AggregateExpressionSubsitution extends AggregateExpressionSubsitution * return the rewritten aggregation expressions for both phases. * * The returned values for this match are as follows: - * - Grouping attributes for the final aggregation. - * - Aggregates for the final aggregation. - * - Grouping expressions for the partial aggregation. - * - Partial aggregate expressions. - * - Input to the aggregation. + * - Grouping expression (transformed to NamedExpression) list . + * - Aggregate Expressions extract from the projection. + * - Rewritten Projection for post shuffled. + * - Original Projection. + * - Child logical plan. */ object PartialAggregation2 { type ReturnType = - (Seq[Attribute], Seq[NamedExpression], Seq[aggregate2.AggregateExpression2], Seq[NamedExpression], Seq[NamedExpression], LogicalPlan) + (Seq[NamedExpression], Seq[aggregate2.AggregateExpression2], Seq[NamedExpression], Seq[NamedExpression], LogicalPlan) def unapply(plan: LogicalPlan) : Option[ReturnType] = plan match { - case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => + case logical.Aggregate(groupingExpressions, projection, child) => // Collect all aggregate expressions that can be computed partially. - val allAggregates = aggregateExpressions.flatMap(_ collect { + val allAggregates = projection.flatMap(_ collect { case a: aggregate2.AggregateExpression2 => a }) @@ -171,13 +171,14 @@ object PartialAggregation2 { // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + val rewrittenProjection = projection.map(_.transformUp { case e: Expression if substitutions.contains(e) => substitutions(e).toAttribute case e: AggregateExpression2 => e.transformChildrenDown { // replace the child expression of the aggregate expression, with // Literal, as in PostShuffle, we don't need children expression any - // more(Only aggregate buffer required). + // more(Only aggregate buffer required), otherwise, it will + // cause the attribute not binding exceptions. case expr => MutableLiteral(null, expr.dataType, expr.nullable) } case e: Expression => @@ -190,14 +191,11 @@ object PartialAggregation2 { .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] - val groupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - Some( - (groupingAttributes, - namedGroupingExpressions.map(_._2), + (namedGroupingExpressions.map(_._2), allAggregates, - rewrittenAggregateExpressions, - aggregateExpressions, + rewrittenProjection, + projection, child)) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c69e3901f67d..6de9c3f93543 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -130,24 +130,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggr.copy(aggregateExpressions = subusitutedAggrExpression) match { // Aggregations that can be performed in two phases, before and after the shuffle. case PartialAggregation2( - groupingAttributes, - groupingExpressions, + namedGroupingExpressions, aggregates, - rewrittenAggregateExpressions, - aggregateExpressions, + rewrittenProjection, + originalProjection, child) => if (aggregates.exists(_.distinct)) { aggregate2.DistinctAggregate( - groupingExpressions, - aggregateExpressions, - rewrittenAggregateExpressions, + namedGroupingExpressions, + originalProjection, + rewrittenProjection, planLater(child)) :: Nil } else { aggregate2.AggregatePostShuffle( - groupingAttributes, - rewrittenAggregateExpressions, + namedGroupingExpressions.map(_.toAttribute), + rewrittenProjection, aggregate2.AggregatePreShuffle( - groupingExpressions, + namedGroupingExpressions, aggregates, planLater(child))) :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala index abd1d1cf7189..c2d2064999a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -82,8 +82,8 @@ sealed trait Aggregate { sealed trait PostShuffle extends Aggregate { self: Product => - def computedAggregates(aggregateExpressions: Seq[NamedExpression]): Seq[AggregateExpression2] = { - aggregateExpressions.flatMap { expr => + def computedAggregates(projectionList: Seq[NamedExpression]): Seq[AggregateExpression2] = { + projectionList.flatMap { expr => expr.collect { case ae: AggregateExpression2 => ae } @@ -97,17 +97,17 @@ sealed trait PostShuffle extends Aggregate { * group. * * @param groupingExpressions the attributes represent the output of the groupby expressions - * @param unboundAggregateExpressions Unbound Aggregate Function List. + * @param originalProjection Unbound Aggregate Function List. * @param child the input data source. */ @DeveloperApi case class AggregatePreShuffle( groupingExpressions: Seq[NamedExpression], - unboundAggregateExpressions: Seq[AggregateExpression2], + originalProjection: Seq[AggregateExpression2], child: SparkPlan) extends UnaryNode with Aggregate { - val aggregateExpressions: Seq[AggregateExpression2] = unboundAggregateExpressions.map { + val aggregateExpressions: Seq[AggregateExpression2] = originalProjection.map { BindReferences.bindReference(_, childOutput) } @@ -181,10 +181,10 @@ case class AggregatePreShuffle( case class AggregatePostShuffle( groupingExpressions: Seq[Attribute], - aggregateExpressions: Seq[NamedExpression], + rewrittenProjection: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with PostShuffle { - override def output = aggregateExpressions.map(_.toAttribute) + override def output = rewrittenProjection.map(_.toAttribute) override def requiredChildDistribution: Seq[Distribution] = if (groupingExpressions == Nil) { AllTuples :: Nil @@ -194,9 +194,9 @@ case class AggregatePostShuffle( override def execute() = attachTree(this, "execute") { child.execute().mapPartitions { iter => - val aggregates = initializedAndGetAggregates(FINAL, computedAggregates(aggregateExpressions)) + val aggregates = initializedAndGetAggregates(FINAL, computedAggregates(rewrittenProjection)) - val finalProjection = new InterpretedMutableProjection(aggregateExpressions, childOutput) + val finalProjection = new InterpretedMutableProjection(rewrittenProjection, childOutput) val results = new OpenHashMap[Row, BufferSeens]() val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) @@ -230,11 +230,11 @@ case class AggregatePostShuffle( // for Aggregation Function like CountDistinct etc. case class DistinctAggregate( groupingExpressions: Seq[NamedExpression], - unboundAggregateExpressions: Seq[NamedExpression], - rewrittenAggregateExpressions: Seq[NamedExpression], + originalProjection: Seq[NamedExpression], + rewrittenProjection: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with PostShuffle { - override def output = rewrittenAggregateExpressions.map(_.toAttribute) + override def output = rewrittenProjection.map(_.toAttribute) override def requiredChildDistribution: Seq[Distribution] = if (groupingExpressions == Nil) { AllTuples :: Nil @@ -242,7 +242,7 @@ case class DistinctAggregate( ClusteredDistribution(groupingExpressions) :: Nil } - val aggregateExpressions: Seq[NamedExpression] = unboundAggregateExpressions.map { + val aggregateExpressions: Seq[NamedExpression] = originalProjection.map { BindReferences.bindReference(_, childOutput) } @@ -252,8 +252,8 @@ case class DistinctAggregate( val outputSchema: Seq[Attribute] = bufferSchema(aggregates) ++ groupingExpressions.map(_.toAttribute) - initializedAndGetAggregates(COMPLETE, computedAggregates(rewrittenAggregateExpressions)) - val finalProjection = new InterpretedMutableProjection(rewrittenAggregateExpressions, outputSchema) + initializedAndGetAggregates(COMPLETE, computedAggregates(rewrittenProjection)) + val finalProjection = new InterpretedMutableProjection(rewrittenProjection, outputSchema) val results = new OpenHashMap[Row, BufferSeens]() val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index d03964d57649..d51221e9bd2a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -476,7 +476,7 @@ private[hive] case class HiveGenericUdaf2( funcWrapper: HiveFunctionWrapper, children: Seq[Expression], distinct: Boolean, - isUDAF: Boolean) extends AggregateExpression2 with HiveInspectors { + isSimpleUDAF: Boolean) extends AggregateExpression2 with HiveInspectors { type UDFType = AbstractGenericUDAFResolver protected def createEvaluator = resolver.getEvaluator( @@ -487,8 +487,8 @@ private[hive] case class HiveGenericUdaf2( lazy val evaluator = createEvaluator @transient - protected lazy val resolver: AbstractGenericUDAFResolver = if (isUDAF) { - // if it's UDAF, we need the UDAF bridge + protected lazy val resolver: AbstractGenericUDAFResolver = if (isSimpleUDAF) { + // if it's the Simple UDAF, we need the UDAF bridge new GenericUDAFBridge(funcWrapper.createFunction()) } else { funcWrapper.createFunction() @@ -548,6 +548,9 @@ private[hive] case class HiveGenericUdaf2( // in the group override def update(input: Row, buf: MutableRow, seen: JSet[Any]): Unit = { val arguments = children.map(_.eval(input)) + // We assume the memory is much more critical than computation, + // so we prefer computation other than put the into a in-memory Set + // when the UDAF is distinct-Like if (distinctLike || !distinct || !seen.contains(arguments)) { val args = arguments.zip(inspectors).map { case (value, oi) => wrap(value, oi) @@ -701,11 +704,11 @@ private[hive] case class HiveUdafFunction( private[hive] object HiveAggregateExpressionSubsitution extends AggregateExpressionSubsitution { override def subsitute(aggr: AggregateExpression): AggregateExpression2 = aggr match { - // TODO: we don't support distinct for Hive UDAF(Generic) yet from the user interface + // TODO: we don't support distinct for Hive UDAF(Generic) yet from the HiveQL Parser yet case HiveGenericUdaf(funcWrapper, children) => - HiveGenericUdaf2(funcWrapper, children, distinct = false, isUDAF = false) + HiveGenericUdaf2(funcWrapper, children, distinct = false, isSimpleUDAF = false) case HiveUdaf(funcWrapper, children) => - HiveGenericUdaf2(funcWrapper, children, distinct = false, isUDAF = true) + HiveGenericUdaf2(funcWrapper, children, distinct = false, isSimpleUDAF = true) case _ => super.subsitute(aggr) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala index a543db3bdc46..0ddfe7e807b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala @@ -61,16 +61,15 @@ class AggregateSuite extends HiveComparisonTest with BeforeAndAfter { |FROM src """.stripMargin, false) -// TODO: NOT support the max(distinct key) for now -// createQueryTest("aggregation without group by expressions #4", -// """ -// |SELECT -// | count(distinct value), -// | max(distinct key), -// | min(distinct key), -// | sum(distinct key) -// |FROM src -// """.stripMargin, false) + createQueryTest("aggregation without group by expressions #4", + """ + |SELECT + | count(distinct value), + | max(distinct key), + | min(distinct key), + | sum(distinct key) + |FROM src + """.stripMargin, false) createQueryTest("aggregation without group by expressions #5", """ @@ -82,16 +81,15 @@ class AggregateSuite extends HiveComparisonTest with BeforeAndAfter { |FROM src """.stripMargin, false) -// TODO: NOT support the max(distinct key) for now -// createQueryTest("aggregation without group by expressions #6", -// """ -// |SELECT -// | count(distinct value) + 4, -// | max(distinct key) + 2, -// | min(distinct key) + 3, -// | sum(distinct key) + 4 -// |FROM src -// """.stripMargin, false) + createQueryTest("aggregation without group by expressions #6", + """ + |SELECT + | count(distinct value) + 4, + | max(distinct key) + 2, + | min(distinct key) + 3, + | sum(distinct key) + 4 + |FROM src + """.stripMargin, false) createQueryTest("aggregation with group by expressions #1", """ @@ -147,16 +145,16 @@ class AggregateSuite extends HiveComparisonTest with BeforeAndAfter { |GROUP BY key + 3, value |ORDER BY a, b LIMIT 5 """.stripMargin, false) -// TODO: NOT support the stddev_pop(distinct key) for now -// createQueryTest("aggregation with group by expressions #7", -// """ -// |SELECT -// | stddev_pop(distinct key) as a, -// | stddev_samp(distinct key) as b -// |FROM src -// |GROUP BY key + 3, value -// |ORDER BY a, b LIMIT 5 -// """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #7", + """ + |SELECT + | stddev_pop(distinct key) as a, + | stddev_samp(distinct key) as b + |FROM src + |GROUP BY key + 3, value + |ORDER BY a, b LIMIT 5 + """.stripMargin, false) createQueryTest("aggregation with group by expressions #8", """ From de96a136ad3a36e0d5c1a43bfdb8cf436a2ba34a Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 17 Apr 2015 05:33:15 +0800 Subject: [PATCH 10/20] style issues --- .../expressions/aggregate2/aggregates.scala | 63 ++++++++--------- .../sql/execution/aggregate2/Aggregate.scala | 69 +++++++++++++++---- 2 files changed, 86 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala index f99277717acf..01b7cba9a942 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -97,7 +97,7 @@ trait AggregateExpression2 extends Expression with AggregateFunction2 { type EvaluatedType = Any - var mode: Mode = COMPLETE + var mode: Mode = COMPLETE // will only be used by Hive UDAF def initial(m: Mode): Unit = { this.mode = m @@ -108,7 +108,7 @@ trait AggregateExpression2 extends Expression with AggregateFunction2 { // Is it a distinct aggregate expression? def distinct: Boolean - def nullable = true + def nullable: Boolean = true final override def eval(aggrBuffer: Row): EvaluatedType = terminate(aggrBuffer) } @@ -117,14 +117,12 @@ abstract class UnaryAggregateExpression extends UnaryExpression with AggregateEx self: Product => } -case class Min( - child: Expression) - extends UnaryAggregateExpression { +case class Min(child: Expression) extends UnaryAggregateExpression { override def distinct: Boolean = false - override def dataType = child.dataType + override def dataType: DataType = child.dataType override def bufferDataType: Seq[DataType] = dataType :: Nil - override def toString = s"MIN($child)" + override def toString: String = s"MIN($child)" /* The below code will be called in executors, be sure to make the instance transientable */ @transient var arg: MutableLiteral = _ @@ -171,9 +169,9 @@ case class Min( case class Average(child: Expression, distinct: Boolean = false) extends UnaryAggregateExpression { - override def nullable = false + override def nullable: Boolean = false - override def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive case DecimalType.Unlimited => @@ -183,7 +181,7 @@ case class Average(child: Expression, distinct: Boolean = false) } override def bufferDataType: Seq[DataType] = LongType :: dataType :: Nil - override def toString = s"AVG($child)" + override def toString: String = s"AVG($child)" /* The below code will be called in executors, be sure to mark the instance as transient */ @transient var count: BoundReference = _ @@ -252,14 +250,13 @@ case class Average(child: Expression, distinct: Boolean = false) override def terminate(row: Row): Any = if (count.eval(row) == 0) null else divide.eval(row) } -case class Max(child: Expression) - extends UnaryAggregateExpression { +case class Max(child: Expression) extends UnaryAggregateExpression { override def distinct: Boolean = false - override def nullable = true - override def dataType = child.dataType + override def nullable: Boolean = true + override def dataType: DataType = child.dataType override def bufferDataType: Seq[DataType] = dataType :: Nil - override def toString = s"MAX($child)" + override def toString: String = s"MAX($child)" /* The below code will be called in executors, be sure to mark the instance as transient */ @transient var aggr: BoundReference = _ @@ -267,7 +264,7 @@ case class Max(child: Expression) @transient var buffer: MutableLiteral = _ @transient var cmp: GreaterThan = _ - override def initialize(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]): Unit = { aggr = buffers(0) arg = MutableLiteral(null, dataType) buffer = MutableLiteral(null, dataType) @@ -306,15 +303,15 @@ case class Max(child: Expression) case class Count(child: Expression) extends UnaryAggregateExpression { def distinct: Boolean = false - override def nullable = false - override def dataType = LongType + override def nullable: Boolean = false + override def dataType: DataType = LongType override def bufferDataType: Seq[DataType] = LongType :: Nil - override def toString = s"COUNT($child)" + override def toString: String = s"COUNT($child)" /* The below code will be called in executors, be sure to mark the instance as transient */ @transient var aggr: BoundReference = _ - override def initialize(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]): Unit = { aggr = buffers(0) } @@ -351,14 +348,14 @@ case class Count(child: Expression) case class CountDistinct(children: Seq[Expression]) extends AggregateExpression2 { def distinct: Boolean = true - override def nullable = false - override def dataType = LongType - override def toString = s"COUNT($children)" + override def nullable: Boolean = false + override def dataType: DataType = LongType + override def toString: String = s"COUNT($children)" override def bufferDataType: Seq[DataType] = LongType :: Nil /* The below code will be called in executors, be sure to mark the instance as transient */ @transient var aggr: BoundReference = _ - override def initialize(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]): Unit = { aggr = buffers(0) } @@ -411,8 +408,8 @@ case class CountDistinct(children: Seq[Expression]) */ case class Sum(child: Expression, distinct: Boolean = false) extends UnaryAggregateExpression { - override def nullable = true - override def dataType = child.dataType match { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive case DecimalType.Unlimited => @@ -422,7 +419,7 @@ case class Sum(child: Expression, distinct: Boolean = false) } override def bufferDataType: Seq[DataType] = dataType :: Nil - override def toString = s"SUM($child)" + override def toString: String = s"SUM($child)" /* The below code will be called in executors, be sure to mark the instance as transient */ @transient var aggr: BoundReference = _ @@ -431,7 +428,7 @@ case class Sum(child: Expression, distinct: Boolean = false) lazy val DEFAULT_VALUE = Cast(Literal.create(0, IntegerType), dataType).eval() - override def initialize(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]): Unit = { aggr = buffers(0) arg = MutableLiteral(null, dataType) sum = Add(arg, aggr) @@ -476,15 +473,15 @@ case class Sum(child: Expression, distinct: Boolean = false) case class First(child: Expression, distinct: Boolean = false) extends UnaryAggregateExpression { - override def nullable = true - override def dataType = child.dataType + override def nullable: Boolean = true + override def dataType: DataType = child.dataType override def bufferDataType: Seq[DataType] = dataType :: Nil - override def toString = s"FIRST($child)" + override def toString: String = s"FIRST($child)" /* The below code will be called in executors, be sure to mark the instance as transient */ @transient var aggr: BoundReference = _ - override def initialize(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]): Unit = { aggr = buffers(0) } @@ -523,7 +520,7 @@ case class Last(child: Expression, distinct: Boolean = false) /* The below code will be called in executors, be sure to mark the instance as transient */ @transient var aggr: BoundReference = _ - override def initialize(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]): Unit = { aggr = buffers(0) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala index c2d2064999a5..05a4e3dd71e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -29,6 +29,7 @@ import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.plans.physical._ +// A class of the Aggregate buffer & Seen Set pair sealed class BufferSeens(var buffer: MutableRow, var seens: Array[JSet[Any]] = null) { def this() { this(new GenericMutableRow(0), null) @@ -51,14 +52,20 @@ sealed trait Aggregate { // out child's output attributes statically here. val childOutput = child.output - def initializedAndGetAggregates(mode: Mode, aggregates: Seq[AggregateExpression2]): Array[AggregateExpression2] = { + // initialize the aggregate functions, this will be called in the beginning of every partition + // data processing + def initializedAndGetAggregates( + mode: Mode, + aggregates: Seq[AggregateExpression2]) + : Array[AggregateExpression2] = { var pos = 0 aggregates.map { ae => ae.initial(mode) - // we connect all of the aggregation buffers in a single Row, - // and "BIND" the attribute references in a Hack way. + // We connect all of the aggregation buffers in a single Row, + // and "BIND" the attribute references in a Hack way, as we believe + // the Pre/Post Shuffle Aggregate are actually tightly coupled val bufferDataTypes = ae.bufferDataType ae.initialize(for (i <- 0 until bufferDataTypes.length) yield { BoundReference(pos + i, bufferDataTypes(i), true) @@ -72,9 +79,13 @@ sealed trait Aggregate { // This is provided by SparkPlan def child: SparkPlan + // The schema of the aggregate buffers, as we lines those buffers + // in a single row. def bufferSchema(aggregates: Seq[AggregateExpression2]): Seq[Attribute] = aggregates.zipWithIndex.flatMap { case (ca, idx) => ca.bufferDataType.zipWithIndex.map { case (dt, i) => + // the attribute names is useless here, as we bind the attribute + // in a hack way, see [[initializedAndGetAggregates]] AttributeReference(s"aggr.${idx}_$i", dt)().toAttribute } } } @@ -82,6 +93,7 @@ sealed trait Aggregate { sealed trait PostShuffle extends Aggregate { self: Product => + // extract the aggregate function from the projection def computedAggregates(projectionList: Seq[NamedExpression]): Seq[AggregateExpression2] = { projectionList.flatMap { expr => expr.collect { @@ -96,7 +108,7 @@ sealed trait PostShuffle extends Aggregate { * Groups input data by `groupingExpressions` and computes the `projection` for each * group. * - * @param groupingExpressions the attributes represent the output of the groupby expressions + * @param groupingExpressions the attributes represent the output of the grouping expressions * @param originalProjection Unbound Aggregate Function List. * @param child the input data source. */ @@ -118,26 +130,30 @@ case class AggregatePreShuffle( * Create Iterator for the in-memory hash map. */ private[this] def createIterator( - functions: Array[AggregateExpression2], - iterator: Iterator[BufferSeens]) = { + functions: Array[AggregateExpression2], + iterator: Iterator[BufferSeens]) = { new Iterator[Row] { override final def hasNext: Boolean = iterator.hasNext override final def next(): Row = { - val keybuffer = iterator.next() + val keyBuffer = iterator.next() var idx = 0 while (idx < functions.length) { - functions(idx).terminatePartial(keybuffer.buffer) + // terminatedPartial is for Hive UDAF, we + // provide an opportunity to transform its internal aggregate buffer into + // the catalyst data. + functions(idx).terminatePartial(keyBuffer.buffer) idx += 1 } - keybuffer.buffer + keyBuffer.buffer } } } override def execute() = attachTree(this, "execute") { child.execute().mapPartitions { iter => + // the input is every single row val aggregates = initializedAndGetAggregates(PARTIAL1, aggregateExpressions) val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) @@ -149,6 +165,7 @@ case class AggregatePreShuffle( results(keys) match { case null => val buffer = new GenericMutableRow(output.length) + // handle the aggregate buffers var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) @@ -156,6 +173,7 @@ case class AggregatePreShuffle( ae.update(currentRow, buffer, null) idx += 1 } + // handle the grouping expressions var idx2 = 0 while (idx2 < keys.length) { buffer(idx) = keys(idx2) @@ -174,6 +192,7 @@ case class AggregatePreShuffle( } } + // The output is the (Aggregate Buffers + Grouping Expression Values) createIterator(aggregates, results.iterator.map(_._2)) } } @@ -194,6 +213,7 @@ case class AggregatePostShuffle( override def execute() = attachTree(this, "execute") { child.execute().mapPartitions { iter => + // The input Row in the format of (AggregateBuffers + GroupingExpression Values) val aggregates = initializedAndGetAggregates(FINAL, computedAggregates(rewrittenProjection)) val finalProjection = new InterpretedMutableProjection(rewrittenProjection, childOutput) @@ -206,8 +226,20 @@ case class AggregatePostShuffle( val keys = groupByProjection(currentRow) results(keys) match { case null => - // TODO currentRow seems most likely a MutableRow + // TODO actually what we need to copy is the grouping expression values + // as the aggregate buffer will be reset. val buffer = currentRow.makeMutable() + // The reason why need to reset it first and merge with the input row is, + // in Hive UDAF, we need to provide an opportunity that the buffer can be the + // custom type, Otherwise, HIVE UDAF will wrap/unwrap in every merge() method + // calls, which is every expensive + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + ae.merge(currentRow, buffer) + idx += 1 + } results(keys.copy()) = new BufferSeens(buffer, null) case pair => var idx = 0 @@ -219,7 +251,9 @@ case class AggregatePostShuffle( } } - results.iterator.map(it => finalProjection(it._2.buffer)) + // final Project is simple a rewrite version of output expression list + // which will project as the final output + results.iterator.map { it => finalProjection(it._2.buffer) } } } } @@ -242,17 +276,23 @@ case class DistinctAggregate( ClusteredDistribution(groupingExpressions) :: Nil } + // binding the expression, which takes the child's output as input val aggregateExpressions: Seq[NamedExpression] = originalProjection.map { BindReferences.bindReference(_, childOutput) } override def execute() = attachTree(this, "execute") { child.execute().mapPartitions { iter => - val aggregates = initializedAndGetAggregates(COMPLETE, computedAggregates(aggregateExpressions)) + // initialize the aggregate functions for input rows (update/terminatePartial will be called) + val aggregates = + initializedAndGetAggregates(COMPLETE, computedAggregates(aggregateExpressions)) - val outputSchema: Seq[Attribute] = bufferSchema(aggregates) ++ groupingExpressions.map(_.toAttribute) + val outputSchema = bufferSchema(aggregates) ++ groupingExpressions.map(_.toAttribute) + // initialize the aggregate functions for the final output (merge/terminate will be called) initializedAndGetAggregates(COMPLETE, computedAggregates(rewrittenProjection)) + // binding the output projection, which takes the aggregate buffer and grouping keys + // as the input row. val finalProjection = new InterpretedMutableProjection(rewrittenProjection, outputSchema) val results = new OpenHashMap[Row, BufferSeens]() @@ -267,6 +307,7 @@ case class DistinctAggregate( val buffer = new GenericMutableRow(aggregates.length + keys.length) val seens = new Array[JSet[Any]](aggregates.length) + // handle the aggregate buffers var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) @@ -280,6 +321,8 @@ case class DistinctAggregate( ae.update(currentRow, buffer, seens(idx)) idx += 1 } + + // handle the grouping expression value var idx2 = 0 while (idx2 < keys.length) { buffer(idx) = keys(idx2) From 483b381ec9b946cf169b7c71a3f2ea6d593b0196 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 17 Apr 2015 05:40:43 +0800 Subject: [PATCH 11/20] more style issues --- .../expressions/aggregate2/aggregates.scala | 6 +++--- .../spark/sql/catalyst/planning/patterns.scala | 6 +++++- .../sql/execution/aggregate2/Aggregate.scala | 16 +++++++++------- .../org/apache/spark/sql/hive/hiveUdfs.scala | 5 +++-- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala index 01b7cba9a942..a996c65dfb91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -512,10 +512,10 @@ case class First(child: Expression, distinct: Boolean = false) case class Last(child: Expression, distinct: Boolean = false) extends UnaryAggregateExpression { - override def nullable = true - override def dataType = child.dataType + override def nullable: Boolean = true + override def dataType: DataType = child.dataType override def bufferDataType: Seq[DataType] = dataType :: Nil - override def toString = s"LAST($child)" + override def toString: String = s"LAST($child)" /* The below code will be called in executors, be sure to mark the instance as transient */ @transient var aggr: BoundReference = _ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 2fa6548085ab..7a3f637351ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -150,7 +150,11 @@ object AggregateExpressionSubsitution extends AggregateExpressionSubsitution */ object PartialAggregation2 { type ReturnType = - (Seq[NamedExpression], Seq[aggregate2.AggregateExpression2], Seq[NamedExpression], Seq[NamedExpression], LogicalPlan) + (Seq[NamedExpression], + Seq[aggregate2.AggregateExpression2], + Seq[NamedExpression], + Seq[NamedExpression], + LogicalPlan) def unapply(plan: LogicalPlan) : Option[ReturnType] = plan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala index 05a4e3dd71e4..0438e4938fde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate2 import java.util.{HashSet=>JHashSet, Set=>JSet} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate2._ import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} @@ -123,8 +124,9 @@ case class AggregatePreShuffle( BindReferences.bindReference(_, childOutput) } - override def requiredChildDistribution = UnspecifiedDistribution :: Nil - override def output = bufferSchema(aggregateExpressions) ++ groupingExpressions.map(_.toAttribute) + override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: Nil + override def output: Seq[Attribute] = + bufferSchema(aggregateExpressions) ++ groupingExpressions.map(_.toAttribute) /** * Create Iterator for the in-memory hash map. @@ -151,7 +153,7 @@ case class AggregatePreShuffle( } } - override def execute() = attachTree(this, "execute") { + override def execute(): RDD[Row] = attachTree(this, "execute") { child.execute().mapPartitions { iter => // the input is every single row val aggregates = initializedAndGetAggregates(PARTIAL1, aggregateExpressions) @@ -203,7 +205,7 @@ case class AggregatePostShuffle( rewrittenProjection: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with PostShuffle { - override def output = rewrittenProjection.map(_.toAttribute) + override def output: Seq[Attribute] = rewrittenProjection.map(_.toAttribute) override def requiredChildDistribution: Seq[Distribution] = if (groupingExpressions == Nil) { AllTuples :: Nil @@ -211,7 +213,7 @@ case class AggregatePostShuffle( ClusteredDistribution(groupingExpressions) :: Nil } - override def execute() = attachTree(this, "execute") { + override def execute(): RDD[Row] = attachTree(this, "execute") { child.execute().mapPartitions { iter => // The input Row in the format of (AggregateBuffers + GroupingExpression Values) val aggregates = initializedAndGetAggregates(FINAL, computedAggregates(rewrittenProjection)) @@ -268,7 +270,7 @@ case class DistinctAggregate( rewrittenProjection: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with PostShuffle { - override def output = rewrittenProjection.map(_.toAttribute) + override def output: Seq[Attribute] = rewrittenProjection.map(_.toAttribute) override def requiredChildDistribution: Seq[Distribution] = if (groupingExpressions == Nil) { AllTuples :: Nil @@ -281,7 +283,7 @@ case class DistinctAggregate( BindReferences.bindReference(_, childOutput) } - override def execute() = attachTree(this, "execute") { + override def execute(): RDD[Row] = attachTree(this, "execute") { child.execute().mapPartitions { iter => // initialize the aggregate functions for input rows (update/terminatePartial will be called) val aggregates = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index d51221e9bd2a..2ed7f2101d01 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -512,7 +512,8 @@ private[hive] case class HiveGenericUdaf2( val annotation = evaluator.getClass().getAnnotation(classOf[HiveUDFType]) if (annotation == null || !annotation.distinctLike()) false else true } - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" // Aggregation Buffer Data Type, We assume only 1 element for the Hive Aggregation Buffer // It will be StructType if more than 1 element (Actually will be StructSettableObjectInspector) @@ -526,7 +527,7 @@ private[hive] case class HiveGenericUdaf2( /////////////////////////////////////////////////////////////////////////////////////////////// @transient var bound: BoundReference = _ - override def initialize(buffers: Seq[BoundReference]) = { + override def initialize(buffers: Seq[BoundReference]): Unit = { bound = buffers(0) mode match { case FINAL => evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(bufferObjectInspector)) From 58b1481f54ed2dd3552d5ccddadebb5e05728892 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 17 Apr 2015 06:00:23 +0800 Subject: [PATCH 12/20] fix bug in the for unit test --- .../sql/execution/aggregate2/Aggregate.scala | 394 +++++++++++------- .../sql/hive/execution/AggregateSuite.scala | 28 +- 2 files changed, 267 insertions(+), 155 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala index 0438e4938fde..05d33574b379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -82,7 +82,7 @@ sealed trait Aggregate { // The schema of the aggregate buffers, as we lines those buffers // in a single row. - def bufferSchema(aggregates: Seq[AggregateExpression2]): Seq[Attribute] = + def bufferSchemaFromAggregate(aggregates: Seq[AggregateExpression2]): Seq[Attribute] = aggregates.zipWithIndex.flatMap { case (ca, idx) => ca.bufferDataType.zipWithIndex.map { case (dt, i) => // the attribute names is useless here, as we bind the attribute @@ -95,7 +95,7 @@ sealed trait PostShuffle extends Aggregate { self: Product => // extract the aggregate function from the projection - def computedAggregates(projectionList: Seq[NamedExpression]): Seq[AggregateExpression2] = { + def computedAggregates(projectionList: Seq[Expression]): Seq[AggregateExpression2] = { projectionList.flatMap { expr => expr.collect { case ae: AggregateExpression2 => ae @@ -120,13 +120,13 @@ case class AggregatePreShuffle( child: SparkPlan) extends UnaryNode with Aggregate { - val aggregateExpressions: Seq[AggregateExpression2] = originalProjection.map { + private val aggregateExpressions: Seq[AggregateExpression2] = originalProjection.map { BindReferences.bindReference(_, childOutput) } + private val buffersSchema = bufferSchemaFromAggregate(aggregateExpressions) override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: Nil - override def output: Seq[Attribute] = - bufferSchema(aggregateExpressions) ++ groupingExpressions.map(_.toAttribute) + override val output: Seq[Attribute] = buffersSchema ++ groupingExpressions.map(_.toAttribute) /** * Create Iterator for the in-memory hash map. @@ -154,48 +154,78 @@ case class AggregatePreShuffle( } override def execute(): RDD[Row] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - // the input is every single row - val aggregates = initializedAndGetAggregates(PARTIAL1, aggregateExpressions) - - val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) - val results = new OpenHashMap[Row, BufferSeens]() - while (iter.hasNext) { - val currentRow = iter.next() - - val keys = groupByProjection(currentRow) - results(keys) match { - case null => - val buffer = new GenericMutableRow(output.length) - // handle the aggregate buffers - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.reset(buffer) - ae.update(currentRow, buffer, null) - idx += 1 - } - // handle the grouping expressions - var idx2 = 0 - while (idx2 < keys.length) { - buffer(idx) = keys(idx2) - idx2 += 1 - idx += 1 - } - - results(keys.copy()) = new BufferSeens(buffer, null) - case inputbuffer => - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.update(currentRow, inputbuffer.buffer, null) - idx += 1 - } + if (groupingExpressions.length == 0) { + child.execute().mapPartitions { iter => + // the input is every single row + val aggregates = initializedAndGetAggregates(PARTIAL1, aggregateExpressions) + // without group by keys + val buffer = new GenericMutableRow(output.length) + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + idx += 1 + } + + while (iter.hasNext) { + val currentRow = iter.next() + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.update(currentRow, buffer, null) + idx += 1 + } } + + createIterator(aggregates, Iterator(new BufferSeens().withBuffer(buffer))) } + } else { + child.execute().mapPartitions { iter => + // the input is every single row + val aggregates = initializedAndGetAggregates(PARTIAL1, aggregateExpressions) + + val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) + + val results = new OpenHashMap[Row, BufferSeens]() + while (iter.hasNext) { + val currentRow = iter.next() + + val keys = groupByProjection(currentRow) + results(keys) match { + case null => + val buffer = new GenericMutableRow(output.length) + // fill the aggregate buffers + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + ae.update(currentRow, buffer, null) + idx += 1 + } + + // fill the grouping expression values into the new row + idx = buffersSchema.length + var idx2 = 0 + while (idx < output.length) { + buffer(idx) = keys(idx2) + idx2 += 1 + idx += 1 + } + + results(keys.copy()) = new BufferSeens(buffer, null) + case inputbuffer => + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.update(currentRow, inputbuffer.buffer, null) + idx += 1 + } + } + } - // The output is the (Aggregate Buffers + Grouping Expression Values) - createIterator(aggregates, results.iterator.map(_._2)) + // The output is the (Aggregate Buffers + Grouping Expression Values) + createIterator(aggregates, results.iterator.map(_._2)) + } } } } @@ -214,48 +244,78 @@ case class AggregatePostShuffle( } override def execute(): RDD[Row] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - // The input Row in the format of (AggregateBuffers + GroupingExpression Values) - val aggregates = initializedAndGetAggregates(FINAL, computedAggregates(rewrittenProjection)) - - val finalProjection = new InterpretedMutableProjection(rewrittenProjection, childOutput) - - val results = new OpenHashMap[Row, BufferSeens]() - val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) - - while (iter.hasNext) { - val currentRow = iter.next() - val keys = groupByProjection(currentRow) - results(keys) match { - case null => - // TODO actually what we need to copy is the grouping expression values - // as the aggregate buffer will be reset. - val buffer = currentRow.makeMutable() - // The reason why need to reset it first and merge with the input row is, - // in Hive UDAF, we need to provide an opportunity that the buffer can be the - // custom type, Otherwise, HIVE UDAF will wrap/unwrap in every merge() method - // calls, which is every expensive - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.reset(buffer) - ae.merge(currentRow, buffer) - idx += 1 - } - results(keys.copy()) = new BufferSeens(buffer, null) - case pair => - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.merge(currentRow, pair.buffer) - idx += 1 - } + if (groupingExpressions.length == 0) { + child.execute().mapPartitions { iter => + // The input Row in the format of (AggregateBuffers) + val aggregates = + initializedAndGetAggregates(FINAL, computedAggregates(rewrittenProjection)) + val finalProjection = new InterpretedMutableProjection(rewrittenProjection, childOutput) + + val buffer = new GenericMutableRow(childOutput.length) + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + idx += 1 + } + + while (iter.hasNext) { + val currentRow = iter.next() + + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.merge(currentRow, buffer) + idx += 1 + } } + + Iterator(finalProjection(buffer)) } + } else { + child.execute().mapPartitions { iter => + // The input Row in the format of (AggregateBuffers + GroupingExpression Values) + val aggregates = + initializedAndGetAggregates(FINAL, computedAggregates(rewrittenProjection)) + val finalProjection = new InterpretedMutableProjection(rewrittenProjection, childOutput) + + val results = new OpenHashMap[Row, BufferSeens]() + val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) + + while (iter.hasNext) { + val currentRow = iter.next() + val keys = groupByProjection(currentRow) + results(keys) match { + case null => + // TODO actually what we need to copy is the grouping expression values + // as the aggregate buffer will be reset. + val buffer = currentRow.makeMutable() + // The reason why need to reset it first and merge with the input row is, + // in Hive UDAF, we need to provide an opportunity that the buffer can be the + // custom type, Otherwise, HIVE UDAF will wrap/unwrap in every merge() method + // calls, which is every expensive + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + ae.merge(currentRow, buffer) + idx += 1 + } + results(keys.copy()) = new BufferSeens(buffer, null) + case pair => + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.merge(currentRow, pair.buffer) + idx += 1 + } + } + } - // final Project is simple a rewrite version of output expression list - // which will project as the final output - results.iterator.map { it => finalProjection(it._2.buffer) } + // final Project is simple a rewrite version of output expression list + // which will project as the final output + results.iterator.map { it => finalProjection(it._2.buffer)} + } } } } @@ -279,72 +339,122 @@ case class DistinctAggregate( } // binding the expression, which takes the child's output as input - val aggregateExpressions: Seq[NamedExpression] = originalProjection.map { - BindReferences.bindReference(_, childOutput) + private val aggregateExpressions: Seq[Expression] = originalProjection.map { + BindReferences.bindReference(_: Expression, childOutput) } override def execute(): RDD[Row] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - // initialize the aggregate functions for input rows (update/terminatePartial will be called) - val aggregates = - initializedAndGetAggregates(COMPLETE, computedAggregates(aggregateExpressions)) - - val outputSchema = bufferSchema(aggregates) ++ groupingExpressions.map(_.toAttribute) - - // initialize the aggregate functions for the final output (merge/terminate will be called) - initializedAndGetAggregates(COMPLETE, computedAggregates(rewrittenProjection)) - // binding the output projection, which takes the aggregate buffer and grouping keys - // as the input row. - val finalProjection = new InterpretedMutableProjection(rewrittenProjection, outputSchema) - - val results = new OpenHashMap[Row, BufferSeens]() - val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) - - while (iter.hasNext) { - val currentRow = iter.next() - - val keys = groupByProjection(currentRow) - results(keys) match { - case null => - val buffer = new GenericMutableRow(aggregates.length + keys.length) - val seens = new Array[JSet[Any]](aggregates.length) - - // handle the aggregate buffers - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.reset(buffer) - - if (ae.distinct) { - val seen = new JHashSet[Any]() - seens(idx) = seen - } + if (groupingExpressions.length == 0) { + child.execute().mapPartitions { iter => + // initialize the aggregate functions for input rows + // (update/terminatePartial will be called) + val aggregates = + initializedAndGetAggregates(COMPLETE, computedAggregates(aggregateExpressions)) + + val buffersSchema = bufferSchemaFromAggregate(aggregates) + + // initialize the aggregate functions for the final output (merge/terminate will be called) + initializedAndGetAggregates(COMPLETE, computedAggregates(rewrittenProjection)) + // binding the output projection, which takes the aggregate buffer and grouping keys + // as the input row. + val finalProjection = new InterpretedMutableProjection(rewrittenProjection, buffersSchema) + + val buffer = new GenericMutableRow(buffersSchema.length) + val seens = new Array[JSet[Any]](aggregates.length) + + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + + if (ae.distinct) { + seens(idx) = new JHashSet[Any]() + } + + idx += 1 + } + val ibs = new BufferSeens().withBuffer(buffer).withSeens(seens) - ae.update(currentRow, buffer, seens(idx)) - idx += 1 - } - - // handle the grouping expression value - var idx2 = 0 - while (idx2 < keys.length) { - buffer(idx) = keys(idx2) - idx2 += 1 - idx += 1 - } - results(keys.copy()) = new BufferSeens(buffer, seens) - - case bufferSeens => - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.update(currentRow, bufferSeens.buffer, bufferSeens.seens(idx)) - - idx += 1 - } + while (iter.hasNext) { + val currentRow = iter.next() + + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.update(currentRow, buffer, seens(idx)) + + idx += 1 + } } + + Iterator(finalProjection(ibs.buffer)) } + } else { + child.execute().mapPartitions { iter => + // initialize the aggregate functions for input rows + // (update/terminatePartial will be called) + val aggregates = + initializedAndGetAggregates(COMPLETE, computedAggregates(aggregateExpressions)) + + val buffersSchema = bufferSchemaFromAggregate(aggregates) + val outputSchema = buffersSchema ++ groupingExpressions.map(_.toAttribute) + + // initialize the aggregate functions for the final output (merge/terminate will be called) + initializedAndGetAggregates(COMPLETE, computedAggregates(rewrittenProjection)) + // binding the output projection, which takes the aggregate buffer and grouping keys + // as the input row. + val finalProjection = new InterpretedMutableProjection(rewrittenProjection, outputSchema) + + val results = new OpenHashMap[Row, BufferSeens]() + val groupByProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) + + while (iter.hasNext) { + val currentRow = iter.next() + + val keys = groupByProjection(currentRow) + results(keys) match { + case null => + val buffer = new GenericMutableRow(aggregates.length + keys.length) + val seens = new Array[JSet[Any]](aggregates.length) + + // handle the aggregate buffers + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + + if (ae.distinct) { + val seen = new JHashSet[Any]() + seens(idx) = seen + } + + ae.update(currentRow, buffer, seens(idx)) + idx += 1 + } + + // fill the grouping expression values into the new row + idx = buffersSchema.length + var idx2 = 0 + while (idx < output.length) { + buffer(idx) = keys(idx2) + idx2 += 1 + idx += 1 + } + results(keys.copy()) = new BufferSeens(buffer, seens) - results.iterator.map(it => finalProjection(it._2.buffer)) + case bufferSeens => + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.update(currentRow, bufferSeens.buffer, bufferSeens.seens(idx)) + + idx += 1 + } + } + } + + results.iterator.map(it => finalProjection(it._2.buffer)) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala index 0ddfe7e807b5..d1c0502bfa3e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala @@ -65,8 +65,8 @@ class AggregateSuite extends HiveComparisonTest with BeforeAndAfter { """ |SELECT | count(distinct value), - | max(distinct key), - | min(distinct key), + | max(key), + | min(key), | sum(distinct key) |FROM src """.stripMargin, false) @@ -85,8 +85,8 @@ class AggregateSuite extends HiveComparisonTest with BeforeAndAfter { """ |SELECT | count(distinct value) + 4, - | max(distinct key) + 2, - | min(distinct key) + 3, + | max(key) + 2, + | min(key) + 3, | sum(distinct key) + 4 |FROM src """.stripMargin, false) @@ -146,15 +146,17 @@ class AggregateSuite extends HiveComparisonTest with BeforeAndAfter { |ORDER BY a, b LIMIT 5 """.stripMargin, false) - createQueryTest("aggregation with group by expressions #7", - """ - |SELECT - | stddev_pop(distinct key) as a, - | stddev_samp(distinct key) as b - |FROM src - |GROUP BY key + 3, value - |ORDER BY a, b LIMIT 5 - """.stripMargin, false) +// TODO currently the parser doesn't support the distinct +// in Hive UDAF +// createQueryTest("aggregation with group by expressions #7", +// """ +// |SELECT +// | stddev_pop(distinct key) as a, +// | stddev_samp(distinct key) as b +// |FROM src +// |GROUP BY key + 3, value +// |ORDER BY a, b LIMIT 5 +// """.stripMargin, false) createQueryTest("aggregation with group by expressions #8", """ From 39a6243c76e2f021faf5100068f33955db3ded93 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 17 Apr 2015 08:58:55 +0800 Subject: [PATCH 13/20] use BufferAndKey class manully maitain the MutableRow --- .../sql/execution/aggregate2/Aggregate.scala | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala index 05d33574b379..9b275ea068a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -47,6 +47,23 @@ sealed class BufferSeens(var buffer: MutableRow, var seens: Array[JSet[Any]] = n } } +// A MutableRow for AggregateBuffers and GroupingExpression Values +sealed class BufferAndKey(leftLen: Int, rightLen: Int) + extends GenericMutableRow(leftLen + rightLen) { + + def this(leftLen: Int, keys: Row) = { + this(leftLen, keys.length) + // copy the keys to the last + var idx = leftLen + var idx2 = 0 + while (idx < keys.length) { + this.values(idx) = keys(idx2) + idx2 += 1 + idx += 1 + } + } +} + sealed trait Aggregate { self: Product => // HACK: Generators don't correctly preserve their output through serializations so we grab @@ -159,7 +176,7 @@ case class AggregatePreShuffle( // the input is every single row val aggregates = initializedAndGetAggregates(PARTIAL1, aggregateExpressions) // without group by keys - val buffer = new GenericMutableRow(output.length) + val buffer = new GenericMutableRow(buffersSchema.length) var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) @@ -193,8 +210,8 @@ case class AggregatePreShuffle( val keys = groupByProjection(currentRow) results(keys) match { case null => - val buffer = new GenericMutableRow(output.length) - // fill the aggregate buffers + val buffer = new BufferAndKey(output.length, keys) + // update the aggregate buffers var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) @@ -203,15 +220,6 @@ case class AggregatePreShuffle( idx += 1 } - // fill the grouping expression values into the new row - idx = buffersSchema.length - var idx2 = 0 - while (idx < output.length) { - buffer(idx) = keys(idx2) - idx2 += 1 - idx += 1 - } - results(keys.copy()) = new BufferSeens(buffer, null) case inputbuffer => var idx = 0 @@ -362,6 +370,7 @@ case class DistinctAggregate( val buffer = new GenericMutableRow(buffersSchema.length) val seens = new Array[JSet[Any]](aggregates.length) + // reset the aggregate buffer var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) @@ -375,6 +384,7 @@ case class DistinctAggregate( } val ibs = new BufferSeens().withBuffer(buffer).withSeens(seens) + // update the aggregate buffer while (iter.hasNext) { val currentRow = iter.next() @@ -387,19 +397,21 @@ case class DistinctAggregate( } } + // only single output for non grouping keys case Iterator(finalProjection(ibs.buffer)) } } else { child.execute().mapPartitions { iter => // initialize the aggregate functions for input rows - // (update/terminatePartial will be called) + // (update will be called) val aggregates = initializedAndGetAggregates(COMPLETE, computedAggregates(aggregateExpressions)) val buffersSchema = bufferSchemaFromAggregate(aggregates) val outputSchema = buffersSchema ++ groupingExpressions.map(_.toAttribute) - // initialize the aggregate functions for the final output (merge/terminate will be called) + // initialize the aggregate functions for the final output + // (merge/terminate will be called) initializedAndGetAggregates(COMPLETE, computedAggregates(rewrittenProjection)) // binding the output projection, which takes the aggregate buffer and grouping keys // as the input row. @@ -414,10 +426,10 @@ case class DistinctAggregate( val keys = groupByProjection(currentRow) results(keys) match { case null => - val buffer = new GenericMutableRow(aggregates.length + keys.length) + val buffer = new BufferAndKey(aggregates.length, keys) val seens = new Array[JSet[Any]](aggregates.length) - // handle the aggregate buffers + // update the aggregate buffers var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) @@ -432,14 +444,6 @@ case class DistinctAggregate( idx += 1 } - // fill the grouping expression values into the new row - idx = buffersSchema.length - var idx2 = 0 - while (idx < output.length) { - buffer(idx) = keys(idx2) - idx2 += 1 - idx += 1 - } results(keys.copy()) = new BufferSeens(buffer, seens) case bufferSeens => From 5b015184fef3ef3a414f84238dfd04ccb1246ab2 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 17 Apr 2015 09:36:01 +0800 Subject: [PATCH 14/20] fix bug of with BufferAndKeys --- .../apache/spark/sql/execution/aggregate2/Aggregate.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala index 9b275ea068a7..41c99470868f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -56,7 +56,7 @@ sealed class BufferAndKey(leftLen: Int, rightLen: Int) // copy the keys to the last var idx = leftLen var idx2 = 0 - while (idx < keys.length) { + while (idx2 < keys.length) { this.values(idx) = keys(idx2) idx2 += 1 idx += 1 @@ -210,7 +210,7 @@ case class AggregatePreShuffle( val keys = groupByProjection(currentRow) results(keys) match { case null => - val buffer = new BufferAndKey(output.length, keys) + val buffer = new BufferAndKey(buffersSchema.length, keys) // update the aggregate buffers var idx = 0 while (idx < aggregates.length) { @@ -426,7 +426,7 @@ case class DistinctAggregate( val keys = groupByProjection(currentRow) results(keys) match { case null => - val buffer = new BufferAndKey(aggregates.length, keys) + val buffer = new BufferAndKey(buffersSchema.length, keys) val seens = new Array[JSet[Any]](aggregates.length) // update the aggregate buffers From feac4d0b9f66cd9bd2e62c2d88564e18f2d1ceaa Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 16 Apr 2015 18:37:10 -0700 Subject: [PATCH 15/20] Add golden files --- ...roup by expressions #1-0-f94fcc218d98298e058589b40b66e54a | 5 +++++ ...roup by expressions #2-0-9e3a5b01c29dc63023bde64d85c2b7e7 | 5 +++++ ...roup by expressions #3-0-bb30af32082fac87c7e2720e40978c87 | 5 +++++ ...roup by expressions #5-0-68ffb9106a13d35ba8c36741e894a93f | 5 +++++ ...roup by expressions #6-0-5f3e67d7a3abd388c85220eb3af07976 | 5 +++++ ...roup by expressions #8-0-d9161a0e40862ba94a35e5b65daea51a | 5 +++++ ...group by expressions #1-0-30038eb221d9d91ff4a098a57c1a5f9 | 1 + ...roup by expressions #2-0-75a3974aac80b9c47f23519da6a68876 | 1 + ...roup by expressions #3-0-8341e7bf739124bef28729aabb9fe542 | 1 + ...roup by expressions #4-0-8341e7bf739124bef28729aabb9fe542 | 1 + ...roup by expressions #5-0-1e35f970b831ecfffdaff828428aea51 | 1 + ...roup by expressions #6-0-9f51fa0a008712e35c70d7187a55ee35 | 1 + 12 files changed, 36 insertions(+) create mode 100644 sql/hive/src/test/resources/golden/aggregation with group by expressions #1-0-f94fcc218d98298e058589b40b66e54a create mode 100644 sql/hive/src/test/resources/golden/aggregation with group by expressions #2-0-9e3a5b01c29dc63023bde64d85c2b7e7 create mode 100644 sql/hive/src/test/resources/golden/aggregation with group by expressions #3-0-bb30af32082fac87c7e2720e40978c87 create mode 100644 sql/hive/src/test/resources/golden/aggregation with group by expressions #5-0-68ffb9106a13d35ba8c36741e894a93f create mode 100644 sql/hive/src/test/resources/golden/aggregation with group by expressions #6-0-5f3e67d7a3abd388c85220eb3af07976 create mode 100644 sql/hive/src/test/resources/golden/aggregation with group by expressions #8-0-d9161a0e40862ba94a35e5b65daea51a create mode 100644 sql/hive/src/test/resources/golden/aggregation without group by expressions #1-0-30038eb221d9d91ff4a098a57c1a5f9 create mode 100644 sql/hive/src/test/resources/golden/aggregation without group by expressions #2-0-75a3974aac80b9c47f23519da6a68876 create mode 100644 sql/hive/src/test/resources/golden/aggregation without group by expressions #3-0-8341e7bf739124bef28729aabb9fe542 create mode 100644 sql/hive/src/test/resources/golden/aggregation without group by expressions #4-0-8341e7bf739124bef28729aabb9fe542 create mode 100644 sql/hive/src/test/resources/golden/aggregation without group by expressions #5-0-1e35f970b831ecfffdaff828428aea51 create mode 100644 sql/hive/src/test/resources/golden/aggregation without group by expressions #6-0-9f51fa0a008712e35c70d7187a55ee35 diff --git a/sql/hive/src/test/resources/golden/aggregation with group by expressions #1-0-f94fcc218d98298e058589b40b66e54a b/sql/hive/src/test/resources/golden/aggregation with group by expressions #1-0-f94fcc218d98298e058589b40b66e54a new file mode 100644 index 000000000000..aa9769a5e9ab --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation with group by expressions #1-0-f94fcc218d98298e058589b40b66e54a @@ -0,0 +1,5 @@ +3 3 0 0 +5 1 2 2 +7 1 4 4 +8 3 5 5 +11 1 8 8 diff --git a/sql/hive/src/test/resources/golden/aggregation with group by expressions #2-0-9e3a5b01c29dc63023bde64d85c2b7e7 b/sql/hive/src/test/resources/golden/aggregation with group by expressions #2-0-9e3a5b01c29dc63023bde64d85c2b7e7 new file mode 100644 index 000000000000..26ae72edc75d --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation with group by expressions #2-0-9e3a5b01c29dc63023bde64d85c2b7e7 @@ -0,0 +1,5 @@ +3 3 0 0 0 +5 1 2 2 2 +7 1 4 4 4 +8 3 5 5 15 +11 1 8 8 8 diff --git a/sql/hive/src/test/resources/golden/aggregation with group by expressions #3-0-bb30af32082fac87c7e2720e40978c87 b/sql/hive/src/test/resources/golden/aggregation with group by expressions #3-0-bb30af32082fac87c7e2720e40978c87 new file mode 100644 index 000000000000..59abfa88c9da --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation with group by expressions #3-0-bb30af32082fac87c7e2720e40978c87 @@ -0,0 +1,5 @@ +3 1 0 0 0 +5 1 2 2 2 +7 1 4 4 4 +8 1 5 5 5 +11 1 8 8 8 diff --git a/sql/hive/src/test/resources/golden/aggregation with group by expressions #5-0-68ffb9106a13d35ba8c36741e894a93f b/sql/hive/src/test/resources/golden/aggregation with group by expressions #5-0-68ffb9106a13d35ba8c36741e894a93f new file mode 100644 index 000000000000..29bba1d1d6ea --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation with group by expressions #5-0-68ffb9106a13d35ba8c36741e894a93f @@ -0,0 +1,5 @@ +6 4 6 6 6 +10 6 10 10 10 +14 8 14 14 14 +16 9 16 16 16 +22 12 22 22 22 diff --git a/sql/hive/src/test/resources/golden/aggregation with group by expressions #6-0-5f3e67d7a3abd388c85220eb3af07976 b/sql/hive/src/test/resources/golden/aggregation with group by expressions #6-0-5f3e67d7a3abd388c85220eb3af07976 new file mode 100644 index 000000000000..1a4fc65a86c4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation with group by expressions #6-0-5f3e67d7a3abd388c85220eb3af07976 @@ -0,0 +1,5 @@ +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 diff --git a/sql/hive/src/test/resources/golden/aggregation with group by expressions #8-0-d9161a0e40862ba94a35e5b65daea51a b/sql/hive/src/test/resources/golden/aggregation with group by expressions #8-0-d9161a0e40862ba94a35e5b65daea51a new file mode 100644 index 000000000000..0daf0e02b136 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation with group by expressions #8-0-d9161a0e40862ba94a35e5b65daea51a @@ -0,0 +1,5 @@ +4 +6 +8 +9 +12 diff --git a/sql/hive/src/test/resources/golden/aggregation without group by expressions #1-0-30038eb221d9d91ff4a098a57c1a5f9 b/sql/hive/src/test/resources/golden/aggregation without group by expressions #1-0-30038eb221d9d91ff4a098a57c1a5f9 new file mode 100644 index 000000000000..5111dd17161f --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation without group by expressions #1-0-30038eb221d9d91ff4a098a57c1a5f9 @@ -0,0 +1 @@ +500 498 0 diff --git a/sql/hive/src/test/resources/golden/aggregation without group by expressions #2-0-75a3974aac80b9c47f23519da6a68876 b/sql/hive/src/test/resources/golden/aggregation without group by expressions #2-0-75a3974aac80b9c47f23519da6a68876 new file mode 100644 index 000000000000..b4d2e5cc256d --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation without group by expressions #2-0-75a3974aac80b9c47f23519da6a68876 @@ -0,0 +1 @@ +500 498 0 130091 diff --git a/sql/hive/src/test/resources/golden/aggregation without group by expressions #3-0-8341e7bf739124bef28729aabb9fe542 b/sql/hive/src/test/resources/golden/aggregation without group by expressions #3-0-8341e7bf739124bef28729aabb9fe542 new file mode 100644 index 000000000000..276664a61678 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation without group by expressions #3-0-8341e7bf739124bef28729aabb9fe542 @@ -0,0 +1 @@ +309 498 0 79136 diff --git a/sql/hive/src/test/resources/golden/aggregation without group by expressions #4-0-8341e7bf739124bef28729aabb9fe542 b/sql/hive/src/test/resources/golden/aggregation without group by expressions #4-0-8341e7bf739124bef28729aabb9fe542 new file mode 100644 index 000000000000..276664a61678 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation without group by expressions #4-0-8341e7bf739124bef28729aabb9fe542 @@ -0,0 +1 @@ +309 498 0 79136 diff --git a/sql/hive/src/test/resources/golden/aggregation without group by expressions #5-0-1e35f970b831ecfffdaff828428aea51 b/sql/hive/src/test/resources/golden/aggregation without group by expressions #5-0-1e35f970b831ecfffdaff828428aea51 new file mode 100644 index 000000000000..ce71b00ee105 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation without group by expressions #5-0-1e35f970b831ecfffdaff828428aea51 @@ -0,0 +1 @@ +503 499 2 130096 diff --git a/sql/hive/src/test/resources/golden/aggregation without group by expressions #6-0-9f51fa0a008712e35c70d7187a55ee35 b/sql/hive/src/test/resources/golden/aggregation without group by expressions #6-0-9f51fa0a008712e35c70d7187a55ee35 new file mode 100644 index 000000000000..418739f242dd --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation without group by expressions #6-0-9f51fa0a008712e35c70d7187a55ee35 @@ -0,0 +1 @@ +313 500 3 79140 From ec7deaab6a573bec35fbed7ff3c8ec2245f563b8 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 20 Apr 2015 21:10:48 -0700 Subject: [PATCH 16/20] enable more unit test --- .../org/apache/spark/sql/SQLQuerySuite.scala | 12 ++ .../Aggregate2CompatibilitySuite.scala | 111 ------------------ .../execution/HiveCompatibilitySuite.scala | 12 ++ .../sql/hive/execution/HiveQuerySuite.scala | 14 ++- .../sql/hive/execution/SQLQuerySuite.scala | 13 ++ 5 files changed, 50 insertions(+), 112 deletions(-) delete mode 100644 sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/Aggregate2CompatibilitySuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5babc4332cc7..89141677de0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1365,3 +1365,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } } } + +class SQLQuerySuite2 extends SQLQuerySuite { + override def beforeAll() { + super.beforeAll() + sqlCtx.setConf(SQLConf.AGGREGATE_2, "true") + } + + override def afterAll() { + sqlCtx.setConf(SQLConf.AGGREGATE_2, "false") + super.afterAll() + } +} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/Aggregate2CompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/Aggregate2CompatibilitySuite.scala deleted file mode 100644 index 31023dffa4ae..000000000000 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/Aggregate2CompatibilitySuite.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.test.TestHive - -/** - * Test the aggregation framework2. - */ -class Aggregate2CompatibilitySuite extends HiveCompatibilitySuite { - override def beforeAll() { - super.beforeAll() - TestHive.setConf(SQLConf.AGGREGATE_2, "true") - } - - override def afterAll() { - TestHive.setConf(SQLConf.AGGREGATE_2, "false") - super.afterAll() - } - - override def whiteList = Seq( - "groupby1", - "groupby11", - "groupby12", - "groupby1_limit", - "groupby_grouping_id1", - "groupby_grouping_id2", - "groupby_grouping_sets1", - "groupby_grouping_sets2", - "groupby_grouping_sets3", - "groupby_grouping_sets4", - "groupby_grouping_sets5", - "groupby1_map", - "groupby1_map_nomap", - "groupby1_map_skew", - "groupby1_noskew", - "groupby2", - "groupby2_limit", - "groupby2_map", - "groupby2_map_skew", - "groupby2_noskew", - "groupby4", - "groupby4_map", - "groupby4_map_skew", - "groupby4_noskew", - "groupby5", - "groupby5_map", - "groupby5_map_skew", - "groupby5_noskew", - "groupby6", - "groupby6_map", - "groupby6_map_skew", - "groupby6_noskew", - "groupby7", - "groupby7_map", - "groupby7_map_multi_single_reducer", - "groupby7_map_skew", - "groupby7_noskew", - "groupby7_noskew_multi_single_reducer", - "groupby8", - "groupby8_map", - "groupby8_map_skew", - "groupby8_noskew", - "groupby9", - "groupby_distinct_samekey", - "groupby_map_ppr", - "groupby_multi_insert_common_distinct", - "groupby_multi_single_reducer2", - "groupby_multi_single_reducer3", - "groupby_mutli_insert_common_distinct", - "groupby_neg_float", - "groupby_ppd", - "groupby_ppr", - "groupby_sort_10", - "groupby_sort_2", - "groupby_sort_3", - "groupby_sort_4", - "groupby_sort_5", - "groupby_sort_6", - "groupby_sort_7", - "groupby_sort_8", - "groupby_sort_9", - "groupby_sort_test_1", - "having", - "udaf_collect_set", - "udaf_corr", - "udaf_covar_pop", - "udaf_covar_samp", - "udaf_histogram_numeric", - "udaf_number_format", - "udf_sum", - "udf_avg", - "udf_count" - ) -} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 048f78b4daa8..b698a1777750 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -1007,3 +1007,15 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "view_inputs" ) } + +class HiveCompatibilitySuite2 extends HiveCompatibilitySuite { + override def beforeAll() { + super.beforeAll() + TestHive.setConf(SQLConf.AGGREGATE_2, "true") + } + + override def afterAll() { + TestHive.setConf(SQLConf.AGGREGATE_2, "false") + super.afterAll() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6d8d99ebc816..8eaf1a4cabbe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfter import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.{SQLConf, AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ @@ -1134,3 +1134,15 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // for SPARK-2180 test case class HavingRow(key: Int, value: String, attr: Int) + +class HiveQuerySuite2 extends HiveQuerySuite { + override def beforeAll() { + super.beforeAll() + TestHive.setConf(SQLConf.AGGREGATE_2, "true") + } + + override def afterAll() { + TestHive.setConf(SQLConf.AGGREGATE_2, "false") + super.afterAll() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 40a35674e4cb..f963c0825cf2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.errors.DialectException @@ -28,6 +30,7 @@ import org.apache.spark.sql.hive.{HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -905,3 +908,13 @@ class SQLQuerySuite extends QueryTest { } } } + +class SQLQuerySuite2 extends SQLQuerySuite with BeforeAndAfter { + def beforeAll() { + TestHive.setConf(SQLConf.AGGREGATE_2, "true") + } + + def afterAll() { + TestHive.setConf(SQLConf.AGGREGATE_2, "false") + } +} From 393b0d1973f5dd25c9d81240f064cd786c4d6158 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 20 Apr 2015 23:14:58 -0700 Subject: [PATCH 17/20] disable the codegen for aggregate2 in unit test --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 89141677de0e..4e1924788968 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -174,7 +174,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true case _ => } - if (!hasGeneratedAgg) { + if (!hasGeneratedAgg && conf.aggregate2 == false) { fail( s""" |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. From 021431f0e77cd9eea976c4b7f6b5b289b58439d6 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 24 Apr 2015 00:22:16 -0700 Subject: [PATCH 18/20] Add more unit test --- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../execution/HiveCompatibilitySuite.scala | 2 +- .../sql/hive/execution/HiveQuerySuite.scala | 86 ++++++++++++++++++- .../sql/hive/execution/SQLQuerySuite.scala | 37 +++++++- 4 files changed, 121 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4e1924788968..82afb6649b7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1366,7 +1366,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } } -class SQLQuerySuite2 extends SQLQuerySuite { +class SQLNewUDAFQuerySuite extends SQLQuerySuite { override def beforeAll() { super.beforeAll() sqlCtx.setConf(SQLConf.AGGREGATE_2, "true") diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index b698a1777750..7ec0be473a5d 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -1008,7 +1008,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { ) } -class HiveCompatibilitySuite2 extends HiveCompatibilitySuite { +class HiveNewUDAFCompatibilitySuite extends HiveCompatibilitySuite { override def beforeAll() { super.beforeAll() TestHive.setConf(SQLConf.AGGREGATE_2, "true") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 8eaf1a4cabbe..9b295deeb1b8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -47,6 +47,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { import org.apache.spark.sql.hive.test.TestHive.implicits._ override def beforeAll() { + TestHive.reset() TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -1135,14 +1136,93 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // for SPARK-2180 test case class HavingRow(key: Int, value: String, attr: Int) -class HiveQuerySuite2 extends HiveQuerySuite { +// TODO ideally we should make this class inherit from HiveQuerySuite. +// However the tables/configuration cannot be reset, which causes +// exceptions like the table already existed etc. +class HiveNewUDAFQuerySuite extends HiveComparisonTest with BeforeAndAfter { + private val originalTimeZone = TimeZone.getDefault + private val originalLocale = Locale.getDefault + + import org.apache.spark.sql.hive.test.TestHive.implicits._ + override def beforeAll() { - super.beforeAll() + TestHive.reset() + TestHive.cacheTables = true + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) TestHive.setConf(SQLConf.AGGREGATE_2, "true") } override def afterAll() { + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.AGGREGATE_2, "false") - super.afterAll() + } + + def isExplanation(result: DataFrame): Boolean = { + val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } + explanation.contains("== Physical Plan ==") + } + + createQueryTest("having no references", + "SELECT key FROM src GROUP BY key HAVING COUNT(*) > 1") + + createQueryTest("Constant Folding Optimization for AVG_SUM_COUNT", + "SELECT AVG(0), SUM(0), COUNT(null), COUNT(value) FROM src GROUP BY key") + + test("SPARK-1704: Explain commands as a DataFrame") { + sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + + val df = sql("explain select key, count(value) from src group by key") + assert(isExplanation(df)) + + TestHive.reset() + } + + test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { + val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) + .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} + TestHive.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") + val results = + sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") + .collect() + .map(x => Pair(x.getString(0), x.getInt(1))) + + assert(results === Array(Pair("foo", 4))) + TestHive.reset() + } + + test("SPARK-2180: HAVING with non-boolean clause raises no exceptions") { + sql("select key, count(*) c from src group by key having c").collect() + } + + test("SPARK-2225: turn HAVING without GROUP BY into a simple filter") { + assert(sql("select key from src having key > 490").collect().size < 100) + } + + test("Query Hive native command execution result") { + val databaseName = "test_native_commands" + + assertResult(0) { + sql(s"DROP DATABASE IF EXISTS $databaseName").count() + } + + assertResult(0) { + sql(s"CREATE DATABASE $databaseName").count() + } + + assert( + sql("SHOW DATABASES") + .select('result) + .collect() + .map(_.getString(0)) + .contains(databaseName)) + + assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM src GROUP BY key"))) + + TestHive.reset() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f963c0825cf2..bd80f146ea9f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -909,7 +909,10 @@ class SQLQuerySuite extends QueryTest { } } -class SQLQuerySuite2 extends SQLQuerySuite with BeforeAndAfter { +// TODO ideally we should make this class inherit from SQLQuerySuite. +// However the tables/configuration cannot be reset, which causes +// exceptions like the table already existed etc. +class SQLNewUDAFQuerySuite extends QueryTest with BeforeAndAfter { def beforeAll() { TestHive.setConf(SQLConf.AGGREGATE_2, "true") } @@ -917,4 +920,36 @@ class SQLQuerySuite2 extends SQLQuerySuite with BeforeAndAfter { def afterAll() { TestHive.setConf(SQLConf.AGGREGATE_2, "false") } + + test("ordering not in agg") { + checkAnswer( + sql("SELECT key FROM src GROUP BY key, value ORDER BY value"), + sql( """ + SELECT key + FROM ( + SELECT key, value + FROM src + GROUP BY key, value + ORDER BY value) a""").collect().toSeq) + } + + test("SPARK-2554 SumDistinct partial aggregation") { + checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"), + sql("SELECT distinct key FROM src order by key").collect().toSeq) + } + + test("SPARK-4296 Grouping field with Hive UDF as sub expression") { + val rdd = sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) + jsonRDD(rdd).registerTempTable("data") + checkAnswer( + sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), + Row("str-1", 1970)) + + dropTempTable("data") + + jsonRDD(rdd).registerTempTable("data") + checkAnswer(sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) + + dropTempTable("data") + } } From 8ad5fc549fbfaa9ee00676e5d8e57cf98da32adc Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 18 May 2015 12:14:08 +0800 Subject: [PATCH 19/20] rebase to the latest master --- .../spark/sql/catalyst/expressions/UnsafeRow.java | 10 ++++++++++ .../apache/spark/sql/catalyst/planning/patterns.scala | 2 +- .../spark/sql/execution/aggregate2/Aggregate.scala | 6 +++--- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index ec97fe603c44..a8c243874535 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -339,6 +339,16 @@ public Row copy() { } @Override + public MutableRow makeMutable() { + GenericMutableRow mr = new GenericMutableRow(this.size()); + for (int i = 0; i < mr.size(); ++i) { + mr.update(i, get(i)); + } + + return mr; + } + + @Override public boolean anyNull() { return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 7a3f637351ff..405c7f3f2d36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -190,7 +190,7 @@ object PartialAggregation2 { // resolving struct field accesses, because `GetField` is not a `NamedExpression`. // (Should we just turn `GetField` into a `NamedExpression`?) substitutions - .get(e.transform { case Alias(g: GetField, _) => g }) + .get(e.transform { case Alias(g: ExtractValue, _) => g }) .map(_.toAttribute) .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala index 41c99470868f..125d67544e83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -170,7 +170,7 @@ case class AggregatePreShuffle( } } - override def execute(): RDD[Row] = attachTree(this, "execute") { + override def doExecute(): RDD[Row] = attachTree(this, "execute") { if (groupingExpressions.length == 0) { child.execute().mapPartitions { iter => // the input is every single row @@ -251,7 +251,7 @@ case class AggregatePostShuffle( ClusteredDistribution(groupingExpressions) :: Nil } - override def execute(): RDD[Row] = attachTree(this, "execute") { + override def doExecute(): RDD[Row] = attachTree(this, "execute") { if (groupingExpressions.length == 0) { child.execute().mapPartitions { iter => // The input Row in the format of (AggregateBuffers) @@ -351,7 +351,7 @@ case class DistinctAggregate( BindReferences.bindReference(_: Expression, childOutput) } - override def execute(): RDD[Row] = attachTree(this, "execute") { + override def doExecute(): RDD[Row] = attachTree(this, "execute") { if (groupingExpressions.length == 0) { child.execute().mapPartitions { iter => // initialize the aggregate functions for input rows From 68dd625d9ef753086760796dcc26de186eaf8825 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 8 Jun 2015 06:10:04 -0700 Subject: [PATCH 20/20] rebase again --- .../apache/spark/sql/execution/aggregate2/Aggregate.scala | 2 +- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 6 +++--- .../src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala index 125d67544e83..b714435ee8c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.aggregate2 -import java.util.{HashSet=>JHashSet, Set=>JSet} +import java.util.{HashSet => JHashSet, Set => JSet} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 82afb6649b7b..39fef43d49f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -174,7 +174,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true case _ => } - if (!hasGeneratedAgg && conf.aggregate2 == false) { + if (!hasGeneratedAgg && sqlContext.getConf(SQLConf.AGGREGATE_2) == false) { fail( s""" |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. @@ -1369,11 +1369,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { class SQLNewUDAFQuerySuite extends SQLQuerySuite { override def beforeAll() { super.beforeAll() - sqlCtx.setConf(SQLConf.AGGREGATE_2, "true") + sqlContext.setConf(SQLConf.AGGREGATE_2, "true") } override def afterAll() { - sqlCtx.setConf(SQLConf.AGGREGATE_2, "false") + sqlContext.setConf(SQLConf.AGGREGATE_2, "false") super.afterAll() } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 2ed7f2101d01..da3d7b2b6953 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import java.util.{Set=>JSet} +import java.util.{Set => JSet} import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper