diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index a7e0075debed..476e7f4a7b83 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -69,6 +69,47 @@ private[spark] object SamplingUtils { } } + /** + * Weight reservoir sampling implementation. + * + * @param input input size + * @param k reservoir size + * @param seed random seed + * @return samples + */ + def reservoirSampleWithWeight[T: ClassTag]( + input: Iterator[(T, Long)], + k: Int, + seed: Long = Random.nextLong()) + : Array[T] = { + val reservoir = new Array[T](k) + // Put the first k elements in the reservoir. + var i = 0 + while (i < k && input.hasNext) { + val item = input.next() + reservoir(i) = item._1 + i += 1 + } + + if (i < k) { + val trimReservoir = new Array[T](i) + System.arraycopy(reservoir, 0, trimReservoir, 0, i) + trimReservoir + } else { + var l = i.toLong + val rand = new XORShiftRandom(seed) + while (input.hasNext) { + val item = input.next() + l += 1 + val replacementIndex = Math.pow(rand.nextDouble(), 1 / item._2).toInt + if (replacementIndex < k) { + reservoir(replacementIndex.toInt) = item._1 + } + } + reservoir + } + } + /** * Returns a sampling rate that guarantees a sample of size greater than or equal to * sampleSizeLowerBound 99.99% of the time. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 5cbf263d1ce4..e08539607cc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -842,6 +842,19 @@ case class Sample( override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil } +/** + * A logical plan for `reservoir`. + */ +case class ReservoirSample( + keys: Seq[Attribute], + child: LogicalPlan, + reservoirSize: Int, + streaming: Boolean = false) + extends UnaryNode { + override def maxRows: Option[Long] = child.maxRows + override def output: Seq[Attribute] = child.output +} + /** * Returns a new logical plan that dedups input rows. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 080f11b76938..e14d1d78a483 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.{Encoder, Encoders} import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 520663f62440..0957b761c9f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2019,6 +2019,21 @@ class Dataset[T] private[sql]( Deduplicate(groupCols, logicalPlan, isStreaming) } + /** + * :: Experimental :: + * (Scala-specific) Reservoir sampling implementation. + * + * @todo move this into sample operator. + * @group typedrel + * @since 2.0.0 + */ + @Experimental + @InterfaceStability.Evolving + def reservoir(reservoirSize: Int): Dataset[T] = withTypedPlan { + val allColumns = queryExecution.analyzed.output + ReservoirSample(allColumns, logicalPlan, reservoirSize, isStreaming) + } + /** * Returns a new Dataset with duplicate rows removed, considering only * the subset of columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9e58e8ce3d5f..c0de4b1c1e32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -18,16 +18,14 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Strategy +import org.apache.spark.sql.{execution, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchange @@ -256,6 +254,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Used to plan the streaming reservoir sample operator. + */ + object ReservoirSampleStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ReservoirSample(keys, child, reservoirSize, true) => + StreamingReservoirSampleExec(keys, PlanLater(child), reservoirSize) :: Nil + + case _ => Nil + } + } + /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ @@ -411,6 +421,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil + case logical.ReservoirSample(keys, child, reservoirSize, false) => + execution.ReservoirSampleExec(reservoirSize, PlanLater(child)) :: Nil case logical.LocalRelation(output, data) => LocalTableScanExec(output, data) :: Nil case logical.LocalLimit(IntegerLiteral(limit), child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d876688a8aab..cb6bb7675f74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration -import org.apache.spark.{InterruptibleIterator, SparkException, TaskContext} +import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.sql.types.LongType import org.apache.spark.util.ThreadUtils -import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler, SamplingUtils} /** Physical plan for Project. */ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) @@ -657,3 +657,20 @@ object SubqueryExec { private[execution] val executionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("subquery", 16)) } + +case class ReservoirSampleExec(reservoirSize: Int, child: SparkPlan) extends UnaryExecNode { + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + protected override def doExecute(): RDD[InternalRow] = { + child.execute() + .mapPartitions(it => { + val (sample, count) = SamplingUtils.reservoirSampleAndCount(it, reservoirSize) + sample.map((_, count)).toIterator + }) + .repartition(1) + .mapPartitions(it => { + SamplingUtils.reservoirSampleWithWeight(it, reservoirSize).iterator}) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a934c75a0245..4ebb6b1da88a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -46,6 +46,7 @@ class IncrementalExecution( sparkSession.sessionState.planner.FlatMapGroupsWithStateStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: sparkSession.sessionState.planner.StreamingDeduplicationStrategy +: + sparkSession.sessionState.planner.ReservoirSampleStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies // Modified planner with stateful operations. @@ -83,7 +84,6 @@ class IncrementalExecution( StateStoreRestoreExec(keys2, None, child))) => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - StateStoreSaveExec( keys, Some(stateId), @@ -98,13 +98,23 @@ class IncrementalExecution( case StreamingDeduplicateExec(keys, child, None, None) => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - StreamingDeduplicateExec( keys, child, Some(stateId), Some(offsetSeqMetadata.batchWatermarkMs)) + case StreamingReservoirSampleExec(keys, child, reservoirSize, None, None, None) => + val stateId = + OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) + StreamingReservoirSampleExec( + keys, + child, + reservoirSize, + Some(stateId), + Some(offsetSeqMetadata.batchWatermarkMs), + Some(outputMode)) + case m: FlatMapGroupsWithStateExec => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 6d2de441eb44..85ee9907af71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.execution.streaming +import scala.util.Random + import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} @@ -32,6 +35,7 @@ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode} import org.apache.spark.sql.types._ import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} /** Used to identify the state store for a given operator. */ @@ -127,8 +131,8 @@ case class StateStoreRestoreExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeVersion = getStateId.batchId, + getStateId.operatorId, + getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, sqlContext.sessionState, @@ -322,3 +326,110 @@ object StreamingDeduplicateExec { private val EMPTY_ROW = UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) } + +/** + * Physical operator for executing streaming Sampling. + * + * @param reservoirSize number of random sample elements. + */ +case class StreamingReservoirSampleExec( + keyExpressions: Seq[Attribute], + child: SparkPlan, + reservoirSize: Int, + stateId: Option[OperatorStateId] = None, + eventTimeWatermark: Option[Long] = None, + outputMode: Option[OutputMode] = None) + extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(keyExpressions) :: Nil + + private val enc = Encoders.STRING.asInstanceOf[ExpressionEncoder[String]] + private val NUM_RECORDS_IN_PARTITION = enc.toRow("NUM_RECORDS_IN_PARTITION") + .asInstanceOf[UnsafeRow] + + override protected def doExecute(): RDD[InternalRow] = { + metrics + val fieldTypes = (keyExpressions.map(_.dataType) ++ Seq(LongType)).toArray + val withSumFieldTypes = (keyExpressions.map(_.dataType) ++ Seq(LongType)).toArray + + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + getStateId.operatorId, + getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + + val numRecordsInPart = store.get(NUM_RECORDS_IN_PARTITION).map(value => { + value.get(0, LongType).asInstanceOf[Long] + }).getOrElse(0L) + + val seed = Random.nextLong() + val rand = new XORShiftRandom(seed) + var numSamples = numRecordsInPart + var count = 0 + + val baseIterator = watermarkPredicate match { + case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case None => iter + } + + baseIterator.foreach { r => + count += 1 + if (numSamples < reservoirSize) { + numSamples += 1 + store.put(enc.toRow(numSamples.toString).asInstanceOf[UnsafeRow], + r.asInstanceOf[UnsafeRow]) + } else { + val randomIdx = (rand.nextDouble() * (numRecordsInPart + count)).toLong + if (randomIdx <= reservoirSize) { + val replacementIdx = enc.toRow(randomIdx.toString).asInstanceOf[UnsafeRow] + store.put(replacementIdx, r.asInstanceOf[UnsafeRow]) + } + } + } + + val numRecordsTillNow = UnsafeProjection.create(Array[DataType](LongType)) + .apply(InternalRow.apply(numRecordsInPart + count)) + store.put(NUM_RECORDS_IN_PARTITION, numRecordsTillNow) + store.commit() + + outputMode match { + case Some(Complete) => + CompletionIterator[InternalRow, Iterator[InternalRow]]( + store.iterator().filter(kv => { + !kv._1.asInstanceOf[UnsafeRow].equals(NUM_RECORDS_IN_PARTITION) + }).map(kv => { + UnsafeProjection.create(withSumFieldTypes).apply(InternalRow.fromSeq( + new JoinedRow(kv._2, numRecordsTillNow) + .toSeq(withSumFieldTypes))) + }), {}) + case Some(Update) => + CompletionIterator[InternalRow, Iterator[InternalRow]]( + store.updates() + .filter(update => !update.key.equals(NUM_RECORDS_IN_PARTITION)) + .map(update => { + UnsafeProjection.create(withSumFieldTypes).apply(InternalRow.fromSeq( + new JoinedRow(update.value, numRecordsTillNow) + .toSeq(withSumFieldTypes))) + }), {}) + case _ => + throw new UnsupportedOperationException(s"Invalid output mode: $outputMode " + + s"for streaming sampling.") + } + }.repartition(1).mapPartitions(it => { + SamplingUtils.reservoirSampleWithWeight( + it.map(item => (item, item.getLong(keyExpressions.size))), reservoirSize) + .map(row => + UnsafeProjection.create(fieldTypes) + .apply(InternalRow.fromSeq(row.toSeq(fieldTypes))) + ).iterator + }) + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ReservoirSampleSuit.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ReservoirSampleSuit.scala new file mode 100644 index 000000000000..78c3eb396537 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ReservoirSampleSuit.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.StateStore + +class ReservoirSampleSuit extends StateStoreMetricsTest with BeforeAndAfterAll { + + import testImplicits._ + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + test("streaming reservoir sample: reservoir size is larger than stream data size - update mode") { + val inputData = MemoryStream[String] + val result = inputData.toDS().reservoir(4) + + testStream(result, Update)( + AddData(inputData, "a", "b"), + CheckAnswer(Row("a"), Row("b")), + AddData(inputData, "a"), + CheckAnswer(Row("a"), Row("b"), Row("a")) + ) + } + + test("streaming reservoir sample: reservoir size is less than stream data size - update mode") { + val inputData = MemoryStream[String] + val result = inputData.toDS().reservoir(1) + + testStream(result, Update)( + AddData(inputData, "a", "a"), + CheckLastBatch(Row("a")), + AddData(inputData, "b", "b", "b", "b", "b", "b", "b", "b"), + CheckLastBatch(Row("b")) + ) + } + + test("streaming reservoir sample with aggregation - update mode") { + val inputData = MemoryStream[String] + val result = inputData.toDS().reservoir(3).groupBy("value").count() + + testStream(result, Update)( + AddData(inputData, "a"), + CheckAnswer(Row("a", 1)), + AddData(inputData, "b"), + CheckAnswer(Row("a", 1), Row("b", 1)) + ) + } + + test("streaming reservoir sample with watermark") { + val inputData = MemoryStream[Int] + val result = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .reservoir(10) + .select($"eventTime".cast("long").as[Long]) + + testStream(result, Update)( + AddData(inputData, (1 to 1).flatMap(_ => (11 to 15)): _*), + CheckLastBatch(11 to 15: _*), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckLastBatch(25), + AddData(inputData, 25), // Drop states less than watermark + CheckLastBatch(25), + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckLastBatch(), + AddData(inputData, 45), // Advance watermark to 35 seconds + CheckLastBatch(45), + AddData(inputData, 25), // Should not emit anything as data less than watermark + CheckLastBatch() + ) + } + + test("streaming reservoir sample with aggregation - complete mode") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS().select($"_1" as "key", $"_2" as "value") + .reservoir(3).groupBy("key").max("value") + + testStream(result, Complete)( + AddData(inputData, ("a", 1)), + CheckAnswer(Row("a", 1)), + AddData(inputData, ("b", 2)), + CheckAnswer(Row("a", 1), Row("b", 2)), + StopStream, + StartStream(), + AddData(inputData, ("a", 10)), + CheckAnswer(Row("a", 10), Row("b", 2)), + AddData(inputData, (1 to 10).map(e => ("c", 100)): _*), + CheckAnswer(Row("a", 10), Row("b", 2), Row("c", 100)) + ) + } + + test("batch reservoir sample") { + val df = spark.createDataset(Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 0)) + assert(df.reservoir(3).count() == 3, "") + } + + test("batch reservoir sample after aggregation") { + val df = spark.createDataset((1 to 10).map(e => (e, s"val_$e"))) + .select($"_1" as "key", $"_2" as "value") + .groupBy("value").count() + assert(df.reservoir(3).count() == 3, "") + } + + test("batch reservoir sample before aggregation") { + val df = spark.createDataset((1 to 10).map(e => (e, s"val_$e"))) + .select($"_1" as "key", $"_2" as "value") + .reservoir(3) + .groupBy("value").count() + assert(df.count() == 3, "") + } +}