From 6be2ebe98b667e57511b9623062d54f43b04e617 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Sat, 20 Aug 2016 00:34:56 +0800 Subject: [PATCH 1/6] percentile approximate --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../aggregate/ApproximatePercentile.scala | 302 +++++++++++++++++ .../ApproximatePercentileSuite.scala | 317 ++++++++++++++++++ .../sql/ApproximatePercentileQuerySuite.scala | 226 +++++++++++++ .../spark/sql/hive/HiveSessionCatalog.scala | 3 +- .../sql/catalyst/ExpressionToSQLSuite.scala | 1 + 6 files changed, 848 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 35fd800df4a4f..b05f4f61f6a3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -250,6 +250,7 @@ object FunctionRegistry { expression[Average]("mean"), expression[Min]("min"), expression[Skewness]("skewness"), + expression[ApproximatePercentile]("percentile_approx"), expression[StddevSamp]("std"), expression[StddevSamp]("stddev"), expression[StddevPop]("stddev_pop"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala new file mode 100644 index 0000000000000..1d6d2cd7b6d7b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflectionLock} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest, PercentileDigestSerializer} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.QuantileSummaries +import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats} +import org.apache.spark.sql.types._ + +/** + * The ApproximatePercentile function returns the approximate percentile(s) of a column at the given + * percentage(s). A percentile is a watermark value below which a given percentage of the column + * values fall. For example, the percentile of column `col` at percentage 50% is the median of + * column `col`. + * + * This function supports partial aggregation. + * + * @param child child expression that can produce column value with `child.eval(inputRow)` + * @param percentageExpression Expression that represents a single percentage value or + * a array of percentage values. Each percentage value must be between + * 0.0 and 1.0. + * @param accuracyExpression Integer literal expression of approximation accuracy. Higher value + * yield better accuracy, the default value is + * DEFAULT_PERCENTILE_ACCURACY. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(col, percentage [, accuracy]) - Returns the approximate percentile value of numeric + column `col` at the given percentage. The value of percentage must be between 0.0 + and 1.0. The `accuracy` parameter (default: 10000) is a positive integer literal which + controls approximation accuracy at the cost of memory. Higher value of `accuracy` yield + better accuracy, `1.0/accuracy` is the relative error of the approximation. + + _FUNC_(col, array(percentage1 [, percentage2]...) [, accuracy]) - Returns the approximate + percentile array of column `col` at the given percentage array. Each value of the + percentage array must be between 0.0 and 1.0. The `accuracy` parameter (default: 10000) is + a positive integer literal which controls approximation accuracy at the cost of memory. + Higher value of `accuracy` yield better accuracy, `1.0/accuracy` is the relative error of + the approximation. + """) +case class ApproximatePercentile( + child: Expression, + percentageExpression: Expression, + accuracyExpression: Expression, + override val mutableAggBufferOffset: Int, + override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] { + + def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = { + this(child, percentageExpression, accuracyExpression, 0, 0) + } + + def this(child: Expression, percentageExpression: Expression) = { + this(child, percentageExpression, Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)) + } + + // Mark as lazy so that accuracyExpression is not evaluated during tree transformation. + private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int] + + private lazy val serializer: PercentileDigestSerializer = new PercentileDigestSerializer + + override def inputTypes: Seq[AbstractDataType] = { + Seq(DoubleType, TypeCollection(DoubleType, ArrayType), IntegerType) + } + + // Mark as lazy so that percentageExpression is not evaluated during tree transformation. + private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) = { + (percentageExpression.dataType, percentageExpression.eval()) match { + // Rule ImplicitTypeCasts can cast other numeric types to double + case (_, num: Double) => (false, Array(num)) + case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => + val numericArray = arrayData.toArray(baseType)(baseType.classTag) + (true, numericArray.map(baseType.numeric.toDouble)) + case other => + throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage") + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (accuracy <= 0) { + TypeCheckFailure( + s"The accuracy provided must be a positive integer literal (current value = $accuracy)") + } else if (percentages.exists(percentage => percentage < 0.0D || percentage > 1.0D)) { + TypeCheckFailure( + s"All percentage values must be between 0.0 and 1.0 " + + s"(current = ${percentages.mkString(", ")})") + } else { + TypeCheckSuccess + } + } + + override def createAggregationBuffer(): PercentileDigest = { + val relativeError = 1.0D / accuracy + new PercentileDigest(relativeError) + } + + override def update(buffer: PercentileDigest, inputRow: InternalRow): Unit = { + val value = child.eval(inputRow) + // Ignore empty rows, for example: percentile_approx(null) + if (value != null) { + buffer.add(value.asInstanceOf[Double]) + } + } + + override def merge(buffer: PercentileDigest, other: PercentileDigest): Unit = { + buffer.merge(other) + } + + override def eval(buffer: PercentileDigest): Any = { + val result = buffer.getPercentiles(percentages) + if (result.length == 0) { + null + } else if (returnPercentileArray) { + new GenericArrayData(result) + } else { + result(0) + } + } + + override def withNewMutableAggBufferOffset(newOffset: Int): ApproximatePercentile = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): ApproximatePercentile = + copy(inputAggBufferOffset = newOffset) + + override def children: Seq[Expression] = Seq(child, percentageExpression, accuracyExpression) + + // Returns null for empty inputs + override def nullable: Boolean = true + + override def dataType: DataType = { + if (returnPercentileArray) ArrayType(DoubleType) else DoubleType + } + + override def prettyName: String = "percentile_approx" + + override def serialize(obj: PercentileDigest): Array[Byte] = { + serializer.serialize(obj) + } + + override def deserialize(bytes: Array[Byte]): PercentileDigest = { + serializer.deserialize(bytes) + } +} + +object ApproximatePercentile { + + // Default accuracy of Percentile approximation. Larger value means better accuracy. + // The default relative error can be deduced by defaultError = 1.0 / DEFAULT_PERCENTILE_ACCURACY + val DEFAULT_PERCENTILE_ACCURACY: Int = 10000 + + /** + * PercentileDigest is a probabilistic data structure used for approximating percentiles + * with limited memory. PercentileDigest is backed by [[QuantileSummaries]]. + * + * @param summaries underlying probabilistic data structure [[QuantileSummaries]]. + * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the + * underlying quantileSummaries is compressed. + */ + class PercentileDigest( + private var summaries: QuantileSummaries, + private var isCompressed: Boolean) { + + // Trigger compression if the QuantileSummaries's buffer length exceeds + // compressThresHoldBufferLength. The buffer length can be get by + // quantileSummaries.sampled.length + private[this] final val compressThresHoldBufferLength: Int = { + // Max buffer length after compression. + val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2 + // A safe upper bound for buffer length before compression + maxBufferLengthAfterCompression * 2 + } + + def this(relativeError: Double) = { + this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true) + } + + /** Returns compressed object of [[QuantileSummaries]] */ + def quantileSummaries: QuantileSummaries = { + if (!isCompressed) compress() + summaries + } + + /** Insert an observation value into the PercentileDigest data structure. */ + def add(value: Double): Unit = { + summaries = summaries.insert(value) + // The result of QuantileSummaries.insert is un-compressed + isCompressed = false + + // Currently, QuantileSummaries ignores the construction parameter compressThresHold, + // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here + // to make sure QuantileSummaries doesn't occupy infinite memory. + // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold + if (summaries.sampled.length >= compressThresHoldBufferLength) compress() + } + + /** In-place merges in another PercentileDigest. */ + def merge(other: PercentileDigest): Unit = { + if (!isCompressed) compress() + summaries = summaries.merge(other.quantileSummaries) + } + + /** + * Returns the approximate percentiles of all observation values at the given percentages. + * A percentile is a watermark value below which a given percentage of observation values fall. + * For example, the following code returns the 25th, median, and 75th percentiles of + * all observation values: + * + * {{{ + * val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75)) + * }}} + */ + def getPercentiles(percentages: Array[Double]): Array[Double] = { + if (!isCompressed) compress() + if (summaries.count == 0 || percentages.length == 0) { + Array.empty[Double] + } else { + val result = new Array[Double](percentages.length) + var i = 0 + while (i < percentages.length) { + result(i) = summaries.query(percentages(i)) + i += 1 + } + result + } + } + + private final def compress(): Unit = { + summaries = summaries.compress() + isCompressed = true + } + } + + /** + * Serializer for class [[PercentileDigest]] + * + * This class is NOT thread safe because usage of ExpressionEncoder is not threadsafe. + */ + class PercentileDigestSerializer { + + // In Scala 2.10, the creation of TypeTag is not thread safe. We need to use ScalaReflectionLock + // to protect the creation of this encoder. See SI-6240 for more details. + private[this] final val serializer = ScalaReflectionLock.synchronized { + ExpressionEncoder[QuantileSummariesData].resolveAndBind() + } + + // There are 4 fields in QuantileSummariesData + private[this] final val row = new UnsafeRow(4) + + final def serialize(obj: PercentileDigest): Array[Byte] = { + val data = new QuantileSummariesData(obj.quantileSummaries) + serializer.toRow(data).asInstanceOf[UnsafeRow].getBytes + } + + final def deserialize(bytes: Array[Byte]): PercentileDigest = { + row.pointTo(bytes, bytes.length) + val quantileSummaries = serializer.fromRow(row).toQuantileSummaries + new PercentileDigest(quantileSummaries, isCompressed = true) + } + } + + // An case class to wrap fields of QuantileSummaries, so that we can use the expression encoder + // to serialize it. + case class QuantileSummariesData( + val compressThreshold: Int, + val relativeError: Double, + val sampled: Array[Stats] = Array.empty, + val count: Long = 0L) { + def this(summary: QuantileSummaries) = { + this(summary.compressThreshold, summary.relativeError, summary.sampled, summary.count) + } + + def toQuantileSummaries: QuantileSummaries = { + new QuantileSummaries(compressThreshold, relativeError, sampled, count) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala new file mode 100644 index 0000000000000..8cc3d1de82e39 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericMutableRow, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest, PercentileDigestSerializer} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.QuantileSummaries +import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats +import org.apache.spark.sql.types.{ArrayType, DoubleType, IntegerType} +import org.apache.spark.util.SizeEstimator + + +class ApproximatePercentileSuite extends SparkFunSuite { + + private val random = new java.util.Random() + + private val data = (0 until 10000).map { _ => + random.nextInt(10000) + } + + test("serialize and de-serialize") { + + val serializer = new PercentileDigestSerializer + + // Check empty serialize and de-serialize + val emptyBuffer = new PercentileDigest(relativeError = 0.01) + assert(compareEquals(emptyBuffer, serializer.deserialize(serializer.serialize(emptyBuffer)))) + + val buffer = new PercentileDigest(relativeError = 0.01) + (1 to 100).foreach { value => + buffer.add(value) + } + assert(compareEquals(buffer, serializer.deserialize(serializer.serialize(buffer)))) + + val agg = new ApproximatePercentile(BoundReference(0, DoubleType, true), Literal(0.5)) + assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) + } + + test("class PercentileDigest, basic operations") { + val valueCount = 10000 + val percentages = Array(0.25, 0.5, 0.75) + Seq(0.0001, 0.001, 0.01, 0.1).foreach { relativeError => + val buffer = new PercentileDigest(relativeError) + (1 to valueCount).grouped(10).foreach { group => + val partialBuffer = new PercentileDigest(relativeError) + group.foreach(x => partialBuffer.add(x)) + buffer.merge(partialBuffer) + } + val expectedPercentiles = percentages.map(_ * valueCount) + val approxPercentiles = buffer.getPercentiles(Array(0.25, 0.5, 0.75)) + expectedPercentiles.zip(approxPercentiles).foreach { pair => + val (expected, estimate) = pair + assert((estimate - expected) / valueCount <= relativeError) + } + } + } + + test("class PercentileDigest, makes sure the memory foot print is bounded") { + val relativeError = 0.01 + val memoryFootPrintUpperBound = { + val headBufferSize = + SizeEstimator.estimate(new Array[Double](QuantileSummaries.defaultHeadSize)) + val bufferSize = SizeEstimator.estimate(new Stats(0, 0, 0)) * (1 / relativeError) * 2 + // A safe upper bound + (headBufferSize + bufferSize) * 2 + } + + val sizePerInputs = Seq(100, 1000, 10000, 100000, 1000000, 10000000).map { count => + val buffer = new PercentileDigest(relativeError) + // Worst case, data is linear sorted + (0 until count).foreach(buffer.add(_)) + assert(SizeEstimator.estimate(buffer) < memoryFootPrintUpperBound) + } + } + + test("class ApproximatePercentile, high level interface, update, merge, eval...") { + val count = 10000 + val data = (1 until 10000).toSeq + val percentages = Array(0.25D, 0.5D, 0.75D) + val accuracy = 10000 + val expectedPercentiles = percentages.map(count * _) + val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType) + val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_))) + val accuracyExpression = Literal(10000) + val agg = new ApproximatePercentile(childExpression, percentageExpression, accuracyExpression) + + assert(agg.nullable) + val group1 = (0 until data.length / 2) + val group1Buffer = agg.createAggregationBuffer() + group1.foreach { index => + val input = InternalRow(data(index)) + agg.update(group1Buffer, input) + } + + val group2 = (data.length / 2 until data.length) + val group2Buffer = agg.createAggregationBuffer() + group2.foreach { index => + val input = InternalRow(data(index)) + agg.update(group2Buffer, input) + } + + val mergeBuffer = agg.createAggregationBuffer() + agg.merge(mergeBuffer, group1Buffer) + agg.merge(mergeBuffer, group2Buffer) + + agg.eval(mergeBuffer) match { + case arrayData: ArrayData => + val error = count / accuracy + val percentiles = arrayData.toDoubleArray() + assert(percentiles.zip(expectedPercentiles) + .forall(pair => Math.abs(pair._1 - pair._2) < error)) + } + } + + test("class ApproximatePercentile, low level interface, update, merge, eval...") { + val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) + val inputAggregationBufferOffset = 1 + val mutableAggregationBufferOffset = 2 + val percentage = 0.5D + + // Phase one, partial mode aggregation + val agg = new ApproximatePercentile(childExpression, Literal(percentage)) + .withNewInputAggBufferOffset(inputAggregationBufferOffset) + .withNewMutableAggBufferOffset(mutableAggregationBufferOffset) + + val mutableAggBuffer = new GenericMutableRow(new Array[Any](mutableAggregationBufferOffset + 1)) + agg.initialize(mutableAggBuffer) + val dataCount = 10 + (1 to dataCount).foreach { data => + agg.update(mutableAggBuffer, InternalRow(data)) + } + agg.serializeAggregateBufferInPlace(mutableAggBuffer) + + // Serialize the aggregation buffer + val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset) + val inputAggBuffer = new GenericMutableRow(Array[Any](null, serialized)) + + // Phase 2: final mode aggregation + // Re-initialize the aggregation buffer + agg.initialize(mutableAggBuffer) + agg.merge(mutableAggBuffer, inputAggBuffer) + val expectedPercentile = dataCount * percentage + assert(Math.abs(agg.eval(mutableAggBuffer).asInstanceOf[Double] - expectedPercentile) < 0.1) + } + + test("class ApproximatePercentile, sql string") { + // sql, single percentile + assertEqual( + "percentile_approx(`a`, 0.5D, 1000)", + new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)).sql: String) + + // sql, array of percentile + assertEqual( + "percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), 1000)", + new ApproximatePercentile( + "a".attr, + percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) + ).sql: String) + + // sql(isDistinct = false), single percentile + assertEqual( + "percentile_approx(`a`, 0.5D, 1000)", + new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) + .sql(isDistinct = false)) + + // sql(isDistinct = false), array of percentile + assertEqual( + "percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), 1000)", + new ApproximatePercentile( + "a".attr, + percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) + ).sql(isDistinct = false)) + + // sql(isDistinct = true), single percentile + assertEqual( + "percentile_approx(DISTINCT `a`, 0.5D, 1000)", + new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) + .sql(isDistinct = true)) + + // sql(isDistinct = true), array of percentile + assertEqual( + "percentile_approx(DISTINCT `a`, array(0.25D, 0.5D, 0.75D), 1000)", + new ApproximatePercentile( + "a".attr, + percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) + ).sql(isDistinct = true)) + } + + test("class ApproximatePercentile, fails analysis if parameters are invalid") { + val wrongAccuracy = new ApproximatePercentile( + AttributeReference("a", DoubleType)(), + percentageExpression = Literal(0.5D), + accuracyExpression = Literal(-1)) + assertEqual( + wrongAccuracy.checkInputDataTypes(), + TypeCheckFailure( + "The accuracy provided must be a positive integer literal (current value = -1)")) + + val correctPercentageExpresions = Seq( + Literal(0D), + Literal(1D), + Literal(0.5D), + CreateArray(Seq(0D, 1D, 0.5D).map(Literal(_))) + ) + correctPercentageExpresions.foreach { percentageExpression => + val correctPercentage = new ApproximatePercentile( + AttributeReference("a", DoubleType)(), + percentageExpression = percentageExpression, + accuracyExpression = Literal(100)) + + // no exception should be thrown + correctPercentage.checkInputDataTypes() + } + + val wrongPercentageExpressions = Seq( + Literal(1.1D), + Literal(-0.5D), + CreateArray(Seq(0D, 0.5D, 1.1D).map(Literal(_))) + ) + + wrongPercentageExpressions.foreach { percentageExpression => + val wrongPercentage = new ApproximatePercentile( + AttributeReference("a", DoubleType)(), + percentageExpression = percentageExpression, + accuracyExpression = Literal(100)) + + val result = wrongPercentage.checkInputDataTypes() + assert( + wrongPercentage.checkInputDataTypes() match { + case TypeCheckFailure(msg) if msg.contains("must be between 0.0 and 1.0") => true + case _ => false + }) + } + } + + test("class ApproximatePercentile, automatically add type casting for parameters") { + val testRelation = LocalRelation('a.int) + val analyzer = SimpleAnalyzer + + // Compatible accuracy types: Long type and decimal type + val accuracyExpressions = Seq(Literal(1000L), DecimalLiteral(10000), Literal(123.0D)) + // Compatible percentage types: float, decimal + val percentageExpressions = Seq(Literal(0.3f), DecimalLiteral(0.5), + CreateArray(Seq(Literal(0.3f), Literal(0.5D), DecimalLiteral(0.7)))) + + accuracyExpressions.foreach { accuracyExpression => + percentageExpressions.foreach { percentageExpression => + val agg = new ApproximatePercentile( + UnresolvedAttribute("a"), + percentageExpression, + accuracyExpression) + val analyzed = testRelation.select(agg).analyze.expressions.head + analyzed match { + case Alias(agg: ApproximatePercentile, _) => + assert(agg.resolved) + assert(agg.child.dataType == DoubleType) + assert(agg.percentageExpression.dataType == DoubleType || + agg.percentageExpression.dataType == ArrayType(DoubleType, containsNull = false)) + assert(agg.accuracyExpression.dataType == IntegerType) + case _ => fail() + } + } + } + } + + test("class ApproximatePercentile, null handling") { + val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) + val agg = new ApproximatePercentile(childExpression, Literal(0.5D)) + val buffer = new GenericMutableRow(new Array[Any](1)) + agg.initialize(buffer) + // Empty aggregation buffer + assert(agg.eval(buffer) == null) + // Empty input row + agg.update(buffer, InternalRow(null)) + assert(agg.eval(buffer) == null) + + // Add some non-empty row + agg.update(buffer, InternalRow(0)) + assert(agg.eval(buffer) != null) + } + + private def compareEquals(left: PercentileDigest, right: PercentileDigest): Boolean = { + val leftSummary = left.quantileSummaries + val rightSummary = right.quantileSummaries + leftSummary.compressThreshold == rightSummary.compressThreshold && + leftSummary.relativeError == rightSummary.relativeError && + leftSummary.count == rightSummary.count && + leftSummary.sampled.sameElements(rightSummary.sampled) + } + + private def assertEqual[T](left: T, right: T): Unit = { + assert(left == right) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala new file mode 100644 index 0000000000000..37d7c442bbeb8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest +import org.apache.spark.sql.test.SharedSQLContext + +class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private val table = "percentile_test" + + test("percentile_approx, single percentile value") { + withTempView(table) { + (1 to 1000).toDF("col").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s""" + |SELECT + | percentile_approx(col, 0.25), + | percentile_approx(col, 0.5), + | percentile_approx(col, 0.75d), + | percentile_approx(col, 0.0), + | percentile_approx(col, 1.0), + | percentile_approx(col, 0), + | percentile_approx(col, 1) + |FROM $table + """.stripMargin), + Row(250D, 500D, 750D, 1D, 1000D, 1D, 1000D) + ) + } + } + + test("percentile_approx, array of percentile value") { + withTempView(table) { + (1 to 1000).toDF("col").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s"""SELECT + | percentile_approx(col, array(0.25, 0.5, 0.75D)), + | count(col), + | percentile_approx(col, array(0.0, 1.0)), + | sum(col) + |FROM $table + """.stripMargin), + Row(Seq(250D, 500D, 750D), 1000, Seq(1D, 1000D), 500500) + ) + } + } + + test("percentile_approx, with different accuracies") { + + withTempView(table) { + (1 to 1000).toDF("col").createOrReplaceTempView(table) + + // With different accuracies + val expectedPercentile = 250D + val accuracies = Array(1, 10, 100, 1000, 10000) + val errors = accuracies.map { accuracy => + val df = spark.sql(s"SELECT percentile_approx(col, 0.25, $accuracy) FROM $table") + val approximatePercentile = df.collect().head.getDouble(0) + val error = Math.abs(approximatePercentile - expectedPercentile) + error + } + + // The larger accuracy value we use, the smaller error we get + assert(errors.sorted.sameElements(errors.reverse)) + } + } + + test("percentile_approx, supports constant folding for parameter accuracy and percentages") { + withTempView(table) { + (1 to 1000).toDF("col").createOrReplaceTempView(table) + checkAnswer( + spark.sql(s"SELECT percentile_approx(col, array(0.25 + 0.25D), 200 + 800D) FROM $table"), + Row(Seq(500D)) + ) + } + } + + test("percentile_approx(), aggregation on empty input table, no group by") { + withTempView(table) { + Seq.empty[Int].toDF("col").createOrReplaceTempView(table) + checkAnswer( + spark.sql(s"SELECT sum(col), percentile_approx(col, 0.5) FROM $table"), + Row(null, null) + ) + } + } + + test("percentile_approx(), aggregation on empty input table, with group by") { + withTempView(table) { + Seq.empty[Int].toDF("col").createOrReplaceTempView(table) + checkAnswer( + spark.sql(s"SELECT sum(col), percentile_approx(col, 0.5) FROM $table GROUP BY col"), + Seq.empty[Row] + ) + } + } + + test("percentile_approx(null), aggregation with group by") { + withTempView(table) { + (1 to 1000).map(x => (x % 3, x)).toDF("key", "value").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s"""SELECT + | key, + | percentile_approx(null, 0.5) + |FROM $table + |GROUP BY key + """.stripMargin), + Seq( + Row(0, null), + Row(1, null), + Row(2, null)) + ) + } + } + + test("percentile_approx(null), aggregation without group by") { + withTempView(table) { + (1 to 1000).map(x => (x % 3, x)).toDF("key", "value").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s"""SELECT + | percentile_approx(null, 0.5), + | sum(null), + | percentile_approx(null, 0.5) + |FROM $table + """.stripMargin), + Row(null, null, null) + ) + } + } + + test("percentile_approx(col, ...), input rows contains null, with out group by") { + withTempView(table) { + (1 to 1000).map(new Integer(_)).flatMap(Seq(null: Integer, _)).toDF("col") + .createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s"""SELECT + | percentile_approx(col, 0.5), + | sum(null), + | percentile_approx(col, 0.5) + |FROM $table + """.stripMargin), + Row(500D, null, 500D)) + } + } + + test("percentile_approx(col, ...), input rows contains null, with group by") { + withTempView(table) { + val rand = new java.util.Random() + (1 to 1000) + .map(new Integer(_)) + .map(v => (new Integer(v % 2), v)) + // Add some nulls + .flatMap(Seq(_, (null: Integer, null: Integer))) + .toDF("key", "value").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s"""SELECT + | percentile_approx(value, 0.5), + | sum(value), + | percentile_approx(value, 0.5) + |FROM $table + |GROUP BY key + """.stripMargin), + Seq( + Row(499.0D, 250000, 499.0D), + Row(500.0D, 250500, 500.0D), + Row(null, null, null)) + ) + } + } + + test("percentile_approx(col, ...) works in window function") { + withTempView(table) { + val data = (1 to 10).map(v => (v % 2, v)) + data.toDF("key", "value").createOrReplaceTempView(table) + + val query = spark.sql( + s""" + |SElECT percentile_approx(value, 0.5) + |OVER + | (PARTITION BY key ORDER BY value ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + | AS percentile + |FROM $table + """.stripMargin) + + val expected = data.groupBy(_._1).toSeq.flatMap { group => + val (key, values) = group + val sortedValues = values.map(_._2).sorted + + var outputRows = Seq.empty[Row] + var i = 0 + + val percentile = new PercentileDigest(1.0 / DEFAULT_PERCENTILE_ACCURACY) + sortedValues.foreach { value => + percentile.add(value) + outputRows :+= Row(percentile.getPercentiles(Array(0.5D)).head) + } + outputRows + } + + checkAnswer(query, expected) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index ca8c7347f23e9..1ab0d6dac94a1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -235,7 +235,6 @@ private[sql] class HiveSessionCatalog( private val hiveFunctions = Seq( "hash", "histogram_numeric", - "percentile", - "percentile_approx" + "percentile" ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index b4eb50e331cf9..ee5d90b433e1f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -155,6 +155,7 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { test("aggregate functions") { checkSqlGeneration("SELECT approx_count_distinct(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT percentile_approx(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT avg(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT corr(value, key) FROM t1 GROUP BY key") checkSqlGeneration("SELECT count(value) FROM t1 GROUP BY key") From 3dda1ae90cea0eab8dd963d4af5708bff8846d65 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Tue, 30 Aug 2016 11:12:39 +0800 Subject: [PATCH 2/6] fix UT --- .../aggregate/ApproximatePercentile.scala | 18 ++++++++++-------- .../sql/catalyst/ExpressionToSQLSuite.scala | 6 +++++- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 1d6d2cd7b6d7b..68c471cf03c6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -91,8 +91,10 @@ case class ApproximatePercentile( // Rule ImplicitTypeCasts can cast other numeric types to double case (_, num: Double) => (false, Array(num)) case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => - val numericArray = arrayData.toArray(baseType)(baseType.classTag) - (true, numericArray.map(baseType.numeric.toDouble)) + val numericArray = arrayData.toObjectArray(baseType) + (true, numericArray.map {x => + baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType]) + }) case other => throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage") } @@ -183,8 +185,8 @@ object ApproximatePercentile { * underlying quantileSummaries is compressed. */ class PercentileDigest( - private var summaries: QuantileSummaries, - private var isCompressed: Boolean) { + private var summaries: QuantileSummaries, + private var isCompressed: Boolean) { // Trigger compression if the QuantileSummaries's buffer length exceeds // compressThresHoldBufferLength. The buffer length can be get by @@ -287,10 +289,10 @@ object ApproximatePercentile { // An case class to wrap fields of QuantileSummaries, so that we can use the expression encoder // to serialize it. case class QuantileSummariesData( - val compressThreshold: Int, - val relativeError: Double, - val sampled: Array[Stats] = Array.empty, - val count: Long = 0L) { + compressThreshold: Int, + relativeError: Double, + sampled: Array[Stats], + count: Long) { def this(summary: QuantileSummaries) = { this(summary.compressThreshold, summary.relativeError, summary.sampled, summary.count) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index ee5d90b433e1f..fdd02821dfa29 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -155,7 +155,11 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { test("aggregate functions") { checkSqlGeneration("SELECT approx_count_distinct(value) FROM t1 GROUP BY key") - checkSqlGeneration("SELECT percentile_approx(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT percentile_approx(value, 0.25) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT percentile_approx(value, array(0.25, 0.75)) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT percentile_approx(value, 0.25, 100) FROM t1 GROUP BY key") + checkSqlGeneration( + "SELECT percentile_approx(value, array(0.25, 0.75), 100) FROM t1 GROUP BY key") checkSqlGeneration("SELECT avg(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT corr(value, key) FROM t1 GROUP BY key") checkSqlGeneration("SELECT count(value) FROM t1 GROUP BY key") From 96bd5a4062d81995e438d0e1a301b42b7360994c Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Tue, 30 Aug 2016 17:38:11 +0800 Subject: [PATCH 3/6] address comments --- .../expressions/aggregate/ApproximatePercentile.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 68c471cf03c6d..b7bde9261c29f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.types._ * a array of percentage values. Each percentage value must be between * 0.0 and 1.0. * @param accuracyExpression Integer literal expression of approximation accuracy. Higher value - * yield better accuracy, the default value is + * yields better accuracy, the default value is * DEFAULT_PERCENTILE_ACCURACY. */ @ExpressionDescription( @@ -51,14 +51,14 @@ import org.apache.spark.sql.types._ _FUNC_(col, percentage [, accuracy]) - Returns the approximate percentile value of numeric column `col` at the given percentage. The value of percentage must be between 0.0 and 1.0. The `accuracy` parameter (default: 10000) is a positive integer literal which - controls approximation accuracy at the cost of memory. Higher value of `accuracy` yield + controls approximation accuracy at the cost of memory. Higher value of `accuracy` yields better accuracy, `1.0/accuracy` is the relative error of the approximation. _FUNC_(col, array(percentage1 [, percentage2]...) [, accuracy]) - Returns the approximate percentile array of column `col` at the given percentage array. Each value of the percentage array must be between 0.0 and 1.0. The `accuracy` parameter (default: 10000) is a positive integer literal which controls approximation accuracy at the cost of memory. - Higher value of `accuracy` yield better accuracy, `1.0/accuracy` is the relative error of + Higher value of `accuracy` yields better accuracy, `1.0/accuracy` is the relative error of the approximation. """) case class ApproximatePercentile( @@ -92,7 +92,7 @@ case class ApproximatePercentile( case (_, num: Double) => (false, Array(num)) case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => val numericArray = arrayData.toObjectArray(baseType) - (true, numericArray.map {x => + (true, numericArray.map { x => baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType]) }) case other => From 539d693d7a569cc83019526ea2565de5a86fda83 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Tue, 30 Aug 2016 19:36:28 +0800 Subject: [PATCH 4/6] address comments --- .../expressions/aggregate/ApproximatePercentileSuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala index 8cc3d1de82e39..41eff8ccd5ba2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala @@ -167,6 +167,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { } test("class ApproximatePercentile, sql string") { + val defaultAccuracy = ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY // sql, single percentile assertEqual( "percentile_approx(`a`, 0.5D, 1000)", @@ -174,7 +175,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { // sql, array of percentile assertEqual( - "percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), 1000)", + s"percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)", new ApproximatePercentile( "a".attr, percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) @@ -188,7 +189,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { // sql(isDistinct = false), array of percentile assertEqual( - "percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), 1000)", + s"percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)", new ApproximatePercentile( "a".attr, percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) @@ -202,7 +203,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { // sql(isDistinct = true), array of percentile assertEqual( - "percentile_approx(DISTINCT `a`, array(0.25D, 0.5D, 0.75D), 1000)", + s"percentile_approx(DISTINCT `a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)", new ApproximatePercentile( "a".attr, percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) From 1965b1a918633f0d9b94b0a245605592cddbeb9b Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Wed, 31 Aug 2016 07:26:53 +0800 Subject: [PATCH 5/6] address comments --- .../expressions/aggregate/ApproximatePercentile.scala | 2 +- .../expressions/aggregate/ApproximatePercentileSuite.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index b7bde9261c29f..41d0c1f4bac4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.types._ * * @param child child expression that can produce column value with `child.eval(inputRow)` * @param percentageExpression Expression that represents a single percentage value or - * a array of percentage values. Each percentage value must be between + * an array of percentage values. Each percentage value must be between * 0.0 and 1.0. * @param accuracyExpression Integer literal expression of approximation accuracy. Higher value * yields better accuracy, the default value is diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala index 41eff8ccd5ba2..c3e00c13056d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala @@ -170,7 +170,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { val defaultAccuracy = ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY // sql, single percentile assertEqual( - "percentile_approx(`a`, 0.5D, 1000)", + s"percentile_approx(`a`, 0.5D, $defaultAccuracy)", new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)).sql: String) // sql, array of percentile @@ -183,7 +183,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { // sql(isDistinct = false), single percentile assertEqual( - "percentile_approx(`a`, 0.5D, 1000)", + s"percentile_approx(`a`, 0.5D, $defaultAccuracy)", new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) .sql(isDistinct = false)) @@ -197,7 +197,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { // sql(isDistinct = true), single percentile assertEqual( - "percentile_approx(DISTINCT `a`, 0.5D, 1000)", + s"percentile_approx(DISTINCT `a`, 0.5D, $defaultAccuracy)", new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) .sql(isDistinct = true)) From 3f08c027add03c59251583420c76582a085b3573 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Wed, 31 Aug 2016 09:09:58 +0800 Subject: [PATCH 6/6] address wenchen's comment --- .../aggregate/ApproximatePercentile.scala | 87 +++++++++++-------- .../ApproximatePercentileSuite.scala | 27 +++++- 2 files changed, 76 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 41d0c1f4bac4d..f91ff87fc1c01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -17,13 +17,16 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import java.nio.ByteBuffer + +import com.google.common.primitives.{Doubles, Ints, Longs} + import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflectionLock} +import org.apache.spark.sql.catalyst.{InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest, PercentileDigestSerializer} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats} @@ -79,8 +82,6 @@ case class ApproximatePercentile( // Mark as lazy so that accuracyExpression is not evaluated during tree transformation. private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int] - private lazy val serializer: PercentileDigestSerializer = new PercentileDigestSerializer - override def inputTypes: Seq[AbstractDataType] = { Seq(DoubleType, TypeCollection(DoubleType, ArrayType), IntegerType) } @@ -104,6 +105,8 @@ case class ApproximatePercentile( val defaultCheck = super.checkInputDataTypes() if (defaultCheck.isFailure) { defaultCheck + } else if (!percentageExpression.foldable || !accuracyExpression.foldable) { + TypeCheckFailure(s"The accuracy or percentage provided must be a constant literal") } else if (accuracy <= 0) { TypeCheckFailure( s"The accuracy provided must be a positive integer literal (current value = $accuracy)") @@ -162,11 +165,11 @@ case class ApproximatePercentile( override def prettyName: String = "percentile_approx" override def serialize(obj: PercentileDigest): Array[Byte] = { - serializer.serialize(obj) + ApproximatePercentile.serializer.serialize(obj) } override def deserialize(bytes: Array[Byte]): PercentileDigest = { - serializer.deserialize(bytes) + ApproximatePercentile.serializer.deserialize(bytes) } } @@ -261,44 +264,58 @@ object ApproximatePercentile { /** * Serializer for class [[PercentileDigest]] * - * This class is NOT thread safe because usage of ExpressionEncoder is not threadsafe. + * This class is thread safe. */ class PercentileDigestSerializer { - // In Scala 2.10, the creation of TypeTag is not thread safe. We need to use ScalaReflectionLock - // to protect the creation of this encoder. See SI-6240 for more details. - private[this] final val serializer = ScalaReflectionLock.synchronized { - ExpressionEncoder[QuantileSummariesData].resolveAndBind() + private final def length(summaries: QuantileSummaries): Int = { + // summaries.compressThreshold, summary.relativeError, summary.count + Ints.BYTES + Doubles.BYTES + Longs.BYTES + + // length of summary.sampled + Ints.BYTES + + // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)] + summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES) } - // There are 4 fields in QuantileSummariesData - private[this] final val row = new UnsafeRow(4) - final def serialize(obj: PercentileDigest): Array[Byte] = { - val data = new QuantileSummariesData(obj.quantileSummaries) - serializer.toRow(data).asInstanceOf[UnsafeRow].getBytes + val summary = obj.quantileSummaries + val buffer = ByteBuffer.wrap(new Array(length(summary))) + buffer.putInt(summary.compressThreshold) + buffer.putDouble(summary.relativeError) + buffer.putLong(summary.count) + buffer.putInt(summary.sampled.length) + + var i = 0 + while (i < summary.sampled.length) { + val stat = summary.sampled(i) + buffer.putDouble(stat.value) + buffer.putInt(stat.g) + buffer.putInt(stat.delta) + i += 1 + } + buffer.array() } final def deserialize(bytes: Array[Byte]): PercentileDigest = { - row.pointTo(bytes, bytes.length) - val quantileSummaries = serializer.fromRow(row).toQuantileSummaries - new PercentileDigest(quantileSummaries, isCompressed = true) + val buffer = ByteBuffer.wrap(bytes) + val compressThreshold = buffer.getInt() + val relativeError = buffer.getDouble() + val count = buffer.getLong() + val sampledLength = buffer.getInt() + val sampled = new Array[Stats](sampledLength) + + var i = 0 + while (i < sampledLength) { + val value = buffer.getDouble() + val g = buffer.getInt() + val delta = buffer.getInt() + sampled(i) = Stats(value, g, delta) + i += 1 + } + val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count) + new PercentileDigest(summary, isCompressed = true) } } - // An case class to wrap fields of QuantileSummaries, so that we can use the expression encoder - // to serialize it. - case class QuantileSummariesData( - compressThreshold: Int, - relativeError: Double, - sampled: Array[Stats], - count: Long) { - def this(summary: QuantileSummaries) = { - this(summary.compressThreshold, summary.relativeError, summary.sampled, summary.count) - } - - def toQuantileSummaries: QuantileSummaries = { - new QuantileSummaries(compressThreshold, relativeError, sampled, count) - } - } + val serializer: PercentileDigestSerializer = new PercentileDigestSerializer } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala index c3e00c13056d4..61298a1b72d77 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats import org.apache.spark.sql.types.{ArrayType, DoubleType, IntegerType} import org.apache.spark.util.SizeEstimator - class ApproximatePercentileSuite extends SparkFunSuite { private val random = new java.util.Random() @@ -42,7 +41,6 @@ class ApproximatePercentileSuite extends SparkFunSuite { } test("serialize and de-serialize") { - val serializer = new PercentileDigestSerializer // Check empty serialize and de-serialize @@ -50,7 +48,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { assert(compareEquals(emptyBuffer, serializer.deserialize(serializer.serialize(emptyBuffer)))) val buffer = new PercentileDigest(relativeError = 0.01) - (1 to 100).foreach { value => + data.foreach { value => buffer.add(value) } assert(compareEquals(buffer, serializer.deserialize(serializer.serialize(buffer)))) @@ -210,6 +208,29 @@ class ApproximatePercentileSuite extends SparkFunSuite { ).sql(isDistinct = true)) } + test("class ApproximatePercentile, fails analysis if percentage or accuracy is not a constant") { + val attribute = AttributeReference("a", DoubleType)() + val wrongAccuracy = new ApproximatePercentile( + attribute, + percentageExpression = Literal(0.5D), + accuracyExpression = AttributeReference("b", IntegerType)()) + + assertEqual( + wrongAccuracy.checkInputDataTypes(), + TypeCheckFailure("The accuracy or percentage provided must be a constant literal") + ) + + val wrongPercentage = new ApproximatePercentile( + attribute, + percentageExpression = attribute, + accuracyExpression = Literal(10000)) + + assertEqual( + wrongPercentage.checkInputDataTypes(), + TypeCheckFailure("The accuracy or percentage provided must be a constant literal") + ) + } + test("class ApproximatePercentile, fails analysis if parameters are invalid") { val wrongAccuracy = new ApproximatePercentile( AttributeReference("a", DoubleType)(),