diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 0ebdb5a27308..97c9fa1b4584 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -498,11 +498,7 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"), #' @export #' @examples #' \dontrun{ -#' # nolint start -#' # An example "path/to/file" can be -#' # paste0(Sys.getenv("SPARK_HOME"), "/data/mllib/sample_lda_libsvm_data.txt") -#' # nolint end -#' text <- read.df("path/to/file", source = "libsvm") +#' text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") #' model <- spark.lda(data = text, optimizer = "em") #' #' # get a summary of the model diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R index 96251b2c7c19..dfcb45a1b66c 100644 --- a/R/pkg/R/mllib_fpm.R +++ b/R/pkg/R/mllib_fpm.R @@ -27,6 +27,10 @@ setClass("FPGrowthModel", slots = list(jobj = "jobj")) #' FP-growth #' #' A parallel FP-growth algorithm to mine frequent itemsets. +#' \code{spark.fpGrowth} fits a FP-growth model on a SparkDataFrame. Users can +#' \code{spark.freqItemsets} to get frequent itemsets, \code{spark.associationRules} to get +#' association rules, \code{predict} to make predictions on new data based on generated association +#' rules, and \code{write.ml}/\code{read.ml} to save/load fitted models. #' For more details, see #' \href{https://spark.apache.org/docs/latest/mllib-frequent-pattern-mining.html#fp-growth}{ #' FP-growth}. diff --git a/core/src/main/resources/org/apache/spark/ui/static/log-view.js b/core/src/main/resources/org/apache/spark/ui/static/log-view.js index 1782b4f209c0..b5c43e5788bc 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/log-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/log-view.js @@ -51,13 +51,26 @@ function noNewAlert() { window.setTimeout(function () {alert.css("display", "none");}, 4000); } + +function getRESTEndPoint() { + // If the worker is served from the master through a proxy (see doc on spark.ui.reverseProxy), + // we need to retain the leading ../proxy// part of the URL when making REST requests. + // Similar logic is contained in executorspage.js function createRESTEndPoint. + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + return words.slice(0, ind + 2).join('/') + "/log"; + } + return "/log" +} + function loadMore() { var offset = Math.max(startByte - byteLength, 0); var moreByteLength = Math.min(byteLength, startByte); $.ajax({ type: "GET", - url: "/log" + baseParams + "&offset=" + offset + "&byteLength=" + moreByteLength, + url: getRESTEndPoint() + baseParams + "&offset=" + offset + "&byteLength=" + moreByteLength, success: function (data) { var oldHeight = $(".log-content")[0].scrollHeight; var newlineIndex = data.indexOf('\n'); @@ -83,14 +96,14 @@ function loadMore() { function loadNew() { $.ajax({ type: "GET", - url: "/log" + baseParams + "&byteLength=0", + url: getRESTEndPoint() + baseParams + "&byteLength=0", success: function (data) { var dataInfo = data.substring(0, data.indexOf('\n')).match(/\d+/g); var newDataLen = dataInfo[2] - totalLogLength; if (newDataLen != 0) { $.ajax({ type: "GET", - url: "/log" + baseParams + "&byteLength=" + newDataLen, + url: getRESTEndPoint() + baseParams + "&byteLength=" + newDataLen, success: function (data) { var newlineIndex = data.indexOf('\n'); var dataInfo = data.substring(0, newlineIndex).match(/\d+/g); diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index b9921138cc6c..e5d60a7ef098 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -147,6 +147,14 @@ private[spark] class TypedConfigBuilder[T]( } } + /** Creates a [[ConfigEntry]] with a function to determine the default value */ + def createWithDefaultFunction(defaultFunc: () => T): ConfigEntry[T] = { + val entry = new ConfigEntryWithDefaultFunction[T](parent.key, defaultFunc, converter, + stringConverter, parent._doc, parent._public) + parent._onCreate.foreach(_ (entry)) + entry + } + /** * Creates a [[ConfigEntry]] that has a default value. The default value is provided as a * [[String]] and must be a valid value for the entry. diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala index 4f3e42bb3c94..e86712e84d6a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -78,7 +78,24 @@ private class ConfigEntryWithDefault[T] ( def readFrom(reader: ConfigReader): T = { reader.get(key).map(valueConverter).getOrElse(_defaultValue) } +} + +private class ConfigEntryWithDefaultFunction[T] ( + key: String, + _defaultFunction: () => T, + valueConverter: String => T, + stringConverter: T => String, + doc: String, + isPublic: Boolean) + extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + + override def defaultValue: Option[T] = Some(_defaultFunction()) + override def defaultValueString: String = stringConverter(_defaultFunction()) + + def readFrom(reader: ConfigReader): T = { + reader.get(key).map(valueConverter).getOrElse(_defaultFunction()) + } } private class ConfigEntryWithDefaultString[T] ( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index c1f25114371f..181465bdf960 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -42,7 +42,7 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" val stageId = Option(request.getParameter("id")).map(_.toInt) stageId.foreach { id => if (progressListener.activeStages.contains(id)) { - sc.foreach(_.cancelStage(id)) + sc.foreach(_.cancelStage(id, "killed via the Web UI")) // Do a quick pause here to give Spark time to kill the stage so it shows up as // killed after the refresh. Note that this will block the serving thread so the // time should be limited in duration. diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index 3ff7e84d73bd..e2ba0d2a53d0 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -251,4 +251,13 @@ class ConfigEntrySuite extends SparkFunSuite { .createWithDefault(null) testEntryRef(nullConf, ref(nullConf)) } + + test("conf entry : default function") { + var data = 0 + val conf = new SparkConf() + val iConf = ConfigBuilder(testKey("intval")).intConf.createWithDefaultFunction(() => data) + assert(conf.get(iConf) === 0) + data = 2 + assert(conf.get(iConf) === 2) + } } diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index b5cf9f164498..37a1d6189a42 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1,6 +1,6 @@ --- layout: global -displayTitle: Structured Streaming Programming Guide [Alpha] +displayTitle: Structured Streaming Programming Guide [Experimental] title: Structured Streaming Programming Guide --- @@ -871,6 +871,65 @@ streamingDf.join(staticDf, "type", "right_join") # right outer join with a stat +### Streaming Deduplication +You can deduplicate records in data streams using a unique identifier in the events. This is exactly same as deduplication on static using a unique identifier column. The query will store the necessary amount of data from previous records such that it can filter duplicate records. Similar to aggregations, you can use deduplication with or without watermarking. + +- *With watermark* - If there is a upper bound on how late a duplicate record may arrive, then you can define a watermark on a event time column and deduplicate using both the guid and the event time columns. The query will use the watermark to remove old state data from past records that are not expected to get any duplicates any more. This bounds the amount of the state the query has to maintain. + +- *Without watermark* - Since there are no bounds on when a duplicate record may arrive, the query stores the data from all the past records as state. + +
+
+ +{% highlight scala %} +val streamingDf = spark.readStream. ... // columns: guid, eventTime, ... + +// Without watermark using guid column +streamingDf.dropDuplicates("guid") + +// With watermark using guid and eventTime columns +streamingDf + .withWatermark("eventTime", "10 seconds") + .dropDuplicates("guid", "eventTime") +{% endhighlight %} + +
+
+ +{% highlight java %} +Dataset streamingDf = spark.readStream. ...; // columns: guid, eventTime, ... + +// Without watermark using guid column +streamingDf.dropDuplicates("guid"); + +// With watermark using guid and eventTime columns +streamingDf + .withWatermark("eventTime", "10 seconds") + .dropDuplicates("guid", "eventTime"); +{% endhighlight %} + + +
+
+ +{% highlight python %} +streamingDf = spark.readStream. ... + +// Without watermark using guid column +streamingDf.dropDuplicates("guid") + +// With watermark using guid and eventTime columns +streamingDf \ + .withWatermark("eventTime", "10 seconds") \ + .dropDuplicates("guid", "eventTime") +{% endhighlight %} + +
+
+ +### Arbitrary Stateful Operations +Many uscases require more advanced stateful operations than aggregations. For example, in many usecases, you have to track sessions from data streams of events. For doing such sessionization, you will have to save arbitrary types of data as state, and perform arbitrary operations on the state using the data stream events in every trigger. Since Spark 2.2, this can be done using the operation `mapGroupsWithState` and the more powerful operation `flatMapGroupsWithState`. Both operations allow you to apply user-defined code on grouped Datasets to update user-defined state. For more concrete details, take a look at the API documentation ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.GroupState)/[Java](api/java/org/apache/spark/sql/streaming/GroupState.html)) and the examples ([Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java)). + ### Unsupported Operations There are a few DataFrame/Dataset operations that are not supported with streaming DataFrames/Datasets. Some of them are as follows. @@ -891,7 +950,7 @@ Some of them are as follows. + Right outer join with a streaming Dataset on the left is not supported -- Any kind of joins between two streaming Datasets are not yet supported. +- Any kind of joins between two streaming Datasets is not yet supported. In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not make sense on a streaming Dataset. Rather, those functionalities can be done by explicitly starting a streaming query (see the next section regarding that). @@ -951,13 +1010,6 @@ Here is the compatibility matrix. Supported Output Modes Notes - - Queries without aggregation - Append, Update - - Complete mode not supported as it is infeasible to keep all data in the Result Table. - - Queries with aggregation Aggregation on event-time with watermark @@ -986,6 +1038,33 @@ Here is the compatibility matrix. this mode. + + Queries with mapGroupsWithState + Update + + + + Queries with flatMapGroupsWithState + Append operation mode + Append + + Aggregations are allowed after flatMapGroupsWithState. + + + + Update operation mode + Update + + Aggregations not allowed after flatMapGroupsWithState. + + + + Other queries + Append, Update + + Complete mode not supported as it is infeasible to keep all unaggregated data in the Result Table. + + @@ -994,6 +1073,7 @@ Here is the compatibility matrix. + #### Output Sinks There are a few types of built-in output sinks. diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java new file mode 100644 index 000000000000..da3a5dfe8628 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql.streaming; + +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.api.java.function.MapGroupsWithStateFunction; +import org.apache.spark.sql.*; +import org.apache.spark.sql.streaming.GroupState; +import org.apache.spark.sql.streaming.GroupStateTimeout; +import org.apache.spark.sql.streaming.StreamingQuery; + +import java.io.Serializable; +import java.sql.Timestamp; +import java.util.*; + +import scala.Tuple2; + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network. + *

