Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ abstract class Expression extends TreeNode[Expression] {
* Returns a user-facing string representation of this expression's name.
* This should usually match the name of the function in SQL.
*/
def prettyName: String = getClass.getSimpleName.toLowerCase
def prettyName: String = nodeName.toLowerCase

private def flatArguments = productIterator.flatMap {
case t: Traversable[_] => t
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.DataType

/**
* A special expression that evaluates [[BoundReference]]s by given expressions instead of the
* input row.
*
* @param result The expression that contains [[BoundReference]] and produces the final output.
* @param children The expressions that used as input values for [[BoundReference]].
*/
case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
extends Expression {

override def nullable: Boolean = result.nullable
override def dataType: DataType = result.dataType

override def checkInputDataTypes(): TypeCheckResult = {
if (result.references.nonEmpty) {
return TypeCheckFailure("The result expression cannot reference to any attributes.")
}

var maxOrdinal = -1
result foreach {
case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
}
if (maxOrdinal > children.length) {
return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
s"there are only ${children.length} inputs.")
}

TypeCheckSuccess
}

private lazy val projection = UnsafeProjection.create(children)

override def eval(input: InternalRow): Any = {
result.eval(projection(input))
}

override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val childrenGen = children.map(_.gen(ctx))
val childrenVars = childrenGen.zip(children).map {
case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType)
}

val resultGen = result.transform {
case b: BoundReference => childrenVars(b.ordinal)
}.gen(ctx)

ev.value = resultGen.value
ev.isNull = resultGen.isNull

childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ object Literal {
* Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object
* into code generation.
*/
def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass))
def fromObject(obj: Any, objType: DataType): Literal = new Literal(obj, objType)
def fromObject(obj: Any): Literal = new Literal(obj, ObjectType(obj.getClass))

def fromJSON(json: JValue): Literal = {
val dataType = DataType.parseDataType(json \ "dataType")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,6 @@ private[sql] case class ObjectType(cls: Class[_]) extends DataType {
throw new UnsupportedOperationException("No size estimation available for objects.")

def asNullable: DataType = this

override def simpleString: String = cls.getName
}
16 changes: 8 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ class TypedColumn[-T, U](
* on a decoded object.
*/
private[sql] def withInputType(
inputEncoder: ExpressionEncoder[_],
schema: Seq[Attribute]): TypedColumn[T, U] = {
val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]]
new TypedColumn[T, U](
expr transform { case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
ta.copy(aEncoder = Some(boundEncoder), children = schema)
},
encoder)
inputDeserializer: Expression,
inputAttributes: Seq[Attribute]): TypedColumn[T, U] = {
val unresolvedDeserializer = UnresolvedDeserializer(inputDeserializer, inputAttributes)
val newExpr = expr transform {
case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
ta.copy(inputDeserializer = Some(unresolvedDeserializer))
}
new TypedColumn[T, U](newExpr, encoder)
}
}

Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ class Dataset[T] private[sql](
sqlContext,
Project(
c1.withInputType(
boundTEncoder,
unresolvedTEncoder.deserializer,
logicalPlan.output).named :: Nil,
logicalPlan),
implicitly[Encoder[U1]])
Expand All @@ -1007,7 +1007,7 @@ class Dataset[T] private[sql](
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
columns.map(_.withInputType(unresolvedTEncoder.deserializer, logicalPlan.output).named)
val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))

