diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 96fb12ce5e0b9..997de9511ca3e 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -878,6 +878,12 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi val runningCounts = pairs.updateStateByKey[Int](updateFunction _) {% endhighlight %} +The update function will be called for each word, with `newValues` having a sequence of 1's (from +the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete +Scala code, take a look at the example +[StatefulNetworkWordCount.scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache +/spark/examples/streaming/StatefulNetworkWordCount.scala). +
@@ -899,6 +905,13 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction); {% endhighlight %} +The update function will be called for each word, with `newValues` having a sequence of 1's (from +the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete +Java code, take a look at the example +[JavaStatefulNetworkWordCount.java]({{site +.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming +/JavaStatefulNetworkWordCount.java). +
@@ -916,14 +929,14 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi runningCounts = pairs.updateStateByKey(updateFunction) {% endhighlight %} -
- - The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete -Scala code, take a look at the example +Python code, take a look at the example [stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py). + + + Note that using `updateStateByKey` requires the checkpoint directory to be configured, which is discussed in detail in the [checkpointing](#checkpointing) section. diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java new file mode 100644 index 0000000000000..d46c7107c7a21 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; + +import scala.Tuple2; + +import com.google.common.base.Optional; +import com.google.common.collect.Lists; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.StorageLevels; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +/** + * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every + * second starting with initial value of word count. + * Usage: JavaStatefulNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive + * data. + *

+ * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example + * org.apache.spark.examples.streaming.JavaStatefulNetworkWordCount localhost 9999` + */ +public class JavaStatefulNetworkWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: JavaStatefulNetworkWordCount "); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + // Update the cumulative count function + final Function2, Optional, Optional> updateFunction = + new Function2, Optional, Optional>() { + @Override + public Optional call(List values, Optional state) { + Integer newSum = state.or(0); + for (Integer value : values) { + newSum += value; + } + return Optional.of(newSum); + } + }; + + // Create the context with a 1 second batch size + SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount"); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); + ssc.checkpoint("."); + + // Initial RDD input to updateStateByKey + List> tuples = Arrays.asList(new Tuple2("hello", 1), + new Tuple2("world", 1)); + JavaPairRDD initialRDD = ssc.sc().parallelizePairs(tuples); + + JavaReceiverInputDStream lines = ssc.socketTextStream( + args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2); + + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + + JavaPairDStream wordsDstream = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }); + + // This will give a Dstream made of state (which is the cumulative count of the words) + JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction, + new HashPartitioner(ssc.sc().defaultParallelism()), initialRDD); + + stateDstream.print(); + ssc.start(); + ssc.awaitTermination(); + } +} diff --git a/pom.xml b/pom.xml index 53372d5cfc624..6810d71be4230 100644 --- a/pom.xml +++ b/pom.xml @@ -1083,6 +1083,12 @@ scala-maven-plugin 3.2.0 + + eclipse-add-source + + add-source + + scala-compile-first process-resources diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 24848634de9cf..e63042ccac7b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -67,7 +67,7 @@ private[sql] class DefaultSource case SaveMode.Append => sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") case SaveMode.Overwrite => - fs.delete(filesystemPath, true) + //fs.delete(filesystemPath, true) true case SaveMode.ErrorIfExists => sys.error(s"path $path already exists.") @@ -76,12 +76,18 @@ private[sql] class DefaultSource } else { true } - if (doSave) { + val relation = if (doSave) { // Only save data when the save mode is not ignore. - data.toJSON.saveAsTextFile(path) + //data.toJSON.saveAsTextFile(path) + val createdRelation = createRelation(sqlContext,parameters, data.schema) + createdRelation.asInstanceOf[JSONRelation].insert(data, true) + + createdRelation + } else { + createRelation(sqlContext, parameters, data.schema) } - createRelation(sqlContext, parameters, data.schema) + relation } } @@ -92,7 +98,15 @@ private[sql] case class JSONRelation( @transient val sqlContext: SQLContext) extends TableScan with InsertableRelation { // TODO: Support partitioned JSON relation. - private def baseRDD = sqlContext.sparkContext.textFile(path) + val filesystemPath = new Path(path) + val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + // TableScan can support base on ordinary file, but InsertableRelation only base on directory. + val newPath = if (fs.exists(filesystemPath) && fs.getFileStatus(filesystemPath).isFile()) { + filesystemPath + } else { + new Path(filesystemPath.toUri.toString,"*") + } + private def baseRDD = sqlContext.sparkContext.textFile(newPath.toUri.toString) override val schema = userSpecifiedSchema.getOrElse( JsonRDD.nullTypeToStringType( @@ -104,21 +118,35 @@ private[sql] case class JSONRelation( override def buildScan() = JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.conf.columnNameOfCorruptRecord) + private def isTemporaryFile(file: Path): Boolean = { + file.getName == "_temporary" + } + override def insert(data: DataFrame, overwrite: Boolean) = { - val filesystemPath = new Path(path) - val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + + // If the path exists, it must be a directory, error for not. + // Otherwise we create a directory with the path name. + if (fs.exists(filesystemPath) && !fs.getFileStatus(filesystemPath).isDirectory) { + sys.error("a CREATE [TEMPORARY] TABLE AS SELECT statement need the path must be directory") + } if (overwrite) { + val temporaryPath = new Path(path, "_temporary") + val dataPath = new Path(path, "data") + // Write the data. + data.toJSON.saveAsTextFile(temporaryPath.toUri.toString) + val pathsToDelete = fs.listStatus(filesystemPath).filter( + f => !isTemporaryFile(f.getPath)).map(_.getPath) + try { - fs.delete(filesystemPath, true) + pathsToDelete.foreach(fs.delete(_,true)) } catch { case e: IOException => throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to INSERT OVERWRITE a JSON table:\n${e.toString}") + s"Unable to delete original data in directory ${filesystemPath.toString} when" + + s" run INSERT OVERWRITE a JSON table:\n${e.toString}") } - // Write the data. - data.toJSON.saveAsTextFile(path) + fs.rename(temporaryPath,dataPath) // Right now, we assume that the schema is not changed. We will not update the schema. // schema = data.schema } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 29caed9337ff6..6e58d8155a350 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -29,10 +29,13 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { import caseInsensisitiveContext._ var path: File = null + var existPath: File = null override def beforeAll(): Unit = { path = util.getTempFilePath("jsonCTAS").getCanonicalFile + existPath = util.getTempFilePath("existJsonCTAS").getCanonicalFile val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + rdd.saveAsTextFile(existPath.toURI.toString) jsonRDD(rdd).registerTempTable("jt") } @@ -62,6 +65,34 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { dropTempTable("jsonTable") } + test("INSERT OVERWRITE with the source and destination point to the same table") { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable1 + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${existPath.toString}' + |) + """.stripMargin) + + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable2 + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${existPath.toString}' + |) AS + |SELECT a, b FROM jsonTable1 + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable2"), + sql("SELECT a, b FROM jt").collect()) + + dropTempTable("jsonTable1") + dropTempTable("jsonTable2") + } + test("create a table, drop it and create another one with the same name") { sql( s"""