From 6b90328d7c654e8d6eba8e217392ba3792e4e591 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Wed, 13 Feb 2019 17:38:51 -0600 Subject: [PATCH 1/6] [SPARK-24935] : Problem with Executing Hive UDF's from Spark 2.2 Onwards Created new abstract class HiveTypedImperativeAggregate which is a framework for hive related aggregation functions. Also, there seems to be a bug in SortBasedAggregator where it was calling merge on aggregate buffer without initializing them. Have fixed it in this PR. --- .../expressions/aggregate/interfaces.scala | 126 +++++++++++++++++- .../aggregate/ObjectAggregationIterator.scala | 30 +++-- .../org/apache/spark/sql/hive/hiveUDFs.scala | 5 +- 3 files changed, 144 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 56c2ee6b53fe5..4e3022830143e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -524,7 +524,125 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { /** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */ def deserialize(storageFormat: Array[Byte]): T + override def initialize(buffer: InternalRow): Unit = { + buffer(mutableAggBufferOffset) = createAggregationBuffer() + } + + override def update(buffer: InternalRow, input: InternalRow): Unit = { + buffer(mutableAggBufferOffset) = update(getBufferObject(buffer), input) + } + + override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = { + val bufferObject = getBufferObject(buffer) + // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate + val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset)) + buffer(mutableAggBufferOffset) = merge(bufferObject, inputObject) + } + + override def eval(buffer: InternalRow): Any = { + eval(getBufferObject(buffer)) + } + + private[this] val anyObjectType = ObjectType(classOf[AnyRef]) + private def getBufferObject(bufferRow: InternalRow): T = { + bufferRow.get(mutableAggBufferOffset, anyObjectType).asInstanceOf[T] + } + + override lazy val aggBufferAttributes: Seq[AttributeReference] = { + // Underlying storage type for the aggregation buffer object + Seq(AttributeReference("buf", BinaryType)()) + } + + override lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + /** + * In-place replaces the aggregation buffer object stored at buffer's index + * `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format + * (BinaryType). + * + * This is only called when doing Partial or PartialMerge mode aggregation, before the framework + * shuffle out aggregate buffers. + */ + def serializeAggregateBufferInPlace(buffer: InternalRow): Unit = { + buffer(mutableAggBufferOffset) = serialize(getBufferObject(buffer)) + } +} + +/** + * Aggregation function which allows **arbitrary** user-defined java object to be used as internal + * aggregation buffer for Hive. + */ +abstract class HiveTypedImperativeAggregate[T] extends TypedImperativeAggregate[T] { + + /** + * Creates an empty aggregation buffer object for partial 1 mode. This is called + * before processing each key group(group by key). + * + * @return an aggregation buffer object + */ + def createAggregationBuffer(): T + + /** + * Creates an empty aggregation buffer object for partial 2 mode. + * + * @return an aggregation buffer object + */ + def createPartial2ModeAggregationBuffer(): T + + var partial2ModeBuffer: InternalRow = _ + + /** + * Updates the aggregation buffer object with an input row and returns a new buffer object. For + * performance, the function may do in-place update and return it instead of constructing new + * buffer object. + * + * This is typically called when doing Partial or Complete mode aggregation. + * + * @param buffer The aggregation buffer object. + * @param input an input row + */ + def update(buffer: T, input: InternalRow): T + + /** + * Merges an input aggregation object into aggregation buffer object and returns a new buffer + * object. For performance, the function may do in-place merge and return it instead of + * constructing new buffer object. + * + * This is typically called when doing PartialMerge or Final mode aggregation. + * + * @param buffer the aggregation buffer object used to store the aggregation result. + * @param input an input aggregation object. Input aggregation object can be produced by + * de-serializing the partial aggregate's output from Mapper side. + */ + def merge(buffer: T, input: T): T + + /** + * Generates the final aggregation result value for current key group with the aggregation buffer + * object. + * + * Developer note: the only return types accepted by Spark are: + * - primitive types + * - InternalRow and subclasses + * - ArrayData + * - MapData + * + * @param buffer aggregation buffer object. + * @return The aggregation result of current key group + */ + def eval(buffer: T): Any + + /** Serializes the aggregation buffer object T to Array[Byte] */ + def serialize(buffer: T): Array[Byte] + + /** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */ + def deserialize(storageFormat: Array[Byte]): T + final override def initialize(buffer: InternalRow): Unit = { + partial2ModeBuffer = buffer.copy() + partial2ModeBuffer(mutableAggBufferOffset) = createPartial2ModeAggregationBuffer() buffer(mutableAggBufferOffset) = createAggregationBuffer() } @@ -533,14 +651,14 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { } final override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = { - val bufferObject = getBufferObject(buffer) + val bufferObject = getBufferObject(partial2ModeBuffer) // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset)) - buffer(mutableAggBufferOffset) = merge(bufferObject, inputObject) + partial2ModeBuffer(mutableAggBufferOffset) = merge(bufferObject, inputObject) } final override def eval(buffer: InternalRow): Any = { - eval(getBufferObject(buffer)) + eval(getBufferObject(partial2ModeBuffer)) } private[this] val anyObjectType = ObjectType(classOf[AnyRef]) @@ -566,7 +684,7 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { * This is only called when doing Partial or PartialMerge mode aggregation, before the framework * shuffle out aggregate buffers. */ - final def serializeAggregateBufferInPlace(buffer: InternalRow): Unit = { + final override def serializeAggregateBufferInPlace(buffer: InternalRow): Unit = { buffer(mutableAggBufferOffset) = serialize(getBufferObject(buffer)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 43514f5271ac8..9f91fc0b689ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -59,7 +59,8 @@ class ObjectAggregationIterator( private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _ // Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers - private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = { + var (sortBasedAggExpressions, sortBasedAggFunctions): ( + Seq[AggregateExpression], Array[AggregateFunction]) = { val newExpressions = aggregateExpressions.map { case agg @ AggregateExpression(_, Partial, _, _) => agg.copy(mode = PartialMerge) @@ -67,9 +68,12 @@ class ObjectAggregationIterator( agg.copy(mode = Final) case other => other } - val newFunctions = initializeAggregateFunctions(newExpressions, 0) - val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) - generateProcessRow(newExpressions, newFunctions, newInputAttributes) + (newExpressions, initializeAggregateFunctions(newExpressions, 0)) + } + + private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = { + val newInputAttributes = sortBasedAggFunctions.flatMap(_.inputAggBufferAttributes) + generateProcessRow(sortBasedAggExpressions, sortBasedAggFunctions, newInputAttributes) } /** @@ -93,7 +97,7 @@ class ObjectAggregationIterator( */ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { if (groupingExpressions.isEmpty) { - val defaultAggregationBuffer = createNewAggregationBuffer() + val defaultAggregationBuffer = createNewAggregationBuffer(aggregateFunctions) generateOutput(UnsafeRow.createFromByteArray(0, 0), defaultAggregationBuffer) } else { throw new IllegalStateException( @@ -106,18 +110,20 @@ class ObjectAggregationIterator( // // - when creating aggregation buffer for a new group in the hash map, and // - when creating the re-used buffer for sort-based aggregation - private def createNewAggregationBuffer(): SpecificInternalRow = { - val bufferFieldTypes = aggregateFunctions.flatMap(_.aggBufferAttributes.map(_.dataType)) + private def createNewAggregationBuffer( + functions: Array[AggregateFunction]): SpecificInternalRow = { + val bufferFieldTypes = functions.flatMap(_.aggBufferAttributes.map(_.dataType)) val buffer = new SpecificInternalRow(bufferFieldTypes) - initAggregationBuffer(buffer) + initAggregationBuffer(buffer, functions) buffer } - private def initAggregationBuffer(buffer: SpecificInternalRow): Unit = { + private def initAggregationBuffer( + buffer: SpecificInternalRow, functions: Array[AggregateFunction]): Unit = { // Initializes declarative aggregates' buffer values expressionAggInitialProjection.target(buffer)(EmptyRow) // Initializes imperative aggregates' buffer values - aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) + functions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) } private def getAggregationBufferByKey( @@ -125,7 +131,7 @@ class ObjectAggregationIterator( var aggBuffer = hashMap.getAggregationBuffer(groupingKey) if (aggBuffer == null) { - aggBuffer = createNewAggregationBuffer() + aggBuffer = createNewAggregationBuffer(aggregateFunctions) hashMap.putAggregationBuffer(groupingKey.copy(), aggBuffer) } @@ -183,7 +189,7 @@ class ObjectAggregationIterator( StructType.fromAttributes(groupingAttributes), processRow, mergeAggregationBuffers, - createNewAggregationBuffer()) + createNewAggregationBuffer(sortBasedAggFunctions)) while (inputRows.hasNext) { // NOTE: The input row is always UnsafeRow diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 4a8450901e3a7..77cd651449a21 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -311,7 +311,7 @@ private[hive] case class HiveUDAFFunction( isUDAFBridgeRequired: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] + extends HiveTypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors with UserDefinedExpression { @@ -404,6 +404,9 @@ private[hive] case class HiveUDAFFunction( override def createAggregationBuffer(): AggregationBuffer = partial1HiveEvaluator.evaluator.getNewAggregationBuffer + override def createPartial2ModeAggregationBuffer(): AggregationBuffer = + partial2ModeEvaluator.getNewAggregationBuffer + @transient private lazy val inputProjection = UnsafeProjection.create(children) From 4cbdf27fa3753cb3bd42d80713b098c5c132e1f0 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Thu, 14 Feb 2019 19:50:43 -0600 Subject: [PATCH 2/6] [SPARK-24935] : Fixing Unit tests --- .../spark/sql/catalyst/expressions/aggregate/interfaces.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 4e3022830143e..e9862520eb3d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -648,9 +648,11 @@ abstract class HiveTypedImperativeAggregate[T] extends TypedImperativeAggregate[ final override def update(buffer: InternalRow, input: InternalRow): Unit = { buffer(mutableAggBufferOffset) = update(getBufferObject(buffer), input) + partial2ModeBuffer = buffer.copy() } final override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = { + partial2ModeBuffer(mutableAggBufferOffset) = createPartial2ModeAggregationBuffer() val bufferObject = getBufferObject(partial2ModeBuffer) // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset)) From 21371d34a4ee7db7e6c8c7205111845005429d72 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Fri, 15 Feb 2019 09:48:45 -0600 Subject: [PATCH 3/6] [SPARK-24935] : Removing redundant code line that was failing unit tests --- .../spark/sql/catalyst/expressions/aggregate/interfaces.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index e9862520eb3d8..820408e339a98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -652,7 +652,6 @@ abstract class HiveTypedImperativeAggregate[T] extends TypedImperativeAggregate[ } final override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = { - partial2ModeBuffer(mutableAggBufferOffset) = createPartial2ModeAggregationBuffer() val bufferObject = getBufferObject(partial2ModeBuffer) // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset)) From 7253983c15153822e6b447de69976dc99ba19fa3 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Fri, 22 Feb 2019 18:09:48 -0600 Subject: [PATCH 4/6] [SPARK-24935] : Adding Unit Tests --- .../sql/hive/execution/HiveUDAFSuite.scala | 196 ++++++++++++++++++ 1 file changed, 196 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index fe3deceb08067..a78abcceecb5a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo +import org.apache.spark.SparkException import test.org.apache.spark.sql.MyDoubleAvg import org.apache.spark.sql.{AnalysisException, QueryTest, Row} @@ -40,6 +41,8 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { super.beforeAll() sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'") sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") + sql(s"CREATE TEMPORARY FUNCTION mock2 AS '${classOf[MockUDAF2].getName}'") + sql(s"CREATE TEMPORARY FUNCTION mock3 AS '${classOf[MockUDAF3].getName}'") Seq( (0: Integer) -> "val_0", @@ -92,6 +95,42 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { )) } + test("customized Hive UDAF2") { + intercept[SparkException] { + val df = sql("SELECT key % 2, mock2(value) FROM t GROUP BY key % 2") + + val aggs = df.queryExecution.executedPlan.collect { + case agg: ObjectHashAggregateExec => agg + } + + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) + + checkAnswer(df, Seq( + Row(0, Row(1, 1)), + Row(1, Row(1, 1)) + )) + } + } + + test("customized Hive UDAF3") { + val df = sql("SELECT key % 2, mock3(value) FROM t GROUP BY key % 2") + + val aggs = df.queryExecution.executedPlan.collect { + case agg: ObjectHashAggregateExec => agg + } + + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) + + checkAnswer(df, Seq( + Row(0, Row(1, 1)), + Row(1, Row(1, 1)) + )) + } + test("call JAVA UDAF") { withTempView("temp") { withUserDefinedFunction("myDoubleAvg" -> false) { @@ -127,12 +166,26 @@ class MockUDAF extends AbstractGenericUDAFResolver { override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator } +class MockUDAF2 extends AbstractGenericUDAFResolver { + override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator2 +} + +class MockUDAF3 extends AbstractGenericUDAFResolver { + override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator3 +} + class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long) extends GenericUDAFEvaluator.AbstractAggregationBuffer { override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2 } +class MockUDAFBuffer2(var nonNullCount: Long, var nullCount: Long) + extends GenericUDAFEvaluator.AbstractAggregationBuffer { + + override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2 +} + class MockUDAFEvaluator extends GenericUDAFEvaluator { private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector @@ -184,3 +237,146 @@ class MockUDAFEvaluator extends GenericUDAFEvaluator { override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg) } + +// Same as MockUDAFEvaluator but using two aggregation buffers, one for PARTIAL1 and the other +// for PARTIAL2. This will throw an Exception. +class MockUDAFEvaluator2 extends GenericUDAFEvaluator { + private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + + private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + private var aggMode: Mode = null + + private val bufferOI = { + val fieldNames = Seq("nonNullCount", "nullCount").asJava + val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava + ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs) + } + + private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount") + + private val nullCountField = bufferOI.getStructFieldRef("nullCount") + + override def getNewAggregationBuffer: AggregationBuffer = { + var aggBuffer: AggregationBuffer = null + if (aggMode == Mode.PARTIAL1) { + aggBuffer = new MockUDAFBuffer(0L, 0L) + } else if (aggMode == Mode.PARTIAL2) { + aggBuffer = new MockUDAFBuffer2(0L, 0L) + } + aggBuffer + } + + override def reset(agg: AggregationBuffer): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + buffer.nonNullCount = 0L + buffer.nullCount = 0L + } + + override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = { + aggMode = mode + bufferOI + } + + override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + if (parameters.head eq null) { + buffer.nullCount += 1L + } else { + buffer.nonNullCount += 1L + } + } + + override def merge(agg: AggregationBuffer, partial: Object): Unit = { + if (partial ne null) { + val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField)) + val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField)) + val buffer = agg.asInstanceOf[MockUDAFBuffer] + buffer.nonNullCount += nonNullCount + buffer.nullCount += nullCount + } + } + + override def terminatePartial(agg: AggregationBuffer): AnyRef = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long) + } + + override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg) +} + +// This class implements the Hive UDAF contract for partial aggregation. +class MockUDAFEvaluator3 extends GenericUDAFEvaluator { + private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + + private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + private var aggMode: Mode = null + + private val bufferOI = { + val fieldNames = Seq("nonNullCount", "nullCount").asJava + val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava + ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs) + } + + private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount") + + private val nullCountField = bufferOI.getStructFieldRef("nullCount") + + override def getNewAggregationBuffer: AggregationBuffer = { + var aggBuffer: AggregationBuffer = null + if (aggMode == Mode.PARTIAL1) { + aggBuffer = new MockUDAFBuffer(0L, 0L) + } else if (aggMode == Mode.PARTIAL2) { + aggBuffer = new MockUDAFBuffer2(0L, 0L) + } + aggBuffer + } + + override def reset(agg: AggregationBuffer): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + buffer.nonNullCount = 0L + buffer.nullCount = 0L + } + + override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = { + aggMode = mode + bufferOI + } + + override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + if (parameters.head eq null) { + buffer.nullCount += 1L + } else { + buffer.nonNullCount += 1L + } + } + + override def merge(agg: AggregationBuffer, partial: Object): Unit = { + if (partial ne null) { + val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField)) + val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField)) + val buffer = agg.asInstanceOf[MockUDAFBuffer2] + buffer.nonNullCount += nonNullCount + buffer.nullCount += nullCount + } + } + + // As this method is called for both states, Partial1 and Partial2, the hack in the method + // to check for class of aggregation buffer was necessary. + override def terminatePartial(agg: AggregationBuffer): AnyRef = { + var result: AnyRef = null + if (agg.getClass.toString.contains("MockUDAFBuffer2")) { + val buffer = agg.asInstanceOf[MockUDAFBuffer2] + result = Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long) + } else { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + result = Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long) + } + result + } + + override def terminate(agg: AggregationBuffer): AnyRef = { + val buffer = agg.asInstanceOf[MockUDAFBuffer2] + Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long) + } +} From fdfd67c8add9f57ff7bb9699b259f4fbd115e624 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Fri, 22 Feb 2019 18:25:32 -0600 Subject: [PATCH 5/6] [SPARK-24935] : Fixing Scalastyle Tests --- .../org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index a78abcceecb5a..63680e8812ce8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -26,9 +26,10 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo -import org.apache.spark.SparkException import test.org.apache.spark.sql.MyDoubleAvg +import org.apache.spark.SparkException + import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec import org.apache.spark.sql.hive.test.TestHiveSingleton From bd57d0f8523546ccc8de604ce00b8f5c9280bbdc Mon Sep 17 00:00:00 2001 From: pgandhi Date: Fri, 22 Feb 2019 18:59:22 -0600 Subject: [PATCH 6/6] [SPARK-24935] : Removing Empty Line --- .../org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 63680e8812ce8..f5a96568a2ac8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -29,7 +29,6 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo import test.org.apache.spark.sql.MyDoubleAvg import org.apache.spark.SparkException - import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec import org.apache.spark.sql.hive.test.TestHiveSingleton