diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 96fb12ce5e0b9..997de9511ca3e 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -878,6 +878,12 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi
val runningCounts = pairs.updateStateByKey[Int](updateFunction _)
{% endhighlight %}
+The update function will be called for each word, with `newValues` having a sequence of 1's (from
+the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete
+Scala code, take a look at the example
+[StatefulNetworkWordCount.scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache
+/spark/examples/streaming/StatefulNetworkWordCount.scala).
+
@@ -899,6 +905,13 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi
JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction);
{% endhighlight %}
+The update function will be called for each word, with `newValues` having a sequence of 1's (from
+the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete
+Java code, take a look at the example
+[JavaStatefulNetworkWordCount.java]({{site
+.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming
+/JavaStatefulNetworkWordCount.java).
+
@@ -916,14 +929,14 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi
runningCounts = pairs.updateStateByKey(updateFunction)
{% endhighlight %}
-
-
-
The update function will be called for each word, with `newValues` having a sequence of 1's (from
the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete
-Scala code, take a look at the example
+Python code, take a look at the example
[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py).
+
+
+
Note that using `updateStateByKey` requires the checkpoint directory to be configured, which is
discussed in detail in the [checkpointing](#checkpointing) section.
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
new file mode 100644
index 0000000000000..d46c7107c7a21
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.streaming;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.regex.Pattern;
+
+import scala.Tuple2;
+
+import com.google.common.base.Optional;
+import com.google.common.collect.Lists;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.StorageLevels;
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.streaming.Durations;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+
+/**
+ * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every
+ * second starting with initial value of word count.
+ * Usage: JavaStatefulNetworkWordCount
+ * and describe the TCP server that Spark Streaming would connect to receive
+ * data.
+ *
+ * To run this on your local machine, you need to first run a Netcat server
+ * `$ nc -lk 9999`
+ * and then run the example
+ * `$ bin/run-example
+ * org.apache.spark.examples.streaming.JavaStatefulNetworkWordCount localhost 9999`
+ */
+public class JavaStatefulNetworkWordCount {
+ private static final Pattern SPACE = Pattern.compile(" ");
+
+ public static void main(String[] args) {
+ if (args.length < 2) {
+ System.err.println("Usage: JavaStatefulNetworkWordCount ");
+ System.exit(1);
+ }
+
+ StreamingExamples.setStreamingLogLevels();
+
+ // Update the cumulative count function
+ final Function2, Optional, Optional> updateFunction =
+ new Function2, Optional, Optional>() {
+ @Override
+ public Optional call(List values, Optional state) {
+ Integer newSum = state.or(0);
+ for (Integer value : values) {
+ newSum += value;
+ }
+ return Optional.of(newSum);
+ }
+ };
+
+ // Create the context with a 1 second batch size
+ SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount");
+ JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
+ ssc.checkpoint(".");
+
+ // Initial RDD input to updateStateByKey
+ List> tuples = Arrays.asList(new Tuple2("hello", 1),
+ new Tuple2("world", 1));
+ JavaPairRDD initialRDD = ssc.sc().parallelizePairs(tuples);
+
+ JavaReceiverInputDStream lines = ssc.socketTextStream(
+ args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2);
+
+ JavaDStream words = lines.flatMap(new FlatMapFunction() {
+ @Override
+ public Iterable call(String x) {
+ return Lists.newArrayList(SPACE.split(x));
+ }
+ });
+
+ JavaPairDStream wordsDstream = words.mapToPair(
+ new PairFunction() {
+ @Override
+ public Tuple2 call(String s) {
+ return new Tuple2(s, 1);
+ }
+ });
+
+ // This will give a Dstream made of state (which is the cumulative count of the words)
+ JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction,
+ new HashPartitioner(ssc.sc().defaultParallelism()), initialRDD);
+
+ stateDstream.print();
+ ssc.start();
+ ssc.awaitTermination();
+ }
+}
diff --git a/pom.xml b/pom.xml
index 53372d5cfc624..6810d71be4230 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1083,6 +1083,12 @@
scala-maven-plugin
3.2.0
+
+ eclipse-add-source
+
+ add-source
+
+
scala-compile-first
process-resources
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index 24848634de9cf..e63042ccac7b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -67,7 +67,7 @@ private[sql] class DefaultSource
case SaveMode.Append =>
sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}")
case SaveMode.Overwrite =>
- fs.delete(filesystemPath, true)
+ //fs.delete(filesystemPath, true)
true
case SaveMode.ErrorIfExists =>
sys.error(s"path $path already exists.")
@@ -76,12 +76,18 @@ private[sql] class DefaultSource
} else {
true
}
- if (doSave) {
+ val relation = if (doSave) {
// Only save data when the save mode is not ignore.
- data.toJSON.saveAsTextFile(path)
+ //data.toJSON.saveAsTextFile(path)
+ val createdRelation = createRelation(sqlContext,parameters, data.schema)
+ createdRelation.asInstanceOf[JSONRelation].insert(data, true)
+
+ createdRelation
+ } else {
+ createRelation(sqlContext, parameters, data.schema)
}
- createRelation(sqlContext, parameters, data.schema)
+ relation
}
}
@@ -92,7 +98,15 @@ private[sql] case class JSONRelation(
@transient val sqlContext: SQLContext)
extends TableScan with InsertableRelation {
// TODO: Support partitioned JSON relation.
- private def baseRDD = sqlContext.sparkContext.textFile(path)
+ val filesystemPath = new Path(path)
+ val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
+ // TableScan can support base on ordinary file, but InsertableRelation only base on directory.
+ val newPath = if (fs.exists(filesystemPath) && fs.getFileStatus(filesystemPath).isFile()) {
+ filesystemPath
+ } else {
+ new Path(filesystemPath.toUri.toString,"*")
+ }
+ private def baseRDD = sqlContext.sparkContext.textFile(newPath.toUri.toString)
override val schema = userSpecifiedSchema.getOrElse(
JsonRDD.nullTypeToStringType(
@@ -104,21 +118,35 @@ private[sql] case class JSONRelation(
override def buildScan() =
JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.conf.columnNameOfCorruptRecord)
+ private def isTemporaryFile(file: Path): Boolean = {
+ file.getName == "_temporary"
+ }
+
override def insert(data: DataFrame, overwrite: Boolean) = {
- val filesystemPath = new Path(path)
- val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
+
+ // If the path exists, it must be a directory, error for not.
+ // Otherwise we create a directory with the path name.
+ if (fs.exists(filesystemPath) && !fs.getFileStatus(filesystemPath).isDirectory) {
+ sys.error("a CREATE [TEMPORARY] TABLE AS SELECT statement need the path must be directory")
+ }
if (overwrite) {
+ val temporaryPath = new Path(path, "_temporary")
+ val dataPath = new Path(path, "data")
+ // Write the data.
+ data.toJSON.saveAsTextFile(temporaryPath.toUri.toString)
+ val pathsToDelete = fs.listStatus(filesystemPath).filter(
+ f => !isTemporaryFile(f.getPath)).map(_.getPath)
+
try {
- fs.delete(filesystemPath, true)
+ pathsToDelete.foreach(fs.delete(_,true))
} catch {
case e: IOException =>
throw new IOException(
- s"Unable to clear output directory ${filesystemPath.toString} prior"
- + s" to INSERT OVERWRITE a JSON table:\n${e.toString}")
+ s"Unable to delete original data in directory ${filesystemPath.toString} when"
+ + s" run INSERT OVERWRITE a JSON table:\n${e.toString}")
}
- // Write the data.
- data.toJSON.saveAsTextFile(path)
+ fs.rename(temporaryPath,dataPath)
// Right now, we assume that the schema is not changed. We will not update the schema.
// schema = data.schema
} else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index 29caed9337ff6..6e58d8155a350 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -29,10 +29,13 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
import caseInsensisitiveContext._
var path: File = null
+ var existPath: File = null
override def beforeAll(): Unit = {
path = util.getTempFilePath("jsonCTAS").getCanonicalFile
+ existPath = util.getTempFilePath("existJsonCTAS").getCanonicalFile
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
+ rdd.saveAsTextFile(existPath.toURI.toString)
jsonRDD(rdd).registerTempTable("jt")
}
@@ -62,6 +65,34 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
dropTempTable("jsonTable")
}
+ test("INSERT OVERWRITE with the source and destination point to the same table") {
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE jsonTable1
+ |USING org.apache.spark.sql.json.DefaultSource
+ |OPTIONS (
+ | path '${existPath.toString}'
+ |)
+ """.stripMargin)
+
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE jsonTable2
+ |USING org.apache.spark.sql.json.DefaultSource
+ |OPTIONS (
+ | path '${existPath.toString}'
+ |) AS
+ |SELECT a, b FROM jsonTable1
+ """.stripMargin)
+
+ checkAnswer(
+ sql("SELECT a, b FROM jsonTable2"),
+ sql("SELECT a, b FROM jt").collect())
+
+ dropTempTable("jsonTable1")
+ dropTempTable("jsonTable2")
+ }
+
test("create a table, drop it and create another one with the same name") {
sql(
s"""