+ * Usage: JavaStructuredNetworkWordCount + * and describe the TCP server that Structured Streaming + * would connect to receive data. + *

+ * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example sql.streaming.JavaStructuredSessionization + * localhost 9999` + */ +public final class JavaStructuredSessionization { + + public static void main(String[] args) throws Exception { + if (args.length < 2) { + System.err.println("Usage: JavaStructuredSessionization "); + System.exit(1); + } + + String host = args[0]; + int port = Integer.parseInt(args[1]); + + SparkSession spark = SparkSession + .builder() + .appName("JavaStructuredSessionization") + .getOrCreate(); + + // Create DataFrame representing the stream of input lines from connection to host:port + Dataset lines = spark + .readStream() + .format("socket") + .option("host", host) + .option("port", port) + .option("includeTimestamp", true) + .load(); + + FlatMapFunction linesToEvents = + new FlatMapFunction() { + @Override + public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exception { + ArrayList eventList = new ArrayList(); + for (String word : lineWithTimestamp.getLine().split(" ")) { + eventList.add(new Event(word, lineWithTimestamp.getTimestamp())); + } + System.out.println( + "Number of events from " + lineWithTimestamp.getLine() + " = " + eventList.size()); + return eventList.iterator(); + } + }; + + // Split the lines into words, treat words as sessionId of events + Dataset events = lines + .withColumnRenamed("value", "line") + .as(Encoders.bean(LineWithTimestamp.class)) + .flatMap(linesToEvents, Encoders.bean(Event.class)); + + // Sessionize the events. Track number of events, start and end timestamps of session, and + // and report session updates. + // + // Step 1: Define the state update function + MapGroupsWithStateFunction stateUpdateFunc = + new MapGroupsWithStateFunction() { + @Override public SessionUpdate call( + String sessionId, Iterator events, GroupState state) + throws Exception { + // If timed out, then remove session and send final update + if (state.hasTimedOut()) { + SessionUpdate finalUpdate = new SessionUpdate( + sessionId, state.get().getDurationMs(), state.get().getNumEvents(), true); + state.remove(); + return finalUpdate; + + } else { + // Find max and min timestamps in events + long maxTimestampMs = Long.MIN_VALUE; + long minTimestampMs = Long.MAX_VALUE; + int numNewEvents = 0; + while (events.hasNext()) { + Event e = events.next(); + long timestampMs = e.getTimestamp().getTime(); + maxTimestampMs = Math.max(timestampMs, maxTimestampMs); + minTimestampMs = Math.min(timestampMs, minTimestampMs); + numNewEvents += 1; + } + SessionInfo updatedSession = new SessionInfo(); + + // Update start and end timestamps in session + if (state.exists()) { + SessionInfo oldSession = state.get(); + updatedSession.setNumEvents(oldSession.numEvents + numNewEvents); + updatedSession.setStartTimestampMs(oldSession.startTimestampMs); + updatedSession.setEndTimestampMs(Math.max(oldSession.endTimestampMs, maxTimestampMs)); + } else { + updatedSession.setNumEvents(numNewEvents); + updatedSession.setStartTimestampMs(minTimestampMs); + updatedSession.setEndTimestampMs(maxTimestampMs); + } + state.update(updatedSession); + // Set timeout such that the session will be expired if no data received for 10 seconds + state.setTimeoutDuration("10 seconds"); + return new SessionUpdate( + sessionId, state.get().getDurationMs(), state.get().getNumEvents(), false); + } + } + }; + + // Step 2: Apply the state update function to the events streaming Dataset grouped by sessionId + Dataset sessionUpdates = events + .groupByKey( + new MapFunction() { + @Override public String call(Event event) throws Exception { + return event.getSessionId(); + } + }, Encoders.STRING()) + .mapGroupsWithState( + stateUpdateFunc, + Encoders.bean(SessionInfo.class), + Encoders.bean(SessionUpdate.class), + GroupStateTimeout.ProcessingTimeTimeout()); + + // Start running the query that prints the session updates to the console + StreamingQuery query = sessionUpdates + .writeStream() + .outputMode("update") + .format("console") + .start(); + + query.awaitTermination(); + } + + /** + * User-defined data type representing the raw lines with timestamps. + */ + public static class LineWithTimestamp implements Serializable { + private String line; + private Timestamp timestamp; + + public Timestamp getTimestamp() { return timestamp; } + public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; } + + public String getLine() { return line; } + public void setLine(String sessionId) { this.line = sessionId; } + } + + /** + * User-defined data type representing the input events + */ + public static class Event implements Serializable { + private String sessionId; + private Timestamp timestamp; + + public Event() { } + public Event(String sessionId, Timestamp timestamp) { + this.sessionId = sessionId; + this.timestamp = timestamp; + } + + public Timestamp getTimestamp() { return timestamp; } + public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; } + + public String getSessionId() { return sessionId; } + public void setSessionId(String sessionId) { this.sessionId = sessionId; } + } + + /** + * User-defined data type for storing a session information as state in mapGroupsWithState. + */ + public static class SessionInfo implements Serializable { + private int numEvents = 0; + private long startTimestampMs = -1; + private long endTimestampMs = -1; + + public int getNumEvents() { return numEvents; } + public void setNumEvents(int numEvents) { this.numEvents = numEvents; } + + public long getStartTimestampMs() { return startTimestampMs; } + public void setStartTimestampMs(long startTimestampMs) { + this.startTimestampMs = startTimestampMs; + } + + public long getEndTimestampMs() { return endTimestampMs; } + public void setEndTimestampMs(long endTimestampMs) { this.endTimestampMs = endTimestampMs; } + + public long getDurationMs() { return endTimestampMs - startTimestampMs; } + @Override public String toString() { + return "SessionInfo(numEvents = " + numEvents + + ", timestamps = " + startTimestampMs + " to " + endTimestampMs + ")"; + } + } + + /** + * User-defined data type representing the update information returned by mapGroupsWithState. + */ + public static class SessionUpdate implements Serializable { + private String id; + private long durationMs; + private int numEvents; + private boolean expired; + + public SessionUpdate() { } + + public SessionUpdate(String id, long durationMs, int numEvents, boolean expired) { + this.id = id; + this.durationMs = durationMs; + this.numEvents = numEvents; + this.expired = expired; + } + + public String getId() { return id; } + public void setId(String id) { this.id = id; } + + public long getDurationMs() { return durationMs; } + public void setDurationMs(long durationMs) { this.durationMs = durationMs; } + + public int getNumEvents() { return numEvents; } + public void setNumEvents(int numEvents) { this.numEvents = numEvents; } + + public boolean isExpired() { return expired; } + public void setExpired(boolean expired) { this.expired = expired; } + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala new file mode 100644 index 000000000000..2ce792c00849 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.sql.streaming + +import java.sql.Timestamp + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.streaming._ + + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network. + * + * Usage: MapGroupsWithState + * and describe the TCP server that Structured Streaming + * would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example sql.streaming.StructuredNetworkWordCount + * localhost 9999` + */ +object StructuredSessionization { + + def main(args: Array[String]): Unit = { + if (args.length < 2) { + System.err.println("Usage: StructuredNetworkWordCount ") + System.exit(1) + } + + val host = args(0) + val port = args(1).toInt + + val spark = SparkSession + .builder + .appName("StructuredSessionization") + .getOrCreate() + + import spark.implicits._ + + // Create DataFrame representing the stream of input lines from connection to host:port + val lines = spark.readStream + .format("socket") + .option("host", host) + .option("port", port) + .option("includeTimestamp", true) + .load() + + // Split the lines into words, treat words as sessionId of events + val events = lines + .as[(String, Timestamp)] + .flatMap { case (line, timestamp) => + line.split(" ").map(word => Event(sessionId = word, timestamp)) + } + + // Sessionize the events. Track number of events, start and end timestamps of session, and + // and report session updates. + val sessionUpdates = events + .groupByKey(event => event.sessionId) + .mapGroupsWithState[SessionInfo, SessionUpdate](GroupStateTimeout.ProcessingTimeTimeout) { + + case (sessionId: String, events: Iterator[Event], state: GroupState[SessionInfo]) => + + // If timed out, then remove session and send final update + if (state.hasTimedOut) { + val finalUpdate = + SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = true) + state.remove() + finalUpdate + } else { + // Update start and end timestamps in session + val timestamps = events.map(_.timestamp.getTime).toSeq + val updatedSession = if (state.exists) { + val oldSession = state.get + SessionInfo( + oldSession.numEvents + timestamps.size, + oldSession.startTimestampMs, + math.max(oldSession.endTimestampMs, timestamps.max)) + } else { + SessionInfo(timestamps.size, timestamps.min, timestamps.max) + } + state.update(updatedSession) + + // Set timeout such that the session will be expired if no data received for 10 seconds + state.setTimeoutDuration("10 seconds") + SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = false) + } + } + + // Start running the query that prints the session updates to the console + val query = sessionUpdates + .writeStream + .outputMode("update") + .format("console") + .start() + + query.awaitTermination() + } +} +/** User-defined data type representing the input events */ +case class Event(sessionId: String, timestamp: Timestamp) + +/** + * User-defined data type for storing a session information as state in mapGroupsWithState. + * + * @param numEvents total number of events received in the session + * @param startTimestampMs timestamp of first event received in the session when it started + * @param endTimestampMs timestamp of last event received in the session before it expired + */ +case class SessionInfo( + numEvents: Int, + startTimestampMs: Long, + endTimestampMs: Long) { + + /** Duration of the session, between the first and last events */ + def durationMs: Long = endTimestampMs - startTimestampMs +} + +/** + * User-defined data type representing the update information returned by mapGroupsWithState. + * + * @param id Id of the session + * @param durationMs Duration the session was active, that is, from first event to its expiry + * @param numEvents Number of events received by the session while it was active + * @param expired Is the session active or expired + */ +case class SessionUpdate( + id: String, + durationMs: Long, + numEvents: Int, + expired: Boolean) + +// scalastyle:on println + diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 6391d6269c5a..0046ba7e43d1 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 65cc80619569..d604c1ac001a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -218,13 +218,28 @@ class FPGrowthModel private[ml] ( def setPredictionCol(value: String): this.type = set(predictionCol, value) /** - * Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe + * Cache minConfidence and associationRules to avoid redundant computation for association rules + * during transform. The associationRules will only be re-computed when minConfidence changed. + */ + @transient private var _cachedMinConf: Double = Double.NaN + + @transient private var _cachedRules: DataFrame = _ + + /** + * Get association rules fitted using the minConfidence. Returns a dataframe * with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and * "consequent" are Array[T] and "confidence" is Double. */ @Since("2.2.0") - @transient lazy val associationRules: DataFrame = { - AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + @transient def associationRules: DataFrame = { + if ($(minConfidence) == _cachedMinConf) { + _cachedRules + } else { + _cachedRules = AssociationRules + .getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + _cachedMinConf = $(minConfidence) + _cachedRules + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index 4c63a2a88c6c..c763a4cef1af 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -164,7 +164,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC]( dataset.as[LabeledPoint], estimator, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[LinearSVCModel, LinearSVC]( - dataset.as[LabeledPoint], estimator, 2, modelEquals) + dataset.as[LabeledPoint], estimator, 2, modelEquals, outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[LinearSVCModel, LinearSVC]( dataset.as[LabeledPoint], estimator, modelEquals, 42L) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 1b6448037349..f0648d0936a1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1874,7 +1874,7 @@ class LogisticRegressionSuite MLTestingUtils.testArbitrarilyScaledWeights[LogisticRegressionModel, LogisticRegression]( dataset.as[LabeledPoint], estimator, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[LogisticRegressionModel, LogisticRegression]( - dataset.as[LabeledPoint], estimator, numClasses, modelEquals) + dataset.as[LabeledPoint], estimator, numClasses, modelEquals, outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[LogisticRegressionModel, LogisticRegression]( dataset.as[LabeledPoint], estimator, modelEquals, seed) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 4d5d299d1408..d41c5b533ded 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -178,7 +178,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa MLTestingUtils.testArbitrarilyScaledWeights[NaiveBayesModel, NaiveBayes]( dataset.as[LabeledPoint], estimatorNoSmoothing, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[NaiveBayesModel, NaiveBayes]( - dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, modelEquals) + dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, modelEquals, outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[NaiveBayesModel, NaiveBayes]( dataset.as[LabeledPoint], estimatorWithSmoothing, modelEquals, seed) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 4603a618d2f9..6bec057511cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.fpm import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ @@ -85,38 +85,58 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty) } + test("FPGrowth prediction should not contain duplicates") { + // This should generate rule 1 -> 3, 2 -> 3 + val dataset = spark.createDataFrame(Seq( + Array("1", "3"), + Array("2", "3") + ).map(Tuple1(_))).toDF("items") + val model = new FPGrowth().fit(dataset) + + val prediction = model.transform( + spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") + ).first().getAs[Seq[String]]("prediction") + + assert(prediction === Seq("3")) + } + + test("FPGrowthModel setMinConfidence should affect rules generation and transform") { + val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset) + val oldRulesNum = model.associationRules.count() + val oldPredict = model.transform(dataset) + + model.setMinConfidence(0.8765) + assert(oldRulesNum > model.associationRules.count()) + assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet)) + + // association rules should stay the same for same minConfidence + model.setMinConfidence(0.1) + assert(oldRulesNum === model.associationRules.count()) + assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet)) + } + test("FPGrowth parameter check") { val fpGrowth = new FPGrowth().setMinSupport(0.4567) val model = fpGrowth.fit(dataset) .setMinConfidence(0.5678) assert(fpGrowth.getMinSupport === 0.4567) assert(model.getMinConfidence === 0.5678) + MLTestingUtils.checkCopy(model) } test("read/write") { def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = { - assert(model.freqItemsets.sort("items").collect() === - model2.freqItemsets.sort("items").collect()) + assert(model.freqItemsets.collect().toSet.equals( + model2.freqItemsets.collect().toSet)) + assert(model.associationRules.collect().toSet.equals( + model2.associationRules.collect().toSet)) + assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals( + model2.setMinConfidence(0.9).associationRules.collect().toSet)) } val fPGrowth = new FPGrowth() testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, FPGrowthSuite.allParamSettings, checkModelData) } - - test("FPGrowth prediction should not contain duplicates") { - // This should generate rule 1 -> 3, 2 -> 3 - val dataset = spark.createDataFrame(Seq( - Array("1", "3"), - Array("2", "3") - ).map(Tuple1(_))).toDF("items") - val model = new FPGrowth().fit(dataset) - - val prediction = model.transform( - spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") - ).first().getAs[Seq[String]]("prediction") - - assert(prediction === Seq("3")) - } } object FPGrowthSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 6a51e75e12a3..c6a267b7283d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -842,7 +842,8 @@ class LinearRegressionSuite MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression]( datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression]( - datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, modelEquals) + datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, modelEquals, + outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[LinearRegressionModel, LinearRegression]( datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals, seed) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index f1ed568d5e60..578f31c8e7db 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -260,12 +260,13 @@ object MLTestingUtils extends SparkFunSuite { data: Dataset[LabeledPoint], estimator: E with HasWeightCol, numClasses: Int, - modelEquals: (M, M) => Unit): Unit = { + modelEquals: (M, M) => Unit, + outlierRatio: Int): Unit = { import data.sqlContext.implicits._ val outlierDS = data.withColumn("weight", lit(1.0)).as[Instance].flatMap { case Instance(l, w, f) => val outlierLabel = if (numClasses == 0) -l else numClasses - l - 1 - List.fill(3)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f)) + List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f)) } val trueModel = estimator.set(estimator.weightCol, "").fit(data) val outlierModel = estimator.set(estimator.weightCol, "weight").fit(outlierDS) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a24512f53c52..774caf53f3a4 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -25,6 +25,8 @@ else: from itertools import imap as map +import warnings + from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer @@ -1281,7 +1283,7 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) @since(1.4) - def replace(self, to_replace, value, subset=None): + def replace(self, to_replace, value=None, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. @@ -1326,43 +1328,72 @@ def replace(self, to_replace, value, subset=None): |null| null|null| +----+------+----+ """ - if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)): + # Helper functions + def all_of(types): + """Given a type or tuple of types and a sequence of xs + check if each x is instance of type(s) + + >>> all_of(bool)([True, False]) + True + >>> all_of(basestring)(["a", 1]) + False + """ + def all_of_(xs): + return all(isinstance(x, types) for x in xs) + return all_of_ + + all_of_bool = all_of(bool) + all_of_str = all_of(basestring) + all_of_numeric = all_of((float, int, long)) + + # Validate input types + valid_types = (bool, float, int, long, basestring, list, tuple) + if not isinstance(to_replace, valid_types + (dict, )): raise ValueError( - "to_replace should be a float, int, long, string, list, tuple, or dict") + "to_replace should be a float, int, long, string, list, tuple, or dict. " + "Got {0}".format(type(to_replace))) - if not isinstance(value, (float, int, long, basestring, list, tuple)): - raise ValueError("value should be a float, int, long, string, list, or tuple") + if not isinstance(value, valid_types) and not isinstance(to_replace, dict): + raise ValueError("If to_replace is not a dict, value should be " + "a float, int, long, string, list, or tuple. " + "Got {0}".format(type(value))) + + if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)): + if len(to_replace) != len(value): + raise ValueError("to_replace and value lists should be of the same length. " + "Got {0} and {1}".format(len(to_replace), len(value))) - rep_dict = dict() + if not (subset is None or isinstance(subset, (list, tuple, basestring))): + raise ValueError("subset should be a list or tuple of column names, " + "column name or None. Got {0}".format(type(subset))) + # Reshape input arguments if necessary if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] - if isinstance(to_replace, tuple): - to_replace = list(to_replace) + if isinstance(value, (float, int, long, basestring)): + value = [value for _ in range(len(to_replace))] - if isinstance(value, tuple): - value = list(value) - - if isinstance(to_replace, list) and isinstance(value, list): - if len(to_replace) != len(value): - raise ValueError("to_replace and value lists should be of the same length") - rep_dict = dict(zip(to_replace, value)) - elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): - rep_dict = dict([(tr, value) for tr in to_replace]) - elif isinstance(to_replace, dict): + if isinstance(to_replace, dict): rep_dict = to_replace + if value is not None: + warnings.warn("to_replace is a dict and value is not None. value will be ignored.") + else: + rep_dict = dict(zip(to_replace, value)) - if subset is None: - return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) - elif isinstance(subset, basestring): + if isinstance(subset, basestring): subset = [subset] - if not isinstance(subset, (list, tuple)): - raise ValueError("subset should be a list or tuple of column names") + # Verify we were not passed in mixed type generics." + if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values()) + for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): + raise ValueError("Mixed type replacements are not supported") - return DataFrame( - self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) + if subset is None: + return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) + else: + return DataFrame( + self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) @since(2.0) def approxQuantile(self, col, probabilities, relativeError): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index db41b4edb6dd..2b2444304e04 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1779,6 +1779,78 @@ def test_replace(self): self.assertEqual(row.age, 10) self.assertEqual(row.height, None) + # replace with lists + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first() + self.assertTupleEqual(row, (u'Ann', 10, 80.1)) + + # replace with dict + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first() + self.assertTupleEqual(row, (u'Alice', 11, 80.1)) + + # test backward compatibility with dummy value + dummy_value = 1 + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first() + self.assertTupleEqual(row, (u'Bob', 10, 80.1)) + + # test dict with mixed numerics + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first() + self.assertTupleEqual(row, (u'Alice', -10, 90.5)) + + # replace with tuples + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first() + self.assertTupleEqual(row, (u'Bob', 10, 80.1)) + + # replace multiple columns + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.0)) + + # test for mixed numerics + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.5)) + + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.5)) + + # replace with boolean + row = (self + .spark.createDataFrame([(u'Alice', 10, 80.0)], schema) + .selectExpr("name = 'Bob'", 'age <= 15') + .replace(False, True).first()) + self.assertTupleEqual(row, (True, True)) + + # should fail if subset is not list, tuple or None + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first() + + # should fail if to_replace and value have different length + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first() + + # should fail if when received unexpected type + with self.assertRaises(ValueError): + from datetime import datetime + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first() + + # should fail if provided mixed type replacements + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first() + + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() + def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5b5d547f8fe5..e685c2bed50a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -752,7 +752,7 @@ object SQLConf { buildConf("spark.sql.session.timeZone") .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""") .stringConf - .createWithDefault(TimeZone.getDefault().getID()) + .createWithDefaultFunction(() => TimeZone.getDefault.getID) val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.windowExec.buffer.spill.threshold") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 022c2f5629e8..cb42e9e4560c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -347,7 +347,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout): Dataset[U] = { - mapGroupsWithState[S, U]( + mapGroupsWithState[S, U](timeoutConf)( (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s) )(stateEncoder, outputEncoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala index 02996ac854f6..d188566f822b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala @@ -47,21 +47,22 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = extends TriggerExecutor with Logging { private val intervalMs = processingTime.intervalMs + require(intervalMs >= 0) - override def execute(batchRunner: () => Boolean): Unit = { + override def execute(triggerHandler: () => Boolean): Unit = { while (true) { - val batchStartTimeMs = clock.getTimeMillis() - val terminated = !batchRunner() + val triggerTimeMs = clock.getTimeMillis + val nextTriggerTimeMs = nextBatchTime(triggerTimeMs) + val terminated = !triggerHandler() if (intervalMs > 0) { - val batchEndTimeMs = clock.getTimeMillis() - val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs + val batchElapsedTimeMs = clock.getTimeMillis - triggerTimeMs if (batchElapsedTimeMs > intervalMs) { notifyBatchFallingBehind(batchElapsedTimeMs) } if (terminated) { return } - clock.waitTillTime(nextBatchTime(batchEndTimeMs)) + clock.waitTillTime(nextTriggerTimeMs) } else { if (terminated) { return @@ -70,7 +71,7 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = } } - /** Called when a batch falls behind. Expose for test only */ + /** Called when a batch falls behind */ def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = { logWarning("Current batch is falling behind. The trigger interval is " + s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds") @@ -83,6 +84,6 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = * an interval of `100 ms`, `nextBatchTime(nextBatchTime(0)) = 200` rather than `0`). */ def nextBatchTime(now: Long): Long = { - now / intervalMs * intervalMs + intervalMs + if (intervalMs == 0) now else now / intervalMs * intervalMs + intervalMs } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 58be2d1da281..d11045fb6ac8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -150,7 +150,7 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { /** - * Find out duplicated exchanges in the spark plan, then use the same exchange for all the + * Find out duplicated subqueries in the spark plan, then use the same subquery result for all the * references. */ case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { @@ -159,7 +159,7 @@ case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { if (!conf.exchangeReuseEnabled) { return plan } - // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. + // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls. val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]() plan transformAllExpressions { case sub: ExecSubqueryExpression => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index 15df906ca7b1..c659ac7fcf3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger. - * That is, in every batch of the `streaming.StreamingQuery`, + * That is, in every batch of the `StreamingQuery`, * the function will be invoked once for each group that has data in the trigger. Furthermore, * if timeout is set, then the function will invoked on timed out groups (more detail below). * @@ -42,12 +42,23 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * - The key of the group. * - An iterator containing all the values for this group. * - A user-defined state object set by previous invocations of the given function. + * * In case of a batch Dataset, there is only one invocation and state object will be empty as * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` * is equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have * no effect. * - * Important points to note about the function. + * The major difference between `mapGroupsWithState` and `flatMapGroupsWithState` is that the + * former allows the function to return one and only one record, whereas the latter + * allows the function to return any number of records (including no records). Furthermore, the + * `flatMapGroupsWithState` is associated with an operation output mode, which can be either + * `Append` or `Update`. Semantically, this defines whether the output records of one trigger + * is effectively replacing the previously output records (from previous triggers) or is appending + * to the list of previously output records. Essentially, this defines how the Result Table (refer + * to the semantics in the programming guide) is updated, and allows us to reason about the + * semantics of later operations. + * + * Important points to note about the function (both mapGroupsWithState and flatMapGroupsWithState). * - In a trigger, the function will be called only the groups present in the batch. So do not * assume that the function will be called in every trigger for every group that has state. * - There is no guaranteed ordering of values in the iterator in the function, neither with diff --git a/sql/core/src/test/resources/tpcds/q77.sql b/sql/core/src/test/resources/tpcds/q77.sql index 7830f96e7651..a69df9fbcd36 100755 --- a/sql/core/src/test/resources/tpcds/q77.sql +++ b/sql/core/src/test/resources/tpcds/q77.sql @@ -36,7 +36,7 @@ WITH ss AS sum(cr_net_loss) AS profit_loss FROM catalog_returns, date_dim WHERE cr_returned_date_sk = d_date_sk - AND d_date BETWEEN cast('2000-08-03]' AS DATE) AND + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND (cast('2000-08-03' AS DATE) + INTERVAL 30 days)), ws AS (SELECT diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala index 00d5e051de35..007554a83f54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -17,14 +17,24 @@ package org.apache.spark.sql.execution.streaming -import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable + +import org.eclipse.jetty.util.ConcurrentHashSet +import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.streaming.ProcessingTime -import org.apache.spark.util.{Clock, ManualClock, SystemClock} +import org.apache.spark.sql.streaming.util.StreamManualClock class ProcessingTimeExecutorSuite extends SparkFunSuite { + val timeout = 10.seconds + test("nextBatchTime") { val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(100)) assert(processingTimeExecutor.nextBatchTime(0) === 100) @@ -35,6 +45,57 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { assert(processingTimeExecutor.nextBatchTime(150) === 200) } + test("trigger timing") { + val triggerTimes = new ConcurrentHashSet[Int] + val clock = new StreamManualClock() + @volatile var continueExecuting = true + @volatile var clockIncrementInTrigger = 0L + val executor = ProcessingTimeExecutor(ProcessingTime("1000 milliseconds"), clock) + val executorThread = new Thread() { + override def run(): Unit = { + executor.execute(() => { + // Record the trigger time, increment clock if needed and + triggerTimes.add(clock.getTimeMillis.toInt) + clock.advance(clockIncrementInTrigger) + clockIncrementInTrigger = 0 // reset this so that there are no runaway triggers + continueExecuting + }) + } + } + executorThread.start() + // First batch should execute immediately, then executor should wait for next one + eventually { + assert(triggerTimes.contains(0)) + assert(clock.isStreamWaitingAt(0)) + assert(clock.isStreamWaitingFor(1000)) + } + + // Second batch should execute when clock reaches the next trigger time. + // If next trigger takes less than the trigger interval, executor should wait for next one + clockIncrementInTrigger = 500 + clock.setTime(1000) + eventually { + assert(triggerTimes.contains(1000)) + assert(clock.isStreamWaitingAt(1500)) + assert(clock.isStreamWaitingFor(2000)) + } + + // If next trigger takes less than the trigger interval, executor should immediately execute + // another one + clockIncrementInTrigger = 1500 + clock.setTime(2000) // allow another trigger by setting clock to 2000 + eventually { + // Since the next trigger will take 1500 (which is more than trigger interval of 1000) + // executor will immediately execute another trigger + assert(triggerTimes.contains(2000) && triggerTimes.contains(3500)) + assert(clock.isStreamWaitingAt(3500)) + assert(clock.isStreamWaitingFor(4000)) + } + continueExecuting = false + clock.advance(1000) + waitForThreadJoin(executorThread) + } + test("calling nextBatchTime with the result of a previous call should return the next interval") { val intervalMS = 100 val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMS)) @@ -54,7 +115,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMs)) processingTimeExecutor.execute(() => { batchCounts += 1 - // If the batch termination works well, batchCounts should be 3 after `execute` + // If the batch termination works correctly, batchCounts should be 3 after `execute` batchCounts < 3 }) assert(batchCounts === 3) @@ -66,9 +127,8 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { } test("notifyBatchFallingBehind") { - val clock = new ManualClock() + val clock = new StreamManualClock() @volatile var batchFallingBehindCalled = false - val latch = new CountDownLatch(1) val t = new Thread() { override def run(): Unit = { val processingTimeExecutor = new ProcessingTimeExecutor(ProcessingTime(100), clock) { @@ -77,7 +137,6 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { } } processingTimeExecutor.execute(() => { - latch.countDown() clock.waitTillTime(200) false }) @@ -85,9 +144,17 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { } t.start() // Wait until the batch is running so that we don't call `advance` too early - assert(latch.await(10, TimeUnit.SECONDS), "the batch has not yet started in 10 seconds") + eventually { assert(clock.isStreamWaitingFor(200)) } clock.advance(200) - t.join() + waitForThreadJoin(t) assert(batchFallingBehindCalled === true) } + + private def eventually(body: => Unit): Unit = { + Eventually.eventually(Timeout(timeout)) { body } + } + + private def waitForThreadJoin(thread: Thread): Unit = { + failAfter(timeout) { thread.join() } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 171877abe6e9..26967782f77c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.ExistsThrowsExceptionFileSystem._ +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index c8e31e3ca2e0..85aa7dbe9ed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -21,8 +21,6 @@ import java.sql.Date import java.util.concurrent.ConcurrentHashMap import org.scalatest.BeforeAndAfterAll -import org.scalatest.concurrent.Eventually.eventually -import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction @@ -35,6 +33,7 @@ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate} import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} /** Class to check custom state types */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 388f15405e70..5ab9dc2bc776 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 951ff2ca0d68..03aa45b61688 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -214,24 +214,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { AssertOnQuery(query => { func(query); true }) } - class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable { - private var waitStartTime: Option[Long] = None - - override def waitTillTime(targetTime: Long): Long = synchronized { - try { - waitStartTime = Some(getTimeMillis()) - super.waitTillTime(targetTime) - } finally { - waitStartTime = None - } - } - - def isStreamWaitingAt(time: Long): Boolean = synchronized { - waitStartTime == Some(time) - } - } - - /** * Executes the specified actions on the given streaming DataFrame and provides helpful * error messages in the case of failures or incorrect answers. @@ -242,6 +224,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = synchronized { + import org.apache.spark.sql.streaming.util.StreamManualClock + // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently // because this method assumes there is only one active query in its `StreamingQueryListener` // and it may not work correctly when multiple `testStream`s run concurrently. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 600c039cd0b9..e5d5b4f32882 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.OutputMode._ +import org.apache.spark.sql.streaming.util.StreamManualClock object FailureSinglton { var firstTime = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 03dad8a6ddbc..b8a694c17731 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.util.JsonProtocol class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1172531fe998..2ebbfcd22b97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider} +import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.util.ManualClock @@ -207,46 +207,53 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi /** Custom MemoryStream that waits for manual clock to reach a time */ val inputData = new MemoryStream[Int](0, sqlContext) { - // Wait for manual clock to be 100 first time there is data + // getOffset should take 50 ms the first time it is called override def getOffset: Option[Offset] = { val offset = super.getOffset if (offset.nonEmpty) { - clock.waitTillTime(300) + clock.waitTillTime(1050) } offset } - // Wait for manual clock to be 300 first time there is data + // getBatch should take 100 ms the first time it is called override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - clock.waitTillTime(600) + if (start.isEmpty) clock.waitTillTime(1150) super.getBatch(start, end) } } - // This is to make sure thatquery waits for manual clock to be 600 first time there is data - val mapped = inputData.toDS().as[Long].map { x => - clock.waitTillTime(1100) + // query execution should take 350 ms the first time it is called + val mapped = inputData.toDS.coalesce(1).as[Long].map { x => + clock.waitTillTime(1500) // this will only wait the first time when clock < 1500 10 / x }.agg(count("*")).as[Long] - case class AssertStreamExecThreadToWaitForClock() + case class AssertStreamExecThreadIsWaitingForTime(targetTime: Long) extends AssertOnQuery(q => { eventually(Timeout(streamingTimeout)) { if (q.exception.isEmpty) { - assert(clock.asInstanceOf[StreamManualClock].isStreamWaitingAt(clock.getTimeMillis)) + assert(clock.isStreamWaitingFor(targetTime)) } } if (q.exception.isDefined) { throw q.exception.get } true - }, "") + }, "") { + override def toString: String = s"AssertStreamExecThreadIsWaitingForTime($targetTime)" + } + + case class AssertClockTime(time: Long) + extends AssertOnQuery(q => clock.getTimeMillis() === time, "") { + override def toString: String = s"AssertClockTime($time)" + } var lastProgressBeforeStop: StreamingQueryProgress = null testStream(mapped, OutputMode.Complete)( - StartStream(ProcessingTime(100), triggerClock = clock), - AssertStreamExecThreadToWaitForClock(), + StartStream(ProcessingTime(1000), triggerClock = clock), + AssertStreamExecThreadIsWaitingForTime(1000), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -254,33 +261,37 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi // Test status and progress while offset is being fetched AddData(inputData, 1, 2), - AdvanceManualClock(100), // time = 100 to start new trigger, will block on getOffset - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on getOffset + AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch is being fetched - AdvanceManualClock(200), // time = 300 to unblock getOffset, will block on getBatch - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(50), // time = 1050 to unblock getOffset + AssertClockTime(1050), + AssertStreamExecThreadIsWaitingForTime(1150), // will block on getBatch that needs 1150 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch is being processed - AdvanceManualClock(300), // time = 600 to unblock getBatch, will block in Spark job + AdvanceManualClock(100), // time = 1150 to unblock getBatch + AssertClockTime(1150), + AssertStreamExecThreadIsWaitingForTime(1500), // will block in Spark job that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AdvanceManualClock(500), // time = 1100 to unblock job - AssertOnQuery { _ => clock.getTimeMillis() === 1100 }, + AssertOnQuery { _ => clock.getTimeMillis() === 1150 }, + AdvanceManualClock(350), // time = 1500 to unblock job + AssertClockTime(1500), CheckAnswer(2), - AssertStreamExecThreadToWaitForClock(), + AssertStreamExecThreadIsWaitingForTime(2000), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -293,21 +304,21 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.id === query.id) assert(progress.name === query.name) assert(progress.batchId === 0) - assert(progress.timestamp === "1970-01-01T00:00:00.100Z") // 100 ms in UTC + assert(progress.timestamp === "1970-01-01T00:00:01.000Z") // 100 ms in UTC assert(progress.numInputRows === 2) - assert(progress.processedRowsPerSecond === 2.0) + assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("getOffset") === 200) - assert(progress.durationMs.get("getBatch") === 300) + assert(progress.durationMs.get("getOffset") === 50) + assert(progress.durationMs.get("getBatch") === 100) assert(progress.durationMs.get("queryPlanning") === 0) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("triggerExecution") === 1000) + assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") assert(progress.sources(0).startOffset === null) assert(progress.sources(0).endOffset !== null) - assert(progress.sources(0).processedRowsPerSecond === 2.0) + assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms assert(progress.stateOperators.length === 1) assert(progress.stateOperators(0).numRowsUpdated === 1) @@ -317,9 +328,12 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi true }, + // Test whether input rate is updated after two batches + AssertStreamExecThreadIsWaitingForTime(2000), // blocked waiting for next trigger time AddData(inputData, 1, 2), - AdvanceManualClock(100), // allow another trigger - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(500), // allow another trigger + AssertClockTime(2000), + AssertStreamExecThreadIsWaitingForTime(3000), // will block waiting for next trigger time CheckAnswer(4), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), @@ -327,13 +341,14 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery { query => assert(query.recentProgress.last.eq(query.lastProgress)) assert(query.lastProgress.batchId === 1) - assert(query.lastProgress.sources(0).inputRowsPerSecond === 1.818) + assert(query.lastProgress.inputRowsPerSecond === 2.0) + assert(query.lastProgress.sources(0).inputRowsPerSecond === 2.0) true }, // Test status and progress after data is not available for a trigger - AdvanceManualClock(100), // allow another trigger - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(1000), // allow another trigger + AssertStreamExecThreadIsWaitingForTime(4000), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -350,10 +365,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Stopped"), // Test status and progress after query terminated with error - StartStream(ProcessingTime(100), triggerClock = clock), - AdvanceManualClock(100), // ensure initial trigger completes before AddData + StartStream(ProcessingTime(1000), triggerClock = clock), + AdvanceManualClock(1000), // ensure initial trigger completes before AddData AddData(inputData, 0), - AdvanceManualClock(100), // allow another trigger + AdvanceManualClock(1000), // allow another trigger ExpectFailure[SparkException](), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), @@ -678,5 +693,5 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi object StreamingQuerySuite { // Singleton reference to clock that does not get serialized in task closures - var clock: ManualClock = null + var clock: StreamManualClock = null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala new file mode 100644 index 000000000000..c769a790a416 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.util + +import org.apache.spark.util.ManualClock + +/** + * ManualClock used for streaming tests that allows checking whether the stream is waiting + * on the clock at expected times. + */ +class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable { + private var waitStartTime: Option[Long] = None + private var waitTargetTime: Option[Long] = None + + override def waitTillTime(targetTime: Long): Long = synchronized { + try { + waitStartTime = Some(getTimeMillis()) + waitTargetTime = Some(targetTime) + super.waitTillTime(targetTime) + } finally { + waitStartTime = None + waitTargetTime = None + } + } + + /** Is the streaming thread waiting for the clock to advance when it is at the given time */ + def isStreamWaitingAt(time: Long): Boolean = synchronized { + waitStartTime == Some(time) + } + + /** Is the streaming thread waiting for clock to advance to the given time */ + def isStreamWaitingFor(target: Long): Boolean = synchronized { + waitTargetTime == Some(target) + } +} +