From c70d920bf70693974fac951762be360e20e4bb80 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 7 Apr 2016 13:11:17 +0800 Subject: [PATCH 1/5] Reimplement TypedAggregateExpression to DeclarativeAggregate --- .../sql/catalyst/expressions/Expression.scala | 2 +- .../sql/catalyst/expressions/literals.scala | 3 +- .../apache/spark/sql/types/ObjectType.scala | 2 + .../scala/org/apache/spark/sql/Column.scala | 16 +- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../spark/sql/KeyValueGroupedDataset.scala | 3 +- .../aggregate/TypedAggregateExpression.scala | 189 +++++++++--------- .../spark/sql/expressions/Aggregator.scala | 31 ++- .../org/apache/spark/sql/QueryTest.scala | 4 + 9 files changed, 132 insertions(+), 122 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a24a5db8d49cd..718bb4b118cea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -185,7 +185,7 @@ abstract class Expression extends TreeNode[Expression] { * Returns a user-facing string representation of this expression's name. * This should usually match the name of the function in SQL. */ - def prettyName: String = getClass.getSimpleName.toLowerCase + def prettyName: String = nodeName.toLowerCase private def flatArguments = productIterator.flatMap { case t: Traversable[_] => t diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e6804d096cd96..7fd4bc3066cbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -60,7 +60,8 @@ object Literal { * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object * into code generation. */ - def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass)) + def fromObject(obj: Any, objType: DataType): Literal = new Literal(obj, objType) + def fromObject(obj: Any): Literal = new Literal(obj, ObjectType(obj.getClass)) def fromJSON(json: JValue): Literal = { val dataType = DataType.parseDataType(json \ "dataType") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 06ee0fbfe9642..b7b1acc58242e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -41,4 +41,6 @@ private[sql] case class ObjectType(cls: Class[_]) extends DataType { throw new UnsupportedOperationException("No size estimation available for objects.") def asNullable: DataType = this + + override def simpleString: String = cls.getName } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d64736e11110b..bd96941da798d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -59,14 +59,14 @@ class TypedColumn[-T, U]( * on a decoded object. */ private[sql] def withInputType( - inputEncoder: ExpressionEncoder[_], - schema: Seq[Attribute]): TypedColumn[T, U] = { - val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]] - new TypedColumn[T, U]( - expr transform { case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => - ta.copy(aEncoder = Some(boundEncoder), children = schema) - }, - encoder) + inputDeserializer: Expression, + inputAttributes: Seq[Attribute]): TypedColumn[T, U] = { + val unresolvedDeserializer = UnresolvedDeserializer(inputDeserializer, inputAttributes) + val newExpr = expr transform { + case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => + ta.copy(inputDeserializer = Some(unresolvedDeserializer)) + } + new TypedColumn[T, U](newExpr, encoder) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2854d5f9daf20..7a4958e7832c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -993,7 +993,7 @@ class Dataset[T] private[sql]( sqlContext, Project( c1.withInputType( - boundTEncoder, + unresolvedTEncoder.deserializer, logicalPlan.output).named :: Nil, logicalPlan), implicitly[Encoder[U1]]) @@ -1007,7 +1007,7 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named) + columns.map(_.withInputType(unresolvedTEncoder.deserializer, logicalPlan.output).named) val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index f19ad6e707526..05e13e66d137c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -209,8 +209,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map( - _.withInputType(resolvedVEncoder, dataAttributes).named) + columns.map(_.withInputType(unresolvedVEncoder.deserializer, dataAttributes).named) val keyColumn = if (resolvedKEncoder.flat) { assert(groupingAttributes.length == 1) groupingAttributes.head diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 9abae5357973f..bb122fb3130df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -19,133 +19,138 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials -import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ object TypedAggregateExpression { - def apply[A, B : Encoder, C : Encoder]( - aggregator: Aggregator[A, B, C]): TypedAggregateExpression = { + def apply[BUF : Encoder, OUT : Encoder]( + aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { + val bufferEncoder = encoderFor[BUF] + val bufferSerializer = bufferEncoder.namedExpressions + val bufferDeserializer = UnresolvedDeserializer( + bufferEncoder.deserializer, + bufferSerializer.map(_.toAttribute)) + + val outputEncoder = encoderFor[OUT] + val outputType = if (outputEncoder.flat) { + outputEncoder.schema.head.dataType + } else { + outputEncoder.schema + } + new TypedAggregateExpression( aggregator.asInstanceOf[Aggregator[Any, Any, Any]], None, - encoderFor[B].asInstanceOf[ExpressionEncoder[Any]], - encoderFor[C].asInstanceOf[ExpressionEncoder[Any]], - Nil, - 0, - 0) + bufferSerializer, + bufferDeserializer, + outputEncoder.serializer, + outputEncoder.deserializer.dataType, + outputType) } } /** - * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has - * the following limitations: - * - It assumes the aggregator has a zero, `0`. + * A helper class to hook [[Aggregator]] into the aggregation system. */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], - aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. - unresolvedBEncoder: ExpressionEncoder[Any], - cEncoder: ExpressionEncoder[Any], - children: Seq[Attribute], - mutableAggBufferOffset: Int, - inputAggBufferOffset: Int) - extends ImperativeAggregate with Logging { + inputDeserializer: Option[Expression], + bufferSerializer: Seq[NamedExpression], + bufferDeserializer: Expression, + outputSerializer: Seq[Expression], + outputExternalType: DataType, + dataType: DataType) extends DeclarativeAggregate with NonSQLExpression { - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) + override def nullable: Boolean = true - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) + override def deterministic: Boolean = true - override def nullable: Boolean = true + override def children: Seq[Expression] = inputDeserializer.toSeq :+ bufferDeserializer - override def dataType: DataType = if (cEncoder.flat) { - cEncoder.schema.head.dataType - } else { - cEncoder.schema - } + override lazy val resolved: Boolean = inputDeserializer.isDefined && childrenResolved - override def deterministic: Boolean = true + override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq) - override lazy val resolved: Boolean = aEncoder.isDefined - - override lazy val inputTypes: Seq[DataType] = Nil - - override val aggBufferSchema: StructType = unresolvedBEncoder.schema - - override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes - - val bEncoder = unresolvedBEncoder - .resolve(aggBufferAttributes, OuterScopes.outerScopes) - .bind(aggBufferAttributes) - - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) - - // We let the dataset do the binding for us. - lazy val boundA = aEncoder.get - - private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { - var i = 0 - while (i < aggBufferAttributes.length) { - val offset = mutableAggBufferOffset + i - aggBufferSchema(i).dataType match { - case BooleanType => buffer.setBoolean(offset, value.getBoolean(i)) - case ByteType => buffer.setByte(offset, value.getByte(i)) - case ShortType => buffer.setShort(offset, value.getShort(i)) - case IntegerType => buffer.setInt(offset, value.getInt(i)) - case LongType => buffer.setLong(offset, value.getLong(i)) - case FloatType => buffer.setFloat(offset, value.getFloat(i)) - case DoubleType => buffer.setDouble(offset, value.getDouble(i)) - case other => buffer.update(offset, value.get(i, other)) - } - i += 1 - } - } + override def inputTypes: Seq[AbstractDataType] = Nil - override def initialize(buffer: MutableRow): Unit = { - val zero = bEncoder.toRow(aggregator.zero) - updateBuffer(buffer, zero) - } + private def aggregatorLiteral = + Literal.create(aggregator, ObjectType(classOf[Aggregator[Any, Any, Any]])) - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val inputA = boundA.fromRow(input) - val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) - val merged = aggregator.reduce(currentB, inputA) - val returned = bEncoder.toRow(merged) + private def bufferExternalType = bufferDeserializer.dataType - updateBuffer(buffer, returned) + override lazy val aggBufferAttributes: Seq[AttributeReference] = + bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference]) + + override lazy val initialValues: Seq[Expression] = { + val zero = Literal.fromObject(aggregator.zero, bufferExternalType) + bufferSerializer.map(_ transform { + case b: BoundReference => zero + }) } - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) - val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) - val merged = aggregator.merge(b1, b2) - val returned = bEncoder.toRow(merged) + override lazy val updateExpressions: Seq[Expression] = { + val reduced = Invoke( + aggregatorLiteral, + "reduce", + bufferExternalType, + bufferDeserializer :: inputDeserializer.get :: Nil) + + bufferSerializer.map(_ transform { + case b: BoundReference => reduced + }) + } - updateBuffer(buffer1, returned) + override lazy val mergeExpressions: Seq[Expression] = { + val leftBuffer = bufferDeserializer transform { + case a: AttributeReference => a.left + } + val rightBuffer = bufferDeserializer transform { + case a: AttributeReference => a.right + } + val merged = Invoke( + aggregatorLiteral, + "merge", + bufferExternalType, + leftBuffer :: rightBuffer :: Nil) + + bufferSerializer.map(_ transform { + case b: BoundReference => merged + }) } - override def eval(buffer: InternalRow): Any = { - val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) - val result = cEncoder.toRow(aggregator.finish(b)) + override lazy val evaluateExpression: Expression = { + val resultObj = Invoke( + aggregatorLiteral, + "finish", + outputExternalType, + bufferDeserializer :: Nil) + + val result = outputSerializer.map(_ transform { + case b: BoundReference => resultObj + }) + dataType match { - case _: StructType => result - case _ => result.get(0, dataType) + case s: StructType => CreateStruct(result) + case _ => + assert(result.length == 1) + result.head } } override def toString: String = { - s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})""" + val input = inputDeserializer match { + case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString + case Some(deserializer) => deserializer.dataType.simpleString + case _ => "unknown" + } + + s"$nodeName($input)" } - override def nodeName: String = aggregator.getClass.getSimpleName + override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 9cb356f1ca375..1f083f0552fed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} +import org.apache.spark.sql.{Dataset, Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression /** - * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] - * operations to take all of the elements of a group and reduce them to a single value. + * A base class for user-defined aggregations, which can be used in [[Dataset]] operations to take + * all of the elements of a group and reduce them to a single value. * * For example, the following aggregator extracts an `int` from a specific class and adds them up: * {{{ @@ -43,52 +43,51 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression * * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird * - * @tparam I The input type for the aggregation. - * @tparam B The type of the intermediate value of the reduction. - * @tparam O The type of the final output result. + * @tparam IN The input type for the aggregation. + * @tparam BUF The type of the intermediate value of the reduction. + * @tparam OUT The type of the final output result. * @since 1.6.0 */ -abstract class Aggregator[-I, B, O] extends Serializable { +abstract class Aggregator[-IN, BUF, OUT] extends Serializable { /** * A zero value for this aggregation. Should satisfy the property that any b + zero = b. * @since 1.6.0 */ - def zero: B + def zero: BUF /** * Combine two values to produce a new value. For performance, the function may modify `b` and * return it instead of constructing new object for b. * @since 1.6.0 */ - def reduce(b: B, a: I): B + def reduce(b: BUF, a: IN): BUF /** * Merge two intermediate values. * @since 1.6.0 */ - def merge(b1: B, b2: B): B + def merge(b1: BUF, b2: BUF): BUF /** * Transform the output of the reduction. * @since 1.6.0 */ - def finish(reduction: B): O + def finish(reduction: BUF): OUT /** - * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]] - * operations. + * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] operations. * @since 1.6.0 */ def toColumn( - implicit bEncoder: Encoder[B], - cEncoder: Encoder[O]): TypedColumn[I, O] = { + implicit bufferEnc: Encoder[BUF], + outputEnc: Encoder[OUT]): TypedColumn[IN, OUT] = { val expr = AggregateExpression( TypedAggregateExpression(this), Complete, isDistinct = false) - new TypedColumn[I, O](expr, encoderFor[O]) + new TypedColumn[IN, OUT](expr, encoderFor[OUT]) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 48a077d0e551a..23a0ce215ff3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -29,9 +29,11 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.MemoryPlan +import org.apache.spark.sql.types.ObjectType abstract class QueryTest extends PlanTest { @@ -204,6 +206,8 @@ abstract class QueryTest extends PlanTest { case _: MemoryPlan => return }.transformAllExpressions { case a: ImperativeAggregate => return + case _: TypedAggregateExpression => return + case Literal(_, _: ObjectType) => return } // bypass hive tests before we fix all corner cases in hive module. From 905234e628f342051953f3a1c09748d28881ad01 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 7 Apr 2016 17:14:31 +0800 Subject: [PATCH 2/5] avoid re-evaluating subexpressions in typed UDAF --- .../aggregate/TypedAggregateExpression.scala | 100 ++++++++++++++---- .../execution/WholeStageCodegenSuite.scala | 17 ++- 2 files changed, 92 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index bb122fb3130df..061e48b868bb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ @@ -32,9 +34,29 @@ object TypedAggregateExpression { aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { val bufferEncoder = encoderFor[BUF] val bufferSerializer = bufferEncoder.namedExpressions - val bufferDeserializer = UnresolvedDeserializer( - bufferEncoder.deserializer, - bufferSerializer.map(_.toAttribute)) + + // To avoid re-calculating the deserializer expression and function call expression while + // evaluating each buffer serializer expression, we serialize the buffer object to a single + // struct field, not multiply fields, no matter whether the encoder is flat or not. So for + // buffer deserializer, we should add one extra level at bottom, to use the buffer attribute of + // struct type as input. + // TODO: remove this trick after we have better integration of subexpression elimination and + // whole stage codegen. + val bufferAttr = if (bufferEncoder.flat) { + AttributeReference("buffer", bufferEncoder.schema.head.dataType, nullable = false)() + } else { + AttributeReference("buffer", bufferEncoder.schema, nullable = false)() + } + val bufferDeserializer = if (bufferEncoder.flat) { + bufferEncoder.deserializer + } else { + bufferEncoder.deserializer transformUp { + case UnresolvedAttribute(nameParts) => + assert(nameParts.length == 1) + UnresolvedExtractValue(bufferAttr, Literal(nameParts.head)) + case BoundReference(ordinal, dt, _) => GetStructField(bufferAttr, ordinal) + } + } val outputEncoder = encoderFor[OUT] val outputType = if (outputEncoder.flat) { @@ -46,8 +68,9 @@ object TypedAggregateExpression { new TypedAggregateExpression( aggregator.asInstanceOf[Aggregator[Any, Any, Any]], None, + bufferAttr, bufferSerializer, - bufferDeserializer, + UnresolvedDeserializer(bufferDeserializer, bufferAttr :: Nil), outputEncoder.serializer, outputEncoder.deserializer.dataType, outputType) @@ -60,6 +83,7 @@ object TypedAggregateExpression { case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], inputDeserializer: Option[Expression], + bufferAttr: AttributeReference, bufferSerializer: Seq[NamedExpression], bufferDeserializer: Expression, outputSerializer: Seq[Expression], @@ -83,14 +107,21 @@ case class TypedAggregateExpression( private def bufferExternalType = bufferDeserializer.dataType - override lazy val aggBufferAttributes: Seq[AttributeReference] = - bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference]) + override lazy val aggBufferAttributes: Seq[AttributeReference] = bufferAttr :: Nil + + private def generateBuffer(inputObj: Expression): Seq[Expression] = { + if (bufferSerializer.length > 1) { + EvaluateOnce(bufferSerializer, inputObj, bufferAttr.dataType) :: Nil + } else { + bufferSerializer.head.transform { + case b: BoundReference => inputObj + } :: Nil + } + } override lazy val initialValues: Seq[Expression] = { val zero = Literal.fromObject(aggregator.zero, bufferExternalType) - bufferSerializer.map(_ transform { - case b: BoundReference => zero - }) + generateBuffer(zero) } override lazy val updateExpressions: Seq[Expression] = { @@ -100,9 +131,7 @@ case class TypedAggregateExpression( bufferExternalType, bufferDeserializer :: inputDeserializer.get :: Nil) - bufferSerializer.map(_ transform { - case b: BoundReference => reduced - }) + generateBuffer(reduced) } override lazy val mergeExpressions: Seq[Expression] = { @@ -118,9 +147,7 @@ case class TypedAggregateExpression( bufferExternalType, leftBuffer :: rightBuffer :: Nil) - bufferSerializer.map(_ transform { - case b: BoundReference => merged - }) + generateBuffer(merged) } override lazy val evaluateExpression: Expression = { @@ -130,15 +157,13 @@ case class TypedAggregateExpression( outputExternalType, bufferDeserializer :: Nil) - val result = outputSerializer.map(_ transform { - case b: BoundReference => resultObj - }) - dataType match { - case s: StructType => CreateStruct(result) + case s: StructType => EvaluateOnce(outputSerializer, resultObj, s) case _ => - assert(result.length == 1) - result.head + assert(outputSerializer.length == 1) + outputSerializer.head transform { + case b: BoundReference => resultObj + } } } @@ -154,3 +179,32 @@ case class TypedAggregateExpression( override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") } + +/** + * Combines serializer expressions into one single expression that outputs a struct, evaluate the + * object expression only once and use the result as input for all serializer expressions. + */ +case class EvaluateOnce(serializer: Seq[Expression], obj: Expression, dataType: DataType) + extends UnaryExpression with NonSQLExpression { + + override def nullable: Boolean = false + override def child: Expression = obj + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val evalObj = obj.gen(ctx) + val objRef = LambdaVariable(evalObj.value, evalObj.isNull, obj.dataType) + + val result = CreateStruct(serializer.map(_ transform { + case b: BoundReference => objRef + })) + + val evalResult = result.gen(ctx) + ev.value = evalResult.value + ev.isNull = evalResult.isNull + + evalObj.code + "\n" + evalResult.code + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index f73ca887f165a..16fde3777c165 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution -import org.apache.spark.api.java.function.MapFunction -import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.BroadcastHashJoin +import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions.{avg, broadcast, col, max} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -82,4 +82,17 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined) assert(ds.collect() === 0.until(10).map(_.toString).toArray) } + + test("simple typed UDAF should be included in WholeStageCodegen") { + import testImplicits._ + + val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS() + .groupByKey(_._1).agg(typed.sum(_._2)) + + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined) + assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) + } } From 4bbd5084b45db9ad2bb675e79bb0a28537db0aea Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 7 Apr 2016 19:56:29 +0800 Subject: [PATCH 3/5] add benchmark --- .../apache/spark/sql/DatasetBenchmark.scala | 74 +++++++++++++++++-- 1 file changed, 66 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index 6eb952445f221..3e1cd737c38e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql import org.apache.spark.SparkContext +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.expressions.scala.typed +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StringType import org.apache.spark.util.Benchmark @@ -28,16 +31,10 @@ object DatasetBenchmark { case class Data(l: Long, s: String) - def main(args: Array[String]): Unit = { - val sparkContext = new SparkContext("local[*]", "Dataset benchmark") - val sqlContext = new SQLContext(sparkContext) - + def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { import sqlContext.implicits._ - val numRows = 10000000 val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) - val numChains = 10 - val benchmark = new Benchmark("back-to-back map", numRows) val func = (d: Data) => Data(d.l + 1, d.s) @@ -61,7 +58,7 @@ object DatasetBenchmark { res.queryExecution.toRdd.foreach(_ => Unit) } - val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD") { iter => var res = rdd var i = 0 @@ -72,6 +69,55 @@ object DatasetBenchmark { res.foreach(_ => Unit) } + benchmark + } + + object ComplexAggregator extends Aggregator[Data, Data, Long] { + override def zero: Data = Data(0, "") + + override def reduce(b: Data, a: Data): Data = Data(b.l + a.l, "") + + override def finish(reduction: Data): Long = reduction.l + + override def merge(b1: Data, b2: Data): Data = Data(b1.l + b2.l, "") + } + + def aggregate(sqlContext: SQLContext, numRows: Long): Benchmark = { + import sqlContext.implicits._ + + val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val benchmark = new Benchmark("aggregate", numRows) + + benchmark.addCase("DataFrame sum") { iter => + df.select(sum($"l")).queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset sum using Aggregator") { iter => + df.as[Data].select(typed.sumLong((d: Data) => d.l)).queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset complex Aggregator") { iter => + df.as[Data].select(ComplexAggregator.toColumn).queryExecution.toRdd.foreach(_ => Unit) + } + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD sum") { iter => + rdd.aggregate(0L)(_ + _.l, _ + _) + } + + benchmark + } + + def main(args: Array[String]): Unit = { + val sparkContext = new SparkContext("local[*]", "Dataset benchmark") + val sqlContext = new SQLContext(sparkContext) + + val numRows = 10000000 + val numChains = 10 + + val benchmark = backToBackMap(sqlContext, numRows, numChains) + val benchmark2 = aggregate(sqlContext, numRows) + /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz @@ -82,5 +128,17 @@ object DatasetBenchmark { RDD 216 / 237 46.3 21.6 4.2X */ benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + aggregate: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + DataFrame sum 137 / 314 72.7 13.7 1.0X + Dataset sum using Aggregator 506 / 542 19.8 50.6 0.3X + Dataset complex Aggregator 959 / 1051 10.4 95.9 0.1X + RDD sum 203 / 217 49.2 20.3 0.7X + */ + benchmark2.run() } } From 7a136c5e88bbd6e17e4b2eadb6a5a558a0d9f333 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 8 Apr 2016 09:49:31 +0800 Subject: [PATCH 4/5] increase datasize and reorder benchmark to run RDD first, then DataFrame, then Dataset --- .../apache/spark/sql/DatasetBenchmark.scala | 70 +++++++++---------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index ccb4b1d864f14..006b3d4f45a63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -36,16 +36,17 @@ object DatasetBenchmark { val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back map", numRows) - val func = (d: Data) => Data(d.l + 1, d.s) - benchmark.addCase("Dataset") { iter => - var res = df.as[Data] + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd var i = 0 while (i < numChains) { - res = res.map(func) + res = rdd.map(func) i += 1 } - res.queryExecution.toRdd.foreach(_ => Unit) + res.foreach(_ => Unit) } benchmark.addCase("DataFrame") { iter => @@ -58,15 +59,14 @@ object DatasetBenchmark { res.queryExecution.toRdd.foreach(_ => Unit) } - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) - benchmark.addCase("RDD") { iter => - var res = rdd + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] var i = 0 while (i < numChains) { - res = rdd.map(func) + res = res.map(func) i += 1 } - res.foreach(_ => Unit) + res.queryExecution.toRdd.foreach(_ => Unit) } benchmark @@ -77,19 +77,20 @@ object DatasetBenchmark { val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back filter", numRows) - val func = (d: Data, i: Int) => d.l % (100L + i) == 0L val funcs = 0.until(numChains).map { i => (d: Data) => func(d, i) } - benchmark.addCase("Dataset") { iter => - var res = df.as[Data] + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd var i = 0 while (i < numChains) { - res = res.filter(funcs(i)) + res = rdd.filter(funcs(i)) i += 1 } - res.queryExecution.toRdd.foreach(_ => Unit) + res.foreach(_ => Unit) } benchmark.addCase("DataFrame") { iter => @@ -102,15 +103,14 @@ object DatasetBenchmark { res.queryExecution.toRdd.foreach(_ => Unit) } - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) - benchmark.addCase("RDD") { iter => - var res = rdd + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] var i = 0 while (i < numChains) { - res = rdd.filter(funcs(i)) + res = res.filter(funcs(i)) i += 1 } - res.foreach(_ => Unit) + res.queryExecution.toRdd.foreach(_ => Unit) } benchmark @@ -132,15 +132,15 @@ object DatasetBenchmark { val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("aggregate", numRows) - benchmark.addCase("DataFrame sum") { iter => - df.select(sum($"l")).queryExecution.toRdd.foreach(_ => Unit) - } - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD sum") { iter => rdd.aggregate(0L)(_ + _.l, _ + _) } + benchmark.addCase("DataFrame sum") { iter => + df.select(sum($"l")).queryExecution.toRdd.foreach(_ => Unit) + } + benchmark.addCase("Dataset sum using Aggregator") { iter => df.as[Data].select(typed.sumLong((d: Data) => d.l)).queryExecution.toRdd.foreach(_ => Unit) } @@ -156,7 +156,7 @@ object DatasetBenchmark { val sparkContext = new SparkContext("local[*]", "Dataset benchmark") val sqlContext = new SQLContext(sparkContext) - val numRows = 10000000 + val numRows = 100000000 val numChains = 10 val benchmark = backToBackMap(sqlContext, numRows, numChains) @@ -168,9 +168,9 @@ object DatasetBenchmark { Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Dataset 902 / 995 11.1 90.2 1.0X - DataFrame 132 / 167 75.5 13.2 6.8X - RDD 216 / 237 46.3 21.6 4.2X + RDD 1935 / 2105 51.7 19.3 1.0X + DataFrame 756 / 799 132.3 7.6 2.6X + Dataset 7359 / 7506 13.6 73.6 0.3X */ benchmark.run() @@ -179,9 +179,9 @@ object DatasetBenchmark { Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Dataset 585 / 628 17.1 58.5 1.0X - DataFrame 62 / 80 160.7 6.2 9.4X - RDD 205 / 220 48.7 20.5 2.8X + RDD 1974 / 2036 50.6 19.7 1.0X + DataFrame 103 / 127 967.4 1.0 19.1X + Dataset 4343 / 4477 23.0 43.4 0.5X */ benchmark2.run() @@ -190,10 +190,10 @@ object DatasetBenchmark { Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz aggregate: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - DataFrame sum 137 / 314 72.7 13.7 1.0X - RDD sum 203 / 217 49.2 20.3 0.7X - Dataset sum using Aggregator 506 / 542 19.8 50.6 0.3X - Dataset complex Aggregator 959 / 1051 10.4 95.9 0.1X + RDD sum 2130 / 2166 46.9 21.3 1.0X + DataFrame sum 92 / 128 1085.3 0.9 23.1X + Dataset sum using Aggregator 4111 / 4282 24.3 41.1 0.5X + Dataset complex Aggregator 8782 / 9036 11.4 87.8 0.2X */ benchmark3.run() } From 4ee5ac178bfb14eb54d4a669deebac1df86425a8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 11 Apr 2016 13:34:08 +0800 Subject: [PATCH 5/5] update --- .../expressions/ReferenceToExpressions.scala | 77 ++++++++++++++++ .../aggregate/TypedAggregateExpression.scala | 87 +++++-------------- .../apache/spark/sql/DatasetBenchmark.scala | 4 + 3 files changed, 105 insertions(+), 63 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala new file mode 100644 index 0000000000000..22645c952e722 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.DataType + +/** + * A special expression that evaluates [[BoundReference]]s by given expressions instead of the + * input row. + * + * @param result The expression that contains [[BoundReference]] and produces the final output. + * @param children The expressions that used as input values for [[BoundReference]]. + */ +case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) + extends Expression { + + override def nullable: Boolean = result.nullable + override def dataType: DataType = result.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + if (result.references.nonEmpty) { + return TypeCheckFailure("The result expression cannot reference to any attributes.") + } + + var maxOrdinal = -1 + result foreach { + case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal + } + if (maxOrdinal > children.length) { + return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " + + s"there are only ${children.length} inputs.") + } + + TypeCheckSuccess + } + + private lazy val projection = UnsafeProjection.create(children) + + override def eval(input: InternalRow): Any = { + result.eval(projection(input)) + } + + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val childrenGen = children.map(_.gen(ctx)) + val childrenVars = childrenGen.zip(children).map { + case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType) + } + + val resultGen = result.transform { + case b: BoundReference => childrenVars(b.ordinal) + }.gen(ctx) + + ev.value = resultGen.value + ev.isNull = resultGen.isNull + + childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 061e48b868bb7..535e64cb34442 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -20,12 +20,10 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ @@ -33,28 +31,30 @@ object TypedAggregateExpression { def apply[BUF : Encoder, OUT : Encoder]( aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { val bufferEncoder = encoderFor[BUF] - val bufferSerializer = bufferEncoder.namedExpressions - - // To avoid re-calculating the deserializer expression and function call expression while - // evaluating each buffer serializer expression, we serialize the buffer object to a single - // struct field, not multiply fields, no matter whether the encoder is flat or not. So for - // buffer deserializer, we should add one extra level at bottom, to use the buffer attribute of - // struct type as input. + // We will insert the deserializer and function call expression at the bottom of each serializer + // expression while executing `TypedAggregateExpression`, which means multiply serializer + // expressions will all evaluate the same sub-expression at bottom. To avoid the re-evaluating, + // here we always use one single serializer expression to serialize the buffer object into a + // single-field row, no matter whether the encoder is flat or not. We also need to update the + // deserializer to read in all fields from that single-field row. // TODO: remove this trick after we have better integration of subexpression elimination and // whole stage codegen. - val bufferAttr = if (bufferEncoder.flat) { - AttributeReference("buffer", bufferEncoder.schema.head.dataType, nullable = false)() + val bufferSerializer = if (bufferEncoder.flat) { + bufferEncoder.namedExpressions.head } else { - AttributeReference("buffer", bufferEncoder.schema, nullable = false)() + Alias(CreateStruct(bufferEncoder.serializer), "buffer")() } + val bufferDeserializer = if (bufferEncoder.flat) { - bufferEncoder.deserializer + bufferEncoder.deserializer transformUp { + case b: BoundReference => bufferSerializer.toAttribute + } } else { bufferEncoder.deserializer transformUp { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) - UnresolvedExtractValue(bufferAttr, Literal(nameParts.head)) - case BoundReference(ordinal, dt, _) => GetStructField(bufferAttr, ordinal) + UnresolvedExtractValue(bufferSerializer.toAttribute, Literal(nameParts.head)) + case BoundReference(ordinal, dt, _) => GetStructField(bufferSerializer.toAttribute, ordinal) } } @@ -68,9 +68,8 @@ object TypedAggregateExpression { new TypedAggregateExpression( aggregator.asInstanceOf[Aggregator[Any, Any, Any]], None, - bufferAttr, bufferSerializer, - UnresolvedDeserializer(bufferDeserializer, bufferAttr :: Nil), + bufferDeserializer, outputEncoder.serializer, outputEncoder.deserializer.dataType, outputType) @@ -83,8 +82,7 @@ object TypedAggregateExpression { case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], inputDeserializer: Option[Expression], - bufferAttr: AttributeReference, - bufferSerializer: Seq[NamedExpression], + bufferSerializer: NamedExpression, bufferDeserializer: Expression, outputSerializer: Seq[Expression], outputExternalType: DataType, @@ -107,21 +105,12 @@ case class TypedAggregateExpression( private def bufferExternalType = bufferDeserializer.dataType - override lazy val aggBufferAttributes: Seq[AttributeReference] = bufferAttr :: Nil - - private def generateBuffer(inputObj: Expression): Seq[Expression] = { - if (bufferSerializer.length > 1) { - EvaluateOnce(bufferSerializer, inputObj, bufferAttr.dataType) :: Nil - } else { - bufferSerializer.head.transform { - case b: BoundReference => inputObj - } :: Nil - } - } + override lazy val aggBufferAttributes: Seq[AttributeReference] = + bufferSerializer.toAttribute.asInstanceOf[AttributeReference] :: Nil override lazy val initialValues: Seq[Expression] = { val zero = Literal.fromObject(aggregator.zero, bufferExternalType) - generateBuffer(zero) + ReferenceToExpressions(bufferSerializer, zero :: Nil) :: Nil } override lazy val updateExpressions: Seq[Expression] = { @@ -131,7 +120,7 @@ case class TypedAggregateExpression( bufferExternalType, bufferDeserializer :: inputDeserializer.get :: Nil) - generateBuffer(reduced) + ReferenceToExpressions(bufferSerializer, reduced :: Nil) :: Nil } override lazy val mergeExpressions: Seq[Expression] = { @@ -147,7 +136,7 @@ case class TypedAggregateExpression( bufferExternalType, leftBuffer :: rightBuffer :: Nil) - generateBuffer(merged) + ReferenceToExpressions(bufferSerializer, merged :: Nil) :: Nil } override lazy val evaluateExpression: Expression = { @@ -158,7 +147,8 @@ case class TypedAggregateExpression( bufferDeserializer :: Nil) dataType match { - case s: StructType => EvaluateOnce(outputSerializer, resultObj, s) + case s: StructType => + ReferenceToExpressions(CreateStruct(outputSerializer), resultObj :: Nil) case _ => assert(outputSerializer.length == 1) outputSerializer.head transform { @@ -179,32 +169,3 @@ case class TypedAggregateExpression( override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") } - -/** - * Combines serializer expressions into one single expression that outputs a struct, evaluate the - * object expression only once and use the result as input for all serializer expressions. - */ -case class EvaluateOnce(serializer: Seq[Expression], obj: Expression, dataType: DataType) - extends UnaryExpression with NonSQLExpression { - - override def nullable: Boolean = false - override def child: Expression = obj - - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val evalObj = obj.gen(ctx) - val objRef = LambdaVariable(evalObj.value, evalObj.isNull, obj.dataType) - - val result = CreateStruct(serializer.map(_ transform { - case b: BoundReference => objRef - })) - - val evalResult = result.gen(ctx) - ev.value = evalResult.value - ev.isNull = evalResult.isNull - - evalObj.code + "\n" + evalResult.code - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index 006b3d4f45a63..ae9fb80c68f42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -124,6 +124,10 @@ object DatasetBenchmark { override def finish(reduction: Data): Long = reduction.l override def merge(b1: Data, b2: Data): Data = Data(b1.l + b2.l, "") + + override def bufferEncoder: Encoder[Data] = Encoders.product[Data] + + override def outputEncoder: Encoder[Long] = Encoders.scalaLong } def aggregate(sqlContext: SQLContext, numRows: Long): Benchmark = {