From 3c7dc1950fa52587dd3bf40ce9b553228e28bc25 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Wed, 1 Mar 2017 19:46:07 +0800 Subject: [PATCH 1/4] Implement one kind of streaming sampling, i.e. reservoir sampling --- .../plans/logical/basicLogicalOperators.scala | 13 ++ .../encoders/ExpressionEncoderSuite.scala | 4 +- .../scala/org/apache/spark/sql/Dataset.scala | 26 ++++ .../org/apache/spark/sql/SparkSession.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 18 ++- .../execution/basicPhysicalOperators.scala | 19 ++- .../streaming/IncrementalExecution.scala | 8 +- .../streaming/statefulOperators.scala | 104 +++++++++++++- .../sql/streaming/ReservoirSampleSuit.scala | 134 ++++++++++++++++++ 9 files changed, 315 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/ReservoirSampleSuit.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ccebae3cc270..4c56422ffd61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -827,6 +827,19 @@ case class Sample( override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil } +/** + * A logical plan for `reservoir`. + */ +case class ReservoirSample( + keys: Seq[Attribute], + child: LogicalPlan, + k: Int, + streaming: Boolean = false) + extends UnaryNode { + override def maxRows: Option[Long] = child.maxRows + override def output: Seq[Attribute] = child.output +} + /** * Returns a new logical plan that dedups input rows. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 080f11b76938..e14d1d78a483 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.{Encoder, Encoders} import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1b0462359607..e800c4fb7488 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2016,6 +2016,32 @@ class Dataset[T] private[sql]( Deduplicate(groupCols, logicalPlan, isStreaming) } + /** + * :: Experimental :: + * (Scala-specific) Reservoir sampling implementation. + * + * @todo move this into sample operator. + * @group typedrel + * @since 2.0.0 + */ + @Experimental + @InterfaceStability.Evolving + def reservoir(k: Int): Dataset[T] = withTypedPlan { + val resolver = sparkSession.sessionState.analyzer.resolver + val allColumns = queryExecution.analyzed.output + val groupCols = this.columns.toSet.toSeq.flatMap { (colName: String) => + // It is possibly there are more than one columns with the same name, + // so we call filter instead of find. + val cols = allColumns.filter(col => resolver(col.name, colName)) + if (cols.isEmpty) { + throw new AnalysisException( + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") + } + cols + } + ReservoirSample(groupCols, logicalPlan, k, isStreaming) + } + /** * Returns a new Dataset with duplicate rows removed, considering only * the subset of columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index afc1827e7eec..69a1c490de84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ -import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 20bf4925dbec..9f5ac3b606e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -18,16 +18,14 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Strategy +import org.apache.spark.sql.{execution, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchange @@ -256,6 +254,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Used to plan the steaming reservoir sample operator. + */ + object ReservoirSampleStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ReservoirSample(keys, child, k, true) => + StreamingReservoirSampleExec(keys, PlanLater(child), k) :: Nil + + case _ => Nil + } + } + /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ @@ -408,6 +418,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil + case logical.ReservoirSample(keys, child, k, false) => + execution.ReservoirSampleExec(k, PlanLater(child)) :: Nil case logical.LocalRelation(output, data) => LocalTableScanExec(output, data) :: Nil case logical.LocalLimit(IntegerLiteral(limit), child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 87e90ed685cc..31653ae4939b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration -import org.apache.spark.{InterruptibleIterator, SparkException, TaskContext} +import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.sql.types.LongType import org.apache.spark.util.ThreadUtils -import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler, SamplingUtils} /** Physical plan for Project. */ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) @@ -644,3 +644,18 @@ object SubqueryExec { private[execution] val executionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("subquery", 16)) } + +case class ReservoirSampleExec(k: Int, child: SparkPlan) extends UnaryExecNode { + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + protected override def doExecute(): RDD[InternalRow] = { + child.execute() + .mapPartitions(it => { + SamplingUtils.reservoirSampleAndCount(it, k)._1.iterator}) + .repartition(1) + .mapPartitions(it => { + SamplingUtils.reservoirSampleAndCount(it, k)._1.iterator}) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index ffdcd9b19d05..b0938d202fd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -46,6 +46,7 @@ class IncrementalExecution( sparkSession.sessionState.planner.MapGroupsWithStateStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: sparkSession.sessionState.planner.StreamingDeduplicationStrategy +: + sparkSession.sessionState.planner.ReservoirSampleStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies // Modified planner with stateful operations. @@ -83,7 +84,6 @@ class IncrementalExecution( StateStoreRestoreExec(keys2, None, child))) => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - StateStoreSaveExec( keys, Some(stateId), @@ -97,12 +97,16 @@ class IncrementalExecution( case StreamingDeduplicateExec(keys, child, None, None) => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - StreamingDeduplicateExec( keys, child, Some(stateId), Some(currentEventTimeWatermark)) + case StreamingReservoirSampleExec(k, keys, child, None, None, None) => + val stateId = + OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) + StreamingReservoirSampleExec( + k, keys, child, Some(stateId), Some(currentEventTimeWatermark), Some(outputMode)) case MapGroupsWithStateExec( f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => val stateId = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index d92529748b6a..b5a5f8a5a5cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.execution.streaming +import scala.util.Random + import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} @@ -29,8 +33,9 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{DataType, NullType, StructType} +import org.apache.spark.sql.types.{DataType, LongType, NullType, StructType} import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} /** Used to identify the state store for a given operator. */ @@ -116,8 +121,8 @@ case class StateStoreRestoreExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeVersion = getStateId.batchId, + getStateId.operatorId, + getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, sqlContext.sessionState, @@ -397,3 +402,96 @@ object StreamingDeduplicateExec { private val EMPTY_ROW = UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) } + +/** + * Physical operator for executing streaming Sampling. + * + * @param k random sample k elements. + */ +case class StreamingReservoirSampleExec( + keyExpressions: Seq[Attribute], + child: SparkPlan, + k: Int, + stateId: Option[OperatorStateId] = None, + eventTimeWatermark: Option[Long] = None, + outputMode: Option[OutputMode] = None) + extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + + /** Distribute by grouping attributes */ + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(keyExpressions) :: Nil + + private val enc = Encoders.STRING.asInstanceOf[ExpressionEncoder[String]] + private val NUM_RECORDS_IN_PARTITION = enc.toRow("NUM_RECORDS_IN_PARTITION") + .asInstanceOf[UnsafeRow] + + override protected def doExecute(): RDD[InternalRow] = { + metrics + + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + getStateId.operatorId, + getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + + val numRecordsInPart = store.get(NUM_RECORDS_IN_PARTITION).map(value => { + value.get(0, LongType).asInstanceOf[Long] + }).getOrElse(0L) + + val seed = Random.nextLong() + val rand = new XORShiftRandom(seed) + var numSamples = numRecordsInPart + var count = 0 + + val baseIterator = watermarkPredicate match { + case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case None => iter + } + + baseIterator.foreach { r => + count += 1 + if (numSamples < k) { + numSamples += 1 + store.put(enc.toRow(numSamples.toString).asInstanceOf[UnsafeRow], + r.asInstanceOf[UnsafeRow]) + } else { + val randomIdx = (rand.nextDouble() * (numRecordsInPart + count)).toLong + if (randomIdx <= k) { + val replacementIdx = enc.toRow(randomIdx.toString).asInstanceOf[UnsafeRow] + store.put(replacementIdx, r.asInstanceOf[UnsafeRow]) + } + } + } + + val row = UnsafeProjection.create(Array[DataType](LongType)) + .apply(InternalRow.apply(numRecordsInPart + count)) + store.put(NUM_RECORDS_IN_PARTITION, row) + store.commit() + + outputMode match { + case Some(Complete) => + CompletionIterator[InternalRow, Iterator[InternalRow]]( + store.iterator().filter(kv => { + !kv._1.asInstanceOf[UnsafeRow].equals(NUM_RECORDS_IN_PARTITION) + }).map(kv => kv._2), {}) + case Some(Update) => + CompletionIterator[InternalRow, Iterator[InternalRow]]( + store.updates() + .filter(update => !update.key.equals(NUM_RECORDS_IN_PARTITION)) + .map(update => update.value), {}) + case _ => + throw new UnsupportedOperationException(s"Invalid output mode: $outputMode " + + s"for streaming sampling.") + } + }.repartition(1).mapPartitions(it => { + SamplingUtils.reservoirSampleAndCount(it, k)._1.iterator + }) + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ReservoirSampleSuit.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ReservoirSampleSuit.scala new file mode 100644 index 000000000000..78c3eb396537 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ReservoirSampleSuit.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.StateStore + +class ReservoirSampleSuit extends StateStoreMetricsTest with BeforeAndAfterAll { + + import testImplicits._ + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + test("streaming reservoir sample: reservoir size is larger than stream data size - update mode") { + val inputData = MemoryStream[String] + val result = inputData.toDS().reservoir(4) + + testStream(result, Update)( + AddData(inputData, "a", "b"), + CheckAnswer(Row("a"), Row("b")), + AddData(inputData, "a"), + CheckAnswer(Row("a"), Row("b"), Row("a")) + ) + } + + test("streaming reservoir sample: reservoir size is less than stream data size - update mode") { + val inputData = MemoryStream[String] + val result = inputData.toDS().reservoir(1) + + testStream(result, Update)( + AddData(inputData, "a", "a"), + CheckLastBatch(Row("a")), + AddData(inputData, "b", "b", "b", "b", "b", "b", "b", "b"), + CheckLastBatch(Row("b")) + ) + } + + test("streaming reservoir sample with aggregation - update mode") { + val inputData = MemoryStream[String] + val result = inputData.toDS().reservoir(3).groupBy("value").count() + + testStream(result, Update)( + AddData(inputData, "a"), + CheckAnswer(Row("a", 1)), + AddData(inputData, "b"), + CheckAnswer(Row("a", 1), Row("b", 1)) + ) + } + + test("streaming reservoir sample with watermark") { + val inputData = MemoryStream[Int] + val result = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .reservoir(10) + .select($"eventTime".cast("long").as[Long]) + + testStream(result, Update)( + AddData(inputData, (1 to 1).flatMap(_ => (11 to 15)): _*), + CheckLastBatch(11 to 15: _*), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckLastBatch(25), + AddData(inputData, 25), // Drop states less than watermark + CheckLastBatch(25), + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckLastBatch(), + AddData(inputData, 45), // Advance watermark to 35 seconds + CheckLastBatch(45), + AddData(inputData, 25), // Should not emit anything as data less than watermark + CheckLastBatch() + ) + } + + test("streaming reservoir sample with aggregation - complete mode") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS().select($"_1" as "key", $"_2" as "value") + .reservoir(3).groupBy("key").max("value") + + testStream(result, Complete)( + AddData(inputData, ("a", 1)), + CheckAnswer(Row("a", 1)), + AddData(inputData, ("b", 2)), + CheckAnswer(Row("a", 1), Row("b", 2)), + StopStream, + StartStream(), + AddData(inputData, ("a", 10)), + CheckAnswer(Row("a", 10), Row("b", 2)), + AddData(inputData, (1 to 10).map(e => ("c", 100)): _*), + CheckAnswer(Row("a", 10), Row("b", 2), Row("c", 100)) + ) + } + + test("batch reservoir sample") { + val df = spark.createDataset(Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 0)) + assert(df.reservoir(3).count() == 3, "") + } + + test("batch reservoir sample after aggregation") { + val df = spark.createDataset((1 to 10).map(e => (e, s"val_$e"))) + .select($"_1" as "key", $"_2" as "value") + .groupBy("value").count() + assert(df.reservoir(3).count() == 3, "") + } + + test("batch reservoir sample before aggregation") { + val df = spark.createDataset((1 to 10).map(e => (e, s"val_$e"))) + .select($"_1" as "key", $"_2" as "value") + .reservoir(3) + .groupBy("value").count() + assert(df.count() == 3, "") + } +} From 23738cf55c6ec6357e488eaaf808e7adecbf7fb0 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Thu, 2 Mar 2017 16:34:08 +0800 Subject: [PATCH 2/4] bug fix --- .../spark/util/random/SamplingUtils.scala | 41 +++++++++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 15 +------ .../spark/sql/execution/SparkStrategies.scala | 2 +- .../execution/basicPhysicalOperators.scala | 6 ++- .../streaming/statefulOperators.scala | 26 +++++++++--- 5 files changed, 67 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index a7e0075debed..476e7f4a7b83 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -69,6 +69,47 @@ private[spark] object SamplingUtils { } } + /** + * Weight reservoir sampling implementation. + * + * @param input input size + * @param k reservoir size + * @param seed random seed + * @return samples + */ + def reservoirSampleWithWeight[T: ClassTag]( + input: Iterator[(T, Long)], + k: Int, + seed: Long = Random.nextLong()) + : Array[T] = { + val reservoir = new Array[T](k) + // Put the first k elements in the reservoir. + var i = 0 + while (i < k && input.hasNext) { + val item = input.next() + reservoir(i) = item._1 + i += 1 + } + + if (i < k) { + val trimReservoir = new Array[T](i) + System.arraycopy(reservoir, 0, trimReservoir, 0, i) + trimReservoir + } else { + var l = i.toLong + val rand = new XORShiftRandom(seed) + while (input.hasNext) { + val item = input.next() + l += 1 + val replacementIndex = Math.pow(rand.nextDouble(), 1 / item._2).toInt + if (replacementIndex < k) { + reservoir(replacementIndex.toInt) = item._1 + } + } + reservoir + } + } + /** * Returns a sampling rate that guarantees a sample of size greater than or equal to * sampleSizeLowerBound 99.99% of the time. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e800c4fb7488..3148d3599e36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2004,8 +2004,6 @@ class Dataset[T] private[sql]( val resolver = sparkSession.sessionState.analyzer.resolver val allColumns = queryExecution.analyzed.output val groupCols = colNames.toSet.toSeq.flatMap { (colName: String) => - // It is possibly there are more than one columns with the same name, - // so we call filter instead of find. val cols = allColumns.filter(col => resolver(col.name, colName)) if (cols.isEmpty) { throw new AnalysisException( @@ -2027,19 +2025,8 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def reservoir(k: Int): Dataset[T] = withTypedPlan { - val resolver = sparkSession.sessionState.analyzer.resolver val allColumns = queryExecution.analyzed.output - val groupCols = this.columns.toSet.toSeq.flatMap { (colName: String) => - // It is possibly there are more than one columns with the same name, - // so we call filter instead of find. - val cols = allColumns.filter(col => resolver(col.name, colName)) - if (cols.isEmpty) { - throw new AnalysisException( - s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") - } - cols - } - ReservoirSample(groupCols, logicalPlan, k, isStreaming) + ReservoirSample(allColumns, logicalPlan, k, isStreaming) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9f5ac3b606e2..b8934ccca2cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -255,7 +255,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Used to plan the steaming reservoir sample operator. + * Used to plan the streaming reservoir sample operator. */ object ReservoirSampleStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 31653ae4939b..ff510340eceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -653,9 +653,11 @@ case class ReservoirSampleExec(k: Int, child: SparkPlan) extends UnaryExecNode { protected override def doExecute(): RDD[InternalRow] = { child.execute() .mapPartitions(it => { - SamplingUtils.reservoirSampleAndCount(it, k)._1.iterator}) + val (sample, count) = SamplingUtils.reservoirSampleAndCount(it, k) + sample.map((_, count)).toIterator + }) .repartition(1) .mapPartitions(it => { - SamplingUtils.reservoirSampleAndCount(it, k)._1.iterator}) + SamplingUtils.reservoirSampleWithWeight(it, k).iterator}) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index b5a5f8a5a5cb..2344bea748ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -417,7 +417,6 @@ case class StreamingReservoirSampleExec( outputMode: Option[OutputMode] = None) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { - /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(keyExpressions) :: Nil @@ -427,6 +426,8 @@ case class StreamingReservoirSampleExec( override protected def doExecute(): RDD[InternalRow] = { metrics + val fieldTypes = (keyExpressions.map(_.dataType) ++ Seq(LongType)).toArray + val withSumFieldTypes = (keyExpressions.map(_.dataType) ++ Seq(LongType)).toArray child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, @@ -466,9 +467,9 @@ case class StreamingReservoirSampleExec( } } - val row = UnsafeProjection.create(Array[DataType](LongType)) + val numRecordsTillNow = UnsafeProjection.create(Array[DataType](LongType)) .apply(InternalRow.apply(numRecordsInPart + count)) - store.put(NUM_RECORDS_IN_PARTITION, row) + store.put(NUM_RECORDS_IN_PARTITION, numRecordsTillNow) store.commit() outputMode match { @@ -476,18 +477,31 @@ case class StreamingReservoirSampleExec( CompletionIterator[InternalRow, Iterator[InternalRow]]( store.iterator().filter(kv => { !kv._1.asInstanceOf[UnsafeRow].equals(NUM_RECORDS_IN_PARTITION) - }).map(kv => kv._2), {}) + }).map(kv => { + UnsafeProjection.create(withSumFieldTypes).apply(InternalRow.fromSeq( + new JoinedRow(kv._2, numRecordsTillNow) + .toSeq(withSumFieldTypes))) + }), {}) case Some(Update) => CompletionIterator[InternalRow, Iterator[InternalRow]]( store.updates() .filter(update => !update.key.equals(NUM_RECORDS_IN_PARTITION)) - .map(update => update.value), {}) + .map(update => { + UnsafeProjection.create(withSumFieldTypes).apply(InternalRow.fromSeq( + new JoinedRow(update.value, numRecordsTillNow) + .toSeq(withSumFieldTypes))) + }), {}) case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode " + s"for streaming sampling.") } }.repartition(1).mapPartitions(it => { - SamplingUtils.reservoirSampleAndCount(it, k)._1.iterator + SamplingUtils.reservoirSampleWithWeight( + it.map(item => (item, item.getLong(keyExpressions.size))), k) + .map(row => + UnsafeProjection.create(fieldTypes) + .apply(InternalRow.fromSeq(row.toSeq(fieldTypes))) + ).iterator }) } From 288c124731901c0506833da05119fc131d1bf20e Mon Sep 17 00:00:00 2001 From: uncleGen Date: Fri, 3 Mar 2017 00:37:19 +0800 Subject: [PATCH 3/4] update --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3148d3599e36..e7d22a6a9fcd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2004,6 +2004,8 @@ class Dataset[T] private[sql]( val resolver = sparkSession.sessionState.analyzer.resolver val allColumns = queryExecution.analyzed.output val groupCols = colNames.toSet.toSeq.flatMap { (colName: String) => + // It is possibly there are more than one columns with the same name, + // so we call filter instead of find. val cols = allColumns.filter(col => resolver(col.name, colName)) if (cols.isEmpty) { throw new AnalysisException( From 02d44aa06f025dc1d69a7abbcf59691ce7ee0e4e Mon Sep 17 00:00:00 2001 From: uncleGen Date: Mon, 20 Mar 2017 10:55:24 +0800 Subject: [PATCH 4/4] bug fix --- .../plans/logical/basicLogicalOperators.scala | 2 +- .../main/scala/org/apache/spark/sql/Dataset.scala | 4 ++-- .../spark/sql/execution/SparkStrategies.scala | 8 ++++---- .../sql/execution/basicPhysicalOperators.scala | 6 +++--- .../execution/streaming/IncrementalExecution.scala | 13 +++++++++---- .../sql/execution/streaming/statefulOperators.scala | 10 +++++----- 6 files changed, 24 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ea1145018b74..e08539607cc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -848,7 +848,7 @@ case class Sample( case class ReservoirSample( keys: Seq[Attribute], child: LogicalPlan, - k: Int, + reservoirSize: Int, streaming: Boolean = false) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8b2a75b8ecf4..0957b761c9f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2029,9 +2029,9 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def reservoir(k: Int): Dataset[T] = withTypedPlan { + def reservoir(reservoirSize: Int): Dataset[T] = withTypedPlan { val allColumns = queryExecution.analyzed.output - ReservoirSample(allColumns, logicalPlan, k, isStreaming) + ReservoirSample(allColumns, logicalPlan, reservoirSize, isStreaming) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6ec98b696dc4..c0de4b1c1e32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -259,8 +259,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object ReservoirSampleStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ReservoirSample(keys, child, k, true) => - StreamingReservoirSampleExec(keys, PlanLater(child), k) :: Nil + case ReservoirSample(keys, child, reservoirSize, true) => + StreamingReservoirSampleExec(keys, PlanLater(child), reservoirSize) :: Nil case _ => Nil } @@ -421,8 +421,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil - case logical.ReservoirSample(keys, child, k, false) => - execution.ReservoirSampleExec(k, PlanLater(child)) :: Nil + case logical.ReservoirSample(keys, child, reservoirSize, false) => + execution.ReservoirSampleExec(reservoirSize, PlanLater(child)) :: Nil case logical.LocalRelation(output, data) => LocalTableScanExec(output, data) :: Nil case logical.LocalLimit(IntegerLiteral(limit), child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index fb3347170d59..cb6bb7675f74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -658,7 +658,7 @@ object SubqueryExec { ThreadUtils.newDaemonCachedThreadPool("subquery", 16)) } -case class ReservoirSampleExec(k: Int, child: SparkPlan) extends UnaryExecNode { +case class ReservoirSampleExec(reservoirSize: Int, child: SparkPlan) extends UnaryExecNode { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning @@ -666,11 +666,11 @@ case class ReservoirSampleExec(k: Int, child: SparkPlan) extends UnaryExecNode { protected override def doExecute(): RDD[InternalRow] = { child.execute() .mapPartitions(it => { - val (sample, count) = SamplingUtils.reservoirSampleAndCount(it, k) + val (sample, count) = SamplingUtils.reservoirSampleAndCount(it, reservoirSize) sample.map((_, count)).toIterator }) .repartition(1) .mapPartitions(it => { - SamplingUtils.reservoirSampleWithWeight(it, k).iterator}) + SamplingUtils.reservoirSampleWithWeight(it, reservoirSize).iterator}) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 5ae47012bca9..4ebb6b1da88a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -103,13 +103,18 @@ class IncrementalExecution( child, Some(stateId), Some(offsetSeqMetadata.batchWatermarkMs)) - - case StreamingReservoirSampleExec(k, keys, child, None, None, None) => + + case StreamingReservoirSampleExec(keys, child, reservoirSize, None, None, None) => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) StreamingReservoirSampleExec( - k, keys, child, Some(stateId), Some(currentEventTimeWatermark), Some(outputMode)) - + keys, + child, + reservoirSize, + Some(stateId), + Some(offsetSeqMetadata.batchWatermarkMs), + Some(outputMode)) + case m: FlatMapGroupsWithStateExec => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 41a59be66d58..85ee9907af71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -330,12 +330,12 @@ object StreamingDeduplicateExec { /** * Physical operator for executing streaming Sampling. * - * @param k random sample k elements. + * @param reservoirSize number of random sample elements. */ case class StreamingReservoirSampleExec( keyExpressions: Seq[Attribute], child: SparkPlan, - k: Int, + reservoirSize: Int, stateId: Option[OperatorStateId] = None, eventTimeWatermark: Option[Long] = None, outputMode: Option[OutputMode] = None) @@ -378,13 +378,13 @@ case class StreamingReservoirSampleExec( baseIterator.foreach { r => count += 1 - if (numSamples < k) { + if (numSamples < reservoirSize) { numSamples += 1 store.put(enc.toRow(numSamples.toString).asInstanceOf[UnsafeRow], r.asInstanceOf[UnsafeRow]) } else { val randomIdx = (rand.nextDouble() * (numRecordsInPart + count)).toLong - if (randomIdx <= k) { + if (randomIdx <= reservoirSize) { val replacementIdx = enc.toRow(randomIdx.toString).asInstanceOf[UnsafeRow] store.put(replacementIdx, r.asInstanceOf[UnsafeRow]) } @@ -421,7 +421,7 @@ case class StreamingReservoirSampleExec( } }.repartition(1).mapPartitions(it => { SamplingUtils.reservoirSampleWithWeight( - it.map(item => (item, item.getLong(keyExpressions.size))), k) + it.map(item => (item, item.getLong(keyExpressions.size))), reservoirSize) .map(row => UnsafeProjection.create(fieldTypes) .apply(InternalRow.fromSeq(row.toSeq(fieldTypes)))