new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(
_.withInputType(resolvedVEncoder, dataAttributes).named)
columns.map(_.withInputType(unresolvedVEncoder.deserializer, dataAttributes).named)
val keyColumn = if (resolvedKEncoder.flat) {
assert(groupingAttributes.length == 1)
groupingAttributes.head
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,133 +19,153 @@ package org.apache.spark.sql.execution.aggregate

import scala.language.existentials

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.types._

object TypedAggregateExpression {
def apply[A, B : Encoder, C : Encoder](
aggregator: Aggregator[A, B, C]): TypedAggregateExpression = {
def apply[BUF : Encoder, OUT : Encoder](
aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = {
val bufferEncoder = encoderFor[BUF]
// We will insert the deserializer and function call expression at the bottom of each serializer
// expression while executing `TypedAggregateExpression`, which means multiply serializer
// expressions will all evaluate the same sub-expression at bottom. To avoid the re-evaluating,
// here we always use one single serializer expression to serialize the buffer object into a
// single-field row, no matter whether the encoder is flat or not. We also need to update the
// deserializer to read in all fields from that single-field row.
// TODO: remove this trick after we have better integration of subexpression elimination and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we hold this pr a little bit? let me think of how to do subexpression elimination in aggregate.

// whole stage codegen.
val bufferSerializer = if (bufferEncoder.flat) {
bufferEncoder.namedExpressions.head
} else {
Alias(CreateStruct(bufferEncoder.serializer), "buffer")()
}

val bufferDeserializer = if (bufferEncoder.flat) {
bufferEncoder.deserializer transformUp {
case b: BoundReference => bufferSerializer.toAttribute
}
} else {
bufferEncoder.deserializer transformUp {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(bufferSerializer.toAttribute, Literal(nameParts.head))
case BoundReference(ordinal, dt, _) => GetStructField(bufferSerializer.toAttribute, ordinal)
}
}

val outputEncoder = encoderFor[OUT]
val outputType = if (outputEncoder.flat) {
outputEncoder.schema.head.dataType
} else {
outputEncoder.schema
}

new TypedAggregateExpression(
aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
None,
encoderFor[B].asInstanceOf[ExpressionEncoder[Any]],
encoderFor[C].asInstanceOf[ExpressionEncoder[Any]],
Nil,
0,
0)
bufferSerializer,
bufferDeserializer,
outputEncoder.serializer,
outputEncoder.deserializer.dataType,
outputType)
}
}

/**
* This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has
* the following limitations:
* - It assumes the aggregator has a zero, `0`.
* A helper class to hook [[Aggregator]] into the aggregation system.
*/
case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
aEncoder: Option[ExpressionEncoder[Any]], // Should be bound.
unresolvedBEncoder: ExpressionEncoder[Any],
cEncoder: ExpressionEncoder[Any],
children: Seq[Attribute],
mutableAggBufferOffset: Int,
inputAggBufferOffset: Int)
extends ImperativeAggregate with Logging {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
inputDeserializer: Option[Expression],
bufferSerializer: NamedExpression,
bufferDeserializer: Expression,
outputSerializer: Seq[Expression],
outputExternalType: DataType,
dataType: DataType) extends DeclarativeAggregate with NonSQLExpression {

override def nullable: Boolean = true

override def dataType: DataType = if (cEncoder.flat) {
cEncoder.schema.head.dataType
} else {
cEncoder.schema
}

override def deterministic: Boolean = true

override lazy val resolved: Boolean = aEncoder.isDefined

override lazy val inputTypes: Seq[DataType] = Nil
override def children: Seq[Expression] = inputDeserializer.toSeq :+ bufferDeserializer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why bufferDeserializer is part of children?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to resolve deserializer in Analyzer, make them children here so that they can be transformed.


override val aggBufferSchema: StructType = unresolvedBEncoder.schema
override lazy val resolved: Boolean = inputDeserializer.isDefined && childrenResolved

override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes
override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be AttributeSet(inputDeserializer.flatMap(_.references))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


val bEncoder = unresolvedBEncoder
.resolve(aggBufferAttributes, OuterScopes.outerScopes)
.bind(aggBufferAttributes)
override def inputTypes: Seq[AbstractDataType] = Nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is from the original code: https://github.com/apache/spark/pull/12067/files#diff-585631b931d2b881e65c585dec11cadbL78

It's because AggregateFunction implements ImplicitCastInputTypes, so we have to provide a inputTypes here.


// Note: although this simply copies aggBufferAttributes, this common code can not be placed
// in the superclass because that will lead to initialization ordering issues.
override val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())
private def aggregatorLiteral =
Literal.create(aggregator, ObjectType(classOf[Aggregator[Any, Any, Any]]))

// We let the dataset do the binding for us.
lazy val boundA = aEncoder.get
private def bufferExternalType = bufferDeserializer.dataType

private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
var i = 0
while (i < aggBufferAttributes.length) {
val offset = mutableAggBufferOffset + i
aggBufferSchema(i).dataType match {
case BooleanType => buffer.setBoolean(offset, value.getBoolean(i))
case ByteType => buffer.setByte(offset, value.getByte(i))
case ShortType => buffer.setShort(offset, value.getShort(i))
case IntegerType => buffer.setInt(offset, value.getInt(i))
case LongType => buffer.setLong(offset, value.getLong(i))
case FloatType => buffer.setFloat(offset, value.getFloat(i))
case DoubleType => buffer.setDouble(offset, value.getDouble(i))
case other => buffer.update(offset, value.get(i, other))
}
i += 1
}
}
override lazy val aggBufferAttributes: Seq[AttributeReference] =
bufferSerializer.toAttribute.asInstanceOf[AttributeReference] :: Nil

