-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-14275][SQL] Reimplement TypedAggregateExpression to DeclarativeAggregate #12067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c70d920
905234e
4bbd508
5f6510e
7a136c5
045a9be
4ee5ac1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.expressions | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} | ||
| import org.apache.spark.sql.types.DataType | ||
|
|
||
| /** | ||
| * A special expression that evaluates [[BoundReference]]s by given expressions instead of the | ||
| * input row. | ||
| * | ||
| * @param result The expression that contains [[BoundReference]] and produces the final output. | ||
| * @param children The expressions that used as input values for [[BoundReference]]. | ||
| */ | ||
| case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) | ||
| extends Expression { | ||
|
|
||
| override def nullable: Boolean = result.nullable | ||
| override def dataType: DataType = result.dataType | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| if (result.references.nonEmpty) { | ||
| return TypeCheckFailure("The result expression cannot reference to any attributes.") | ||
| } | ||
|
|
||
| var maxOrdinal = -1 | ||
| result foreach { | ||
| case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal | ||
| } | ||
| if (maxOrdinal > children.length) { | ||
| return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " + | ||
| s"there are only ${children.length} inputs.") | ||
| } | ||
|
|
||
| TypeCheckSuccess | ||
| } | ||
|
|
||
| private lazy val projection = UnsafeProjection.create(children) | ||
|
|
||
| override def eval(input: InternalRow): Any = { | ||
| result.eval(projection(input)) | ||
| } | ||
|
|
||
| override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { | ||
| val childrenGen = children.map(_.gen(ctx)) | ||
| val childrenVars = childrenGen.zip(children).map { | ||
| case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType) | ||
| } | ||
|
|
||
| val resultGen = result.transform { | ||
| case b: BoundReference => childrenVars(b.ordinal) | ||
| }.gen(ctx) | ||
|
|
||
| ev.value = resultGen.value | ||
| ev.isNull = resultGen.isNull | ||
|
|
||
| childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,133 +19,153 @@ package org.apache.spark.sql.execution.aggregate | |
|
|
||
| import scala.language.existentials | ||
|
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql.Encoder | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} | ||
| import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} | ||
| import org.apache.spark.sql.catalyst.encoders.encoderFor | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate | ||
| import org.apache.spark.sql.expressions.Aggregator | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| object TypedAggregateExpression { | ||
| def apply[A, B : Encoder, C : Encoder]( | ||
| aggregator: Aggregator[A, B, C]): TypedAggregateExpression = { | ||
| def apply[BUF : Encoder, OUT : Encoder]( | ||
| aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { | ||
| val bufferEncoder = encoderFor[BUF] | ||
| // We will insert the deserializer and function call expression at the bottom of each serializer | ||
| // expression while executing `TypedAggregateExpression`, which means multiply serializer | ||
| // expressions will all evaluate the same sub-expression at bottom. To avoid the re-evaluating, | ||
| // here we always use one single serializer expression to serialize the buffer object into a | ||
| // single-field row, no matter whether the encoder is flat or not. We also need to update the | ||
| // deserializer to read in all fields from that single-field row. | ||
| // TODO: remove this trick after we have better integration of subexpression elimination and | ||
| // whole stage codegen. | ||
| val bufferSerializer = if (bufferEncoder.flat) { | ||
| bufferEncoder.namedExpressions.head | ||
| } else { | ||
| Alias(CreateStruct(bufferEncoder.serializer), "buffer")() | ||
| } | ||
|
|
||
| val bufferDeserializer = if (bufferEncoder.flat) { | ||
| bufferEncoder.deserializer transformUp { | ||
| case b: BoundReference => bufferSerializer.toAttribute | ||
| } | ||
| } else { | ||
| bufferEncoder.deserializer transformUp { | ||
| case UnresolvedAttribute(nameParts) => | ||
| assert(nameParts.length == 1) | ||
| UnresolvedExtractValue(bufferSerializer.toAttribute, Literal(nameParts.head)) | ||
| case BoundReference(ordinal, dt, _) => GetStructField(bufferSerializer.toAttribute, ordinal) | ||
| } | ||
| } | ||
|
|
||
| val outputEncoder = encoderFor[OUT] | ||
| val outputType = if (outputEncoder.flat) { | ||
| outputEncoder.schema.head.dataType | ||
| } else { | ||
| outputEncoder.schema | ||
| } | ||
|
|
||
| new TypedAggregateExpression( | ||
| aggregator.asInstanceOf[Aggregator[Any, Any, Any]], | ||
| None, | ||
| encoderFor[B].asInstanceOf[ExpressionEncoder[Any]], | ||
| encoderFor[C].asInstanceOf[ExpressionEncoder[Any]], | ||
| Nil, | ||
| 0, | ||
| 0) | ||
| bufferSerializer, | ||
| bufferDeserializer, | ||
| outputEncoder.serializer, | ||
| outputEncoder.deserializer.dataType, | ||
| outputType) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has | ||
| * the following limitations: | ||
| * - It assumes the aggregator has a zero, `0`. | ||
| * A helper class to hook [[Aggregator]] into the aggregation system. | ||
| */ | ||
| case class TypedAggregateExpression( | ||
| aggregator: Aggregator[Any, Any, Any], | ||
| aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. | ||
| unresolvedBEncoder: ExpressionEncoder[Any], | ||
| cEncoder: ExpressionEncoder[Any], | ||
| children: Seq[Attribute], | ||
| mutableAggBufferOffset: Int, | ||
| inputAggBufferOffset: Int) | ||
| extends ImperativeAggregate with Logging { | ||
|
|
||
| override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
| copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
|
||
| override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
| copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
| inputDeserializer: Option[Expression], | ||
| bufferSerializer: NamedExpression, | ||
| bufferDeserializer: Expression, | ||
| outputSerializer: Seq[Expression], | ||
| outputExternalType: DataType, | ||
| dataType: DataType) extends DeclarativeAggregate with NonSQLExpression { | ||
|
|
||
| override def nullable: Boolean = true | ||
|
|
||
| override def dataType: DataType = if (cEncoder.flat) { | ||
| cEncoder.schema.head.dataType | ||
| } else { | ||
| cEncoder.schema | ||
| } | ||
|
|
||
| override def deterministic: Boolean = true | ||
|
|
||
| override lazy val resolved: Boolean = aEncoder.isDefined | ||
|
|
||
| override lazy val inputTypes: Seq[DataType] = Nil | ||
| override def children: Seq[Expression] = inputDeserializer.toSeq :+ bufferDeserializer | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why bufferDeserializer is part of children?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to resolve deserializer in Analyzer, make them children here so that they can be transformed. |
||
|
|
||
| override val aggBufferSchema: StructType = unresolvedBEncoder.schema | ||
| override lazy val resolved: Boolean = inputDeserializer.isDefined && childrenResolved | ||
|
|
||
| override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes | ||
| override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be AttributeSet(inputDeserializer.flatMap(_.references))
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| val bEncoder = unresolvedBEncoder | ||
| .resolve(aggBufferAttributes, OuterScopes.outerScopes) | ||
| .bind(aggBufferAttributes) | ||
| override def inputTypes: Seq[AbstractDataType] = Nil | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is from the original code: https://github.com/apache/spark/pull/12067/files#diff-585631b931d2b881e65c585dec11cadbL78 It's because |
||
|
|
||
| // Note: although this simply copies aggBufferAttributes, this common code can not be placed | ||
| // in the superclass because that will lead to initialization ordering issues. | ||
| override val inputAggBufferAttributes: Seq[AttributeReference] = | ||
| aggBufferAttributes.map(_.newInstance()) | ||
| private def aggregatorLiteral = | ||
| Literal.create(aggregator, ObjectType(classOf[Aggregator[Any, Any, Any]])) | ||
|
|
||
| // We let the dataset do the binding for us. | ||
| lazy val boundA = aEncoder.get | ||
| private def bufferExternalType = bufferDeserializer.dataType | ||
|
|
||
| private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { | ||
| var i = 0 | ||
| while (i < aggBufferAttributes.length) { | ||
| val offset = mutableAggBufferOffset + i | ||
| aggBufferSchema(i).dataType match { | ||
| case BooleanType => buffer.setBoolean(offset, value.getBoolean(i)) | ||
| case ByteType => buffer.setByte(offset, value.getByte(i)) | ||
| case ShortType => buffer.setShort(offset, value.getShort(i)) | ||
| case IntegerType => buffer.setInt(offset, value.getInt(i)) | ||
| case LongType => buffer.setLong(offset, value.getLong(i)) | ||
| case FloatType => buffer.setFloat(offset, value.getFloat(i)) | ||
| case DoubleType => buffer.setDouble(offset, value.getDouble(i)) | ||
| case other => buffer.update(offset, value.get(i, other)) | ||
| } | ||
| i += 1 | ||
| } | ||
| } | ||
| override lazy val aggBufferAttributes: Seq[AttributeReference] = | ||
| bufferSerializer.toAttribute.asInstanceOf[AttributeReference] :: Nil | ||
|
|
||
| override def initialize(buffer: MutableRow): Unit = { | ||
| val zero = bEncoder.toRow(aggregator.zero) | ||
| updateBuffer(buffer, zero) | ||
| override lazy val initialValues: Seq[Expression] = { | ||
| val zero = Literal.fromObject(aggregator.zero, bufferExternalType) | ||
| ReferenceToExpressions(bufferSerializer, zero :: Nil) :: Nil | ||
| } | ||
|
|
||
| override def update(buffer: MutableRow, input: InternalRow): Unit = { | ||
| val inputA = boundA.fromRow(input) | ||
| val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) | ||
| val merged = aggregator.reduce(currentB, inputA) | ||
| val returned = bEncoder.toRow(merged) | ||
| override lazy val updateExpressions: Seq[Expression] = { | ||
| val reduced = Invoke( | ||
| aggregatorLiteral, | ||
| "reduce", | ||
| bufferExternalType, | ||
| bufferDeserializer :: inputDeserializer.get :: Nil) | ||
|
|
||
| updateBuffer(buffer, returned) | ||
| ReferenceToExpressions(bufferSerializer, reduced :: Nil) :: Nil | ||
| } | ||
|
|
||
| override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { | ||
| val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) | ||
| val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) | ||
| val merged = aggregator.merge(b1, b2) | ||
| val returned = bEncoder.toRow(merged) | ||
| override lazy val mergeExpressions: Seq[Expression] = { | ||
| val leftBuffer = bufferDeserializer transform { | ||
| case a: AttributeReference => a.left | ||
| } | ||
| val rightBuffer = bufferDeserializer transform { | ||
| case a: AttributeReference => a.right | ||
| } | ||
| val merged = Invoke( | ||
| aggregatorLiteral, | ||
| "merge", | ||
| bufferExternalType, | ||
| leftBuffer :: rightBuffer :: Nil) | ||
|
|
||
| updateBuffer(buffer1, returned) | ||
| ReferenceToExpressions(bufferSerializer, merged :: Nil) :: Nil | ||
| } | ||
|
|
||
| override def eval(buffer: InternalRow): Any = { | ||
| val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) | ||
| val result = cEncoder.toRow(aggregator.finish(b)) | ||
| override lazy val evaluateExpression: Expression = { | ||
| val resultObj = Invoke( | ||
| aggregatorLiteral, | ||
| "finish", | ||
| outputExternalType, | ||
| bufferDeserializer :: Nil) | ||
|
|
||
| dataType match { | ||
| case _: StructType => result | ||
| case _ => result.get(0, dataType) | ||
| case s: StructType => | ||
| ReferenceToExpressions(CreateStruct(outputSerializer), resultObj :: Nil) | ||
| case _ => | ||
| assert(outputSerializer.length == 1) | ||
| outputSerializer.head transform { | ||
| case b: BoundReference => resultObj | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def toString: String = { | ||
| s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})""" | ||
| val input = inputDeserializer match { | ||
| case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is deserializer always resolved?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh sorry I missed here, this case should be removed.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you remove it when merging?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just realized this line is needed. The input deserializer is set by |
||
| case Some(deserializer) => deserializer.dataType.simpleString | ||
| case _ => "unknown" | ||
| } | ||
|
|
||
| s"$nodeName($input)" | ||
| } | ||
|
|
||
| override def nodeName: String = aggregator.getClass.getSimpleName | ||
| override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we hold this pr a little bit? let me think of how to do subexpression elimination in aggregate.