From c36b9dc0be38aa59eb4679245371583dbe62e804 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 28 Oct 2016 15:36:35 -0700 Subject: [PATCH 01/10] Initial draft --- .../org/apache/spark/sql/hive/hiveUDFs.scala | 124 ++++++++++----- .../sql/hive/execution/HiveUDAFSuite.scala | 144 ++++++++++++++++++ 2 files changed, 229 insertions(+), 39 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 42033080dc34b..943d0be43502c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.hive +import java.nio.ByteBuffer + import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.ql.exec._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, - ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.spark.internal.Logging @@ -35,6 +37,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.hive.HiveUDAFFunction.AggregationBufferSerDe import org.apache.spark.sql.types._ @@ -58,7 +61,7 @@ private[hive] case class HiveSimpleUDF( @transient private lazy val isUDFDeterministic = { - val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) + val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) udfType != null && udfType.deterministic() } @@ -75,7 +78,7 @@ private[hive] case class HiveSimpleUDF( @transient lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector( - method.getGenericReturnType(), ObjectInspectorOptions.JAVA)) + method.getGenericReturnType, ObjectInspectorOptions.JAVA)) @transient private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) @@ -273,7 +276,7 @@ private[hive] case class HiveUDAFFunction( isUDAFBridgeRequired: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with HiveInspectors { + extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -293,61 +296,43 @@ private[hive] case class HiveUDAFFunction( private lazy val inspectors = children.map(toInspector).toArray @transient - private lazy val functionAndInspector = { + private lazy val function = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - val f = resolver.getEvaluator(parameterInfo) - f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + resolver.getEvaluator(parameterInfo) } @transient - private lazy val function = functionAndInspector._1 + private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray @transient - private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + private lazy val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) @transient - private lazy val returnInspector = functionAndInspector._2 + private lazy val partialResultInspector = + function.init(GenericUDAFEvaluator.Mode.PARTIAL1, inspectors) @transient - private lazy val unwrapper = unwrapperFor(returnInspector) + private lazy val partialResultDataType = inspectorToDataType(partialResultInspector) @transient - private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _ + private lazy val partialResultUnwrapper = unwrapperFor(partialResultInspector) - override def eval(input: InternalRow): Any = unwrapper(function.evaluate(buffer)) + @transient + private lazy val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType) @transient - private lazy val inputProjection = new InterpretedProjection(children) + private lazy val resultUnwrapper = unwrapperFor(returnInspector) @transient - private lazy val cached = new Array[AnyRef](children.length) + private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) @transient private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray - // Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation - // buffer for it. - override def aggBufferSchema: StructType = StructType(Nil) - - override def update(_buffer: InternalRow, input: InternalRow): Unit = { - val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes)) - } - - override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { - throw new UnsupportedOperationException( - "Hive UDAF doesn't support partial aggregate") - } - - override def initialize(_buffer: InternalRow): Unit = { - buffer = function.getNewAggregationBuffer - } - - override val aggBufferAttributes: Seq[AttributeReference] = Nil - - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override val inputAggBufferAttributes: Seq[AttributeReference] = Nil + @transient + private lazy val bufferSerDe: AggregationBufferSerDe = + new AggregationBufferSerDe( + function, partialResultDataType, partialResultUnwrapper, partialResultWrapper) // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our // catalyst type checking framework. @@ -365,4 +350,65 @@ private[hive] case class HiveUDAFFunction( val distinct = if (isDistinct) "DISTINCT " else " " s"$name($distinct${children.map(_.sql).mkString(", ")})" } + + override def createAggregationBuffer(): AggregationBuffer = function.getNewAggregationBuffer + + @transient + private lazy val inputProjection = new InterpretedProjection(children) + + override def update(buffer: AggregationBuffer, input: InternalRow): Unit = { + function.iterate(buffer, wrap(inputProjection(input), wrappers, cached, inputDataTypes)) + } + + override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = { + function.merge(buffer, partialResultUnwrapper(input)) + } + + override def eval(buffer: AggregationBuffer): Any = resultUnwrapper(function.evaluate(buffer)) + + override def serialize(buffer: AggregationBuffer): Array[Byte] = { + bufferSerDe.serialize(buffer) + } + + override def deserialize(bytes: Array[Byte]): AggregationBuffer = { + bufferSerDe.deserialize(bytes) + } +} + +private[hive] object HiveUDAFFunction { + class AggregationBufferSerDe( + function: GenericUDAFEvaluator, + partialResultDataType: DataType, + partialResultUnwrapper: Any => Any, + partialResultWrapper: Any => Any) + extends HiveInspectors { + + private val projection = UnsafeProjection.create(Array(partialResultDataType)) + + private val mutableRow = new GenericInternalRow(1) + + def serialize(buffer: AggregationBuffer): Array[Byte] = { + // `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object + // that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`. + // Then we can unwrap it to a Spark SQL value. + mutableRow.update(0, partialResultUnwrapper(function.terminatePartial(buffer))) + val unsafeRow = projection(mutableRow) + val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes) + unsafeRow.writeTo(bytes) + bytes.array() + } + + def deserialize(bytes: Array[Byte]): AggregationBuffer = { + // `GenericUDAFEvaluator` doesn't provide any method that is capable to convert an object + // returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The + // workaround here is creating an initial `AggregationBuffer` first and then merge the + // deserialized object into the buffer. + val buffer = function.getNewAggregationBuffer + val unsafeRow = new UnsafeRow + unsafeRow.pointTo(bytes, bytes.length) + val partialResult = unsafeRow.get(0, partialResultDataType) + function.merge(buffer, partialResultWrapper(partialResult)) + buffer + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala new file mode 100644 index 0000000000000..4a3dda635ed52 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDAFEvaluator, GenericUDAFMax} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.{AggregationBuffer, Mode} +import org.apache.hadoop.hive.ql.util.JavaDataModel +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.execution.aggregate.SortAggregateExec +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { + import testImplicits._ + + protected override def beforeAll(): Unit = { + sql(s"CREATE FUNCTION mock AS '${classOf[MockUDAF].getName}'") + sql(s"CREATE FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") + + Seq( + (0: Integer) -> "val_0", + (1: Integer) -> "val_1", + (2: Integer) -> null, + (3: Integer) -> null + ).toDF("key", "value").repartition(2).createOrReplaceTempView("t") + } + + test("built-in Hive UDAF") { + val df = sql("SELECT hive_max(key) FROM t GROUP BY key % 2") + + val aggs = df.queryExecution.executedPlan.collect { + case agg: SortAggregateExec => agg + } + + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) + + checkAnswer(df, Seq( + Row(2), + Row(3) + )) + } + + test("customized Hive UDAF") { + val df = sql("SELECT mock(value) FROM t GROUP BY key % 2") + + val aggs = df.queryExecution.executedPlan.collect { + case agg: SortAggregateExec => agg + } + + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) + + checkAnswer(df, Seq( + Row(Row(1, 1)), + Row(Row(1, 1)) + )) + } +} + +class MockUDAF extends AbstractGenericUDAFResolver { + override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator +} + +class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long) + extends GenericUDAFEvaluator.AbstractAggregationBuffer { + + override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2 +} + +class MockUDAFEvaluator extends GenericUDAFEvaluator { + private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + + private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + + private val bufferOI = { + val fieldNames = Seq("nonNullCount", "nullCount").asJava + val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava + ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs) + } + + private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount") + + private val nullCountField = bufferOI.getStructFieldRef("nullCount") + + override def getNewAggregationBuffer: AggregationBuffer = new MockUDAFBuffer(0L, 0L) + + override def reset(agg: AggregationBuffer): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + buffer.nonNullCount = 0L + buffer.nullCount = 0L + } + + override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = bufferOI + + override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + if (parameters.head eq null) { + buffer.nullCount += 1L + } else { + buffer.nonNullCount += 1L + } + } + + override def merge(agg: AggregationBuffer, partial: Object): Unit = { + if (partial ne null) { + val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField)) + val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField)) + val buffer = agg.asInstanceOf[MockUDAFBuffer] + buffer.nonNullCount += nonNullCount + buffer.nullCount += nullCount + } + } + + override def terminatePartial(agg: AggregationBuffer): AnyRef = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long) + } + + override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg) +} From cf639c81dffdd488d0d5e6e93ed52c91f813d38a Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 31 Oct 2016 15:23:28 -0700 Subject: [PATCH 02/10] Fix test failures --- .../org/apache/spark/sql/hive/hiveUDFs.scala | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 943d0be43502c..cdcb2c5526a78 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -23,9 +23,9 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.ql.exec._ -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} @@ -266,8 +266,19 @@ private[hive] case class HiveGenericUDTF( } /** - * Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt - * performance a lot. + * While being evaluated by Spark SQL, the aggregation state of a Hive UDAF may be in the following + * three formats: + * + * 1. a Spark SQL value, or + * 2. an instance of some concrete `GenericUDAFEvaluator.AggregationBuffer` class, or + * 3. a Java object that can be inspected using the `ObjectInspector` returned by the + * `GenericUDAFEvaluator.init()` method. + * + * We may use the following methods to convert the aggregation state back and forth: + * + * - `wrap()`/`wrapperFor()`: from 3 to 1 + * - `unwrap()`/`unwrapperFor()`: from 1 to 3 + * - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3 */ private[hive] case class HiveUDAFFunction( name: String, @@ -305,12 +316,18 @@ private[hive] case class HiveUDAFFunction( private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray @transient - private lazy val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + private lazy val returnInspector = + function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) @transient private lazy val partialResultInspector = function.init(GenericUDAFEvaluator.Mode.PARTIAL1, inspectors) + // The following two lines initializes `function: GenericUDAFEvaluator` eagerly. These two fields + // are declared as `@transient lazy val` only for the purpose of serialization. + returnInspector + partialResultInspector + @transient private lazy val partialResultDataType = inspectorToDataType(partialResultInspector) @@ -340,7 +357,7 @@ private[hive] case class HiveUDAFFunction( override def nullable: Boolean = true - override def supportsPartial: Boolean = false + override def supportsPartial: Boolean = true override lazy val dataType: DataType = inspectorToDataType(returnInspector) @@ -361,7 +378,7 @@ private[hive] case class HiveUDAFFunction( } override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = { - function.merge(buffer, partialResultUnwrapper(input)) + function.merge(buffer, function.terminatePartial(input)) } override def eval(buffer: AggregationBuffer): Any = resultUnwrapper(function.evaluate(buffer)) @@ -404,7 +421,7 @@ private[hive] object HiveUDAFFunction { // workaround here is creating an initial `AggregationBuffer` first and then merge the // deserialized object into the buffer. val buffer = function.getNewAggregationBuffer - val unsafeRow = new UnsafeRow + val unsafeRow = new UnsafeRow(1) unsafeRow.pointTo(bytes, bytes.length) val partialResult = unsafeRow.get(0, partialResultDataType) function.merge(buffer, partialResultWrapper(partialResult)) From 6b4908b851b9492d73d0ceee0cf8d171c4bb203a Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 31 Oct 2016 16:22:49 -0700 Subject: [PATCH 03/10] Comment update --- .../org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 4a3dda635ed52..fcb8882b70a5f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -81,6 +81,9 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } +/** + * A testing Hive UDAF that computes the counts of both non-null values and nulls of a given column. + */ class MockUDAF extends AbstractGenericUDAFResolver { override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator } From f6a080efb71fa26ffe7cd34af6871c013d5616c6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 1 Nov 2016 11:04:12 -0700 Subject: [PATCH 04/10] Remove the @transient lazy val hack --- .../main/scala/org/apache/spark/sql/hive/hiveUDFs.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index cdcb2c5526a78..a29cd0dbc9058 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -316,18 +316,13 @@ private[hive] case class HiveUDAFFunction( private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray @transient - private lazy val returnInspector = + private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) @transient - private lazy val partialResultInspector = + private val partialResultInspector = function.init(GenericUDAFEvaluator.Mode.PARTIAL1, inspectors) - // The following two lines initializes `function: GenericUDAFEvaluator` eagerly. These two fields - // are declared as `@transient lazy val` only for the purpose of serialization. - returnInspector - partialResultInspector - @transient private lazy val partialResultDataType = inspectorToDataType(partialResultInspector) From 689684a58b4e4130c6dac7dc6734ed421e907356 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 1 Nov 2016 15:30:07 -0700 Subject: [PATCH 05/10] Properly initialize Hive UDAF evaluators Hive UDAFs are sensitive to aggregation mode, and must be initialized with proper modes before being used. Basically, it means that you can't use an evaluator initialized with mode PARTIAL1 to merge two aggregation states (although it still works for aggregate functions whose partial result type is the same as the final result type). --- .../org/apache/spark/sql/hive/hiveUDFs.scala | 69 ++++++++++--------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index a29cd0dbc9058..3ffeef1d4302e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -37,7 +37,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.sql.hive.HiveUDAFFunction.AggregationBufferSerDe import org.apache.spark.sql.types._ @@ -304,33 +303,41 @@ private[hive] case class HiveUDAFFunction( } @transient - private lazy val inspectors = children.map(toInspector).toArray + private lazy val inputInspectors = children.map(toInspector).toArray @transient - private lazy val function = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - resolver.getEvaluator(parameterInfo) - } + val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) @transient - private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + private lazy val partial1ModeEvaluator = resolver.getEvaluator(parameterInfo) @transient - private val returnInspector = - function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + private val partialResultInspector = partial1ModeEvaluator.init( + GenericUDAFEvaluator.Mode.PARTIAL1, + inputInspectors + ) @transient - private val partialResultInspector = - function.init(GenericUDAFEvaluator.Mode.PARTIAL1, inspectors) + private lazy val partial2ModeEvaluator = { + val evaluator = resolver.getEvaluator(parameterInfo) + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector)) + evaluator + } @transient - private lazy val partialResultDataType = inspectorToDataType(partialResultInspector) + private lazy val finalModeEvaluator = resolver.getEvaluator(parameterInfo) @transient - private lazy val partialResultUnwrapper = unwrapperFor(partialResultInspector) + private val returnInspector = finalModeEvaluator.init( + GenericUDAFEvaluator.Mode.FINAL, + Array(partialResultInspector) + ) @transient - private lazy val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType) + private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + + @transient + private lazy val partialResultDataType = inspectorToDataType(partialResultInspector) @transient private lazy val resultUnwrapper = unwrapperFor(returnInspector) @@ -342,9 +349,7 @@ private[hive] case class HiveUDAFFunction( private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray @transient - private lazy val bufferSerDe: AggregationBufferSerDe = - new AggregationBufferSerDe( - function, partialResultDataType, partialResultUnwrapper, partialResultWrapper) + private lazy val bufferSerDe: AggregationBufferSerDe = new AggregationBufferSerDe // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our // catalyst type checking framework. @@ -363,20 +368,24 @@ private[hive] case class HiveUDAFFunction( s"$name($distinct${children.map(_.sql).mkString(", ")})" } - override def createAggregationBuffer(): AggregationBuffer = function.getNewAggregationBuffer + override def createAggregationBuffer(): AggregationBuffer = + partial1ModeEvaluator.getNewAggregationBuffer @transient private lazy val inputProjection = new InterpretedProjection(children) override def update(buffer: AggregationBuffer, input: InternalRow): Unit = { - function.iterate(buffer, wrap(inputProjection(input), wrappers, cached, inputDataTypes)) + partial1ModeEvaluator.iterate( + buffer, wrap(inputProjection(input), wrappers, cached, inputDataTypes)) } override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = { - function.merge(buffer, function.terminatePartial(input)) + partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input)) } - override def eval(buffer: AggregationBuffer): Any = resultUnwrapper(function.evaluate(buffer)) + override def eval(buffer: AggregationBuffer): Any = { + resultUnwrapper(finalModeEvaluator.terminate(buffer)) + } override def serialize(buffer: AggregationBuffer): Array[Byte] = { bufferSerDe.serialize(buffer) @@ -385,15 +394,11 @@ private[hive] case class HiveUDAFFunction( override def deserialize(bytes: Array[Byte]): AggregationBuffer = { bufferSerDe.deserialize(bytes) } -} -private[hive] object HiveUDAFFunction { - class AggregationBufferSerDe( - function: GenericUDAFEvaluator, - partialResultDataType: DataType, - partialResultUnwrapper: Any => Any, - partialResultWrapper: Any => Any) - extends HiveInspectors { + private class AggregationBufferSerDe { + private val partialResultUnwrapper = unwrapperFor(partialResultInspector) + + private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType) private val projection = UnsafeProjection.create(Array(partialResultDataType)) @@ -403,7 +408,7 @@ private[hive] object HiveUDAFFunction { // `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object // that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`. // Then we can unwrap it to a Spark SQL value. - mutableRow.update(0, partialResultUnwrapper(function.terminatePartial(buffer))) + mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer))) val unsafeRow = projection(mutableRow) val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes) unsafeRow.writeTo(bytes) @@ -415,11 +420,11 @@ private[hive] object HiveUDAFFunction { // returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The // workaround here is creating an initial `AggregationBuffer` first and then merge the // deserialized object into the buffer. - val buffer = function.getNewAggregationBuffer + val buffer = partial2ModeEvaluator.getNewAggregationBuffer val unsafeRow = new UnsafeRow(1) unsafeRow.pointTo(bytes, bytes.length) val partialResult = unsafeRow.get(0, partialResultDataType) - function.merge(buffer, partialResultWrapper(partialResult)) + partial2ModeEvaluator.merge(buffer, partialResultWrapper(partialResult)) buffer } } From a6206af70f9cfe0d49aca6ae5331bcf8975b2a89 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 1 Nov 2016 15:56:45 -0700 Subject: [PATCH 06/10] Update comments --- .../org/apache/spark/sql/hive/hiveUDFs.scala | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 3ffeef1d4302e..4082edbd2c95a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -302,43 +302,61 @@ private[hive] case class HiveUDAFFunction( funcWrapper.createFunction[AbstractGenericUDAFResolver]() } + // Hive `ObjectInspector`s for all child expressions (input parameters of the function). @transient private lazy val inputInspectors = children.map(toInspector).toArray + // Spark SQL data types of input parameters. @transient - val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + // The UDAF evaluator used to consume raw input rows and produce partial aggregation results. @transient - private lazy val partial1ModeEvaluator = resolver.getEvaluator(parameterInfo) + private lazy val partial1ModeEvaluator = { + val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) + resolver.getEvaluator(parameterInfo) + } + // Hive `ObjectInspector` used to inspect partial aggregation results. @transient private val partialResultInspector = partial1ModeEvaluator.init( GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors ) + // The UDAF evaluator used to merge partial aggregation results. @transient private lazy val partial2ModeEvaluator = { + val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) val evaluator = resolver.getEvaluator(parameterInfo) evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector)) evaluator } + // Spark SQL data type of partial aggregation results @transient - private lazy val finalModeEvaluator = resolver.getEvaluator(parameterInfo) + private lazy val partialResultDataType = inspectorToDataType(partialResultInspector) + // The UDAF evaluator used to compute the final result from a partial aggregation result objects. + @transient + private lazy val finalModeEvaluator = { + val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) + resolver.getEvaluator(parameterInfo) + } + + // Hive `ObjectInspector` used to inspect the final aggregation result object. @transient private val returnInspector = finalModeEvaluator.init( GenericUDAFEvaluator.Mode.FINAL, Array(partialResultInspector) ) + // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format. @transient - private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray - - @transient - private lazy val partialResultDataType = inspectorToDataType(partialResultInspector) + private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + // Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into + // Spark SQL specific format. @transient private lazy val resultUnwrapper = unwrapperFor(returnInspector) @@ -346,10 +364,7 @@ private[hive] case class HiveUDAFFunction( private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) @transient - private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray - - @transient - private lazy val bufferSerDe: AggregationBufferSerDe = new AggregationBufferSerDe + private lazy val aggBufferSerDe: AggregationBufferSerDe = new AggregationBufferSerDe // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our // catalyst type checking framework. @@ -376,7 +391,7 @@ private[hive] case class HiveUDAFFunction( override def update(buffer: AggregationBuffer, input: InternalRow): Unit = { partial1ModeEvaluator.iterate( - buffer, wrap(inputProjection(input), wrappers, cached, inputDataTypes)) + buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes)) } override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = { @@ -388,13 +403,14 @@ private[hive] case class HiveUDAFFunction( } override def serialize(buffer: AggregationBuffer): Array[Byte] = { - bufferSerDe.serialize(buffer) + aggBufferSerDe.serialize(buffer) } override def deserialize(bytes: Array[Byte]): AggregationBuffer = { - bufferSerDe.deserialize(bytes) + aggBufferSerDe.deserialize(bytes) } + // Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects private class AggregationBufferSerDe { private val partialResultUnwrapper = unwrapperFor(partialResultInspector) From 0ab0a06bfcdfb7e491ab15d397b5a8f5f45f9b04 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 1 Nov 2016 17:43:03 -0700 Subject: [PATCH 07/10] Should create testing UDAFs as temporary functions and drop them while cleaning up --- .../apache/spark/sql/hive/execution/HiveUDAFSuite.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index fcb8882b70a5f..1c2de52d2cf4e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -35,8 +35,8 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { import testImplicits._ protected override def beforeAll(): Unit = { - sql(s"CREATE FUNCTION mock AS '${classOf[MockUDAF].getName}'") - sql(s"CREATE FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") + sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'") + sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") Seq( (0: Integer) -> "val_0", @@ -46,6 +46,11 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { ).toDF("key", "value").repartition(2).createOrReplaceTempView("t") } + protected override def afterAll(): Unit = { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock") + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } + test("built-in Hive UDAF") { val df = sql("SELECT hive_max(key) FROM t GROUP BY key % 2") From b418cd75d1edf7e1f48d3112d7c42d6f95985878 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 9 Nov 2016 16:35:03 -0800 Subject: [PATCH 08/10] Minor refactoring --- .../org/apache/spark/sql/hive/hiveUDFs.scala | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 4082edbd2c95a..0f67c7617da57 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -310,13 +310,15 @@ private[hive] case class HiveUDAFFunction( @transient private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray - // The UDAF evaluator used to consume raw input rows and produce partial aggregation results. - @transient - private lazy val partial1ModeEvaluator = { + private def newEvaluator(): GenericUDAFEvaluator = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) resolver.getEvaluator(parameterInfo) } + // The UDAF evaluator used to consume raw input rows and produce partial aggregation results. + @transient + private lazy val partial1ModeEvaluator = newEvaluator() + // Hive `ObjectInspector` used to inspect partial aggregation results. @transient private val partialResultInspector = partial1ModeEvaluator.init( @@ -327,8 +329,7 @@ private[hive] case class HiveUDAFFunction( // The UDAF evaluator used to merge partial aggregation results. @transient private lazy val partial2ModeEvaluator = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) - val evaluator = resolver.getEvaluator(parameterInfo) + val evaluator = newEvaluator() evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector)) evaluator } @@ -339,10 +340,7 @@ private[hive] case class HiveUDAFFunction( // The UDAF evaluator used to compute the final result from a partial aggregation result objects. @transient - private lazy val finalModeEvaluator = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) - resolver.getEvaluator(parameterInfo) - } + private lazy val finalModeEvaluator = newEvaluator() // Hive `ObjectInspector` used to inspect the final aggregation result object. @transient From e88db5c9638678592f0f8836097e88d47558ffc1 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 10 Nov 2016 14:14:50 -0800 Subject: [PATCH 09/10] Fix test failures We're now using ObjectHashAggregateExec instead of SortAggregateExec to evaluate Hive UDAFs. --- .../org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 1c2de52d2cf4e..c637e93c7a230 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.execution.aggregate.SortAggregateExec +import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -55,7 +55,7 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { val df = sql("SELECT hive_max(key) FROM t GROUP BY key % 2") val aggs = df.queryExecution.executedPlan.collect { - case agg: SortAggregateExec => agg + case agg: ObjectHashAggregateExec => agg } // There should be two aggregate operators, one for partial aggregation, and the other for @@ -72,7 +72,7 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { val df = sql("SELECT mock(value) FROM t GROUP BY key % 2") val aggs = df.queryExecution.executedPlan.collect { - case agg: SortAggregateExec => agg + case agg: ObjectHashAggregateExec => agg } // There should be two aggregate operators, one for partial aggregation, and the other for From ca3978c7f368225d349d91477ca13aa2d5b9a3fa Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 16 Nov 2016 11:17:29 -0800 Subject: [PATCH 10/10] Address PR comments --- .../org/apache/spark/sql/hive/hiveUDFs.scala | 46 ++++++++++++++----- .../sql/hive/execution/HiveUDAFSuite.scala | 12 ++--- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 0f67c7617da57..32edd4aec2865 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -268,11 +268,27 @@ private[hive] case class HiveGenericUDTF( * While being evaluated by Spark SQL, the aggregation state of a Hive UDAF may be in the following * three formats: * - * 1. a Spark SQL value, or - * 2. an instance of some concrete `GenericUDAFEvaluator.AggregationBuffer` class, or - * 3. a Java object that can be inspected using the `ObjectInspector` returned by the + * 1. An instance of some concrete `GenericUDAFEvaluator.AggregationBuffer` class + * + * This is the native Hive representation of an aggregation state. Hive `GenericUDAFEvaluator` + * methods like `iterate()`, `merge()`, `terminatePartial()`, and `terminate()` use this format. + * We call these methods to evaluate Hive UDAFs. + * + * 2. A Java object that can be inspected using the `ObjectInspector` returned by the * `GenericUDAFEvaluator.init()` method. * + * Hive uses this format to produce a serializable aggregation state so that it can shuffle + * partial aggregation results. Whenever we need to convert a Hive `AggregationBuffer` instance + * into a Spark SQL value, we have to convert it to this format first and then do the conversion + * with the help of `ObjectInspector`s. + * + * 3. A Spark SQL value + * + * We use this format for serializing Hive UDAF aggregation states on Spark side. To be more + * specific, we convert `AggregationBuffer`s into equivalent Spark SQL values, write them into + * `UnsafeRow`s, and then retrieve the byte array behind those `UnsafeRow`s as serialization + * results. + * * We may use the following methods to convert the aggregation state back and forth: * * - `wrap()`/`wrapperFor()`: from 3 to 1 @@ -294,14 +310,6 @@ private[hive] case class HiveUDAFFunction( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - @transient - private lazy val resolver = - if (isUDAFBridgeRequired) { - new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) - } else { - funcWrapper.createFunction[AbstractGenericUDAFResolver]() - } - // Hive `ObjectInspector`s for all child expressions (input parameters of the function). @transient private lazy val inputInspectors = children.map(toInspector).toArray @@ -311,6 +319,12 @@ private[hive] case class HiveUDAFFunction( private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray private def newEvaluator(): GenericUDAFEvaluator = { + val resolver = if (isUDAFBridgeRequired) { + new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) + } else { + funcWrapper.createFunction[AbstractGenericUDAFResolver]() + } + val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) resolver.getEvaluator(parameterInfo) } @@ -385,7 +399,7 @@ private[hive] case class HiveUDAFFunction( partial1ModeEvaluator.getNewAggregationBuffer @transient - private lazy val inputProjection = new InterpretedProjection(children) + private lazy val inputProjection = UnsafeProjection.create(children) override def update(buffer: AggregationBuffer, input: InternalRow): Unit = { partial1ModeEvaluator.iterate( @@ -393,6 +407,10 @@ private[hive] case class HiveUDAFFunction( } override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = { + // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation + // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts + // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and + // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion. partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input)) } @@ -401,10 +419,14 @@ private[hive] case class HiveUDAFFunction( } override def serialize(buffer: AggregationBuffer): Array[Byte] = { + // Serializes an `AggregationBuffer` that holds partial aggregation results so that we can + // shuffle it for global aggregation later. aggBufferSerDe.serialize(buffer) } override def deserialize(bytes: Array[Byte]): AggregationBuffer = { + // Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare + // for global aggregation by merging multiple partial aggregation results within a single group. aggBufferSerDe.deserialize(bytes) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index c637e93c7a230..c9ef72ee112cf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -52,7 +52,7 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("built-in Hive UDAF") { - val df = sql("SELECT hive_max(key) FROM t GROUP BY key % 2") + val df = sql("SELECT key % 2, hive_max(key) FROM t GROUP BY key % 2") val aggs = df.queryExecution.executedPlan.collect { case agg: ObjectHashAggregateExec => agg @@ -63,13 +63,13 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { assert(aggs.length == 2) checkAnswer(df, Seq( - Row(2), - Row(3) + Row(0, 2), + Row(1, 3) )) } test("customized Hive UDAF") { - val df = sql("SELECT mock(value) FROM t GROUP BY key % 2") + val df = sql("SELECT key % 2, mock(value) FROM t GROUP BY key % 2") val aggs = df.queryExecution.executedPlan.collect { case agg: ObjectHashAggregateExec => agg @@ -80,8 +80,8 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { assert(aggs.length == 2) checkAnswer(df, Seq( - Row(Row(1, 1)), - Row(Row(1, 1)) + Row(0, Row(1, 1)), + Row(1, Row(1, 1)) )) } }