override def initialize(buffer: MutableRow): Unit = {
val zero = bEncoder.toRow(aggregator.zero)
updateBuffer(buffer, zero)
override lazy val initialValues: Seq[Expression] = {
val zero = Literal.fromObject(aggregator.zero, bufferExternalType)
ReferenceToExpressions(bufferSerializer, zero :: Nil) :: Nil
}

override def update(buffer: MutableRow, input: InternalRow): Unit = {
val inputA = boundA.fromRow(input)
val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
val merged = aggregator.reduce(currentB, inputA)
val returned = bEncoder.toRow(merged)
override lazy val updateExpressions: Seq[Expression] = {
val reduced = Invoke(
aggregatorLiteral,
"reduce",
bufferExternalType,
bufferDeserializer :: inputDeserializer.get :: Nil)

updateBuffer(buffer, returned)
ReferenceToExpressions(bufferSerializer, reduced :: Nil) :: Nil
}

override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1)
val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2)
val merged = aggregator.merge(b1, b2)
val returned = bEncoder.toRow(merged)
override lazy val mergeExpressions: Seq[Expression] = {
val leftBuffer = bufferDeserializer transform {
case a: AttributeReference => a.left
}
val rightBuffer = bufferDeserializer transform {
case a: AttributeReference => a.right
}
val merged = Invoke(
aggregatorLiteral,
"merge",
bufferExternalType,
leftBuffer :: rightBuffer :: Nil)

updateBuffer(buffer1, returned)
ReferenceToExpressions(bufferSerializer, merged :: Nil) :: Nil
}

override def eval(buffer: InternalRow): Any = {
val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
val result = cEncoder.toRow(aggregator.finish(b))
override lazy val evaluateExpression: Expression = {
val resultObj = Invoke(
aggregatorLiteral,
"finish",
outputExternalType,
bufferDeserializer :: Nil)

dataType match {
case _: StructType => result
case _ => result.get(0, dataType)
case s: StructType =>
ReferenceToExpressions(CreateStruct(outputSerializer), resultObj :: Nil)
case _ =>
assert(outputSerializer.length == 1)
outputSerializer.head transform {
case b: BoundReference => resultObj
}
}
}

override def toString: String = {
s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})"""
val input = inputDeserializer match {
case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is deserializer always resolved?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry I missed here, this case should be removed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove it when merging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized this line is needed. The input deserializer is set by TypedColumn.withInputType and is unresolved at the first place.

case Some(deserializer) => deserializer.dataType.simpleString
case _ => "unknown"
}

s"$nodeName($input)"
}

override def nodeName: String = aggregator.getClass.getSimpleName
override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$")
}
Loading