diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f9ff4ea6ca157..022e2891559d7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -288,7 +288,7 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. * This is used by PySpark's shuffle operations. */ -private class PairwiseRDD(prev: RDD[Array[Byte]]) extends +private[spark] class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) { override def getPartitions = prev.partitions override def compute(split: Partition, context: TaskContext) = diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py new file mode 100644 index 0000000000000..cd2a8a73de63b --- /dev/null +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -0,0 +1,20 @@ +import sys + +from pyspark.streaming.context import StreamingContext +from pyspark.streaming.duration import * + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: wordcount " + exit(-1) + ssc = StreamingContext(appName="PythonStreamingNetworkWordCount", + duration=Seconds(1)) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda word: (word, 1))\ + .reduceByKey(lambda a,b: a+b) + counts.pyprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/wordcount.py b/examples/src/main/python/streaming/wordcount.py new file mode 100644 index 0000000000000..4c62835ed8025 --- /dev/null +++ b/examples/src/main/python/streaming/wordcount.py @@ -0,0 +1,21 @@ +import sys + +from pyspark.streaming.context import StreamingContext +from pyspark.streaming.duration import * + +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, "Usage: wordcount " + exit(-1) + + ssc = StreamingContext(appName="PythonStreamingWordCount", + duration=Seconds(1)) + + lines = ssc.textFileStream(sys.argv[1]) + counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda x: (x, 1))\ + .reduceByKey(lambda a, b: a+b) + counts.pyprint() + + ssc.start() + ssc.awaitTermination() diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 9c70fa5c16d0c..c3fef42d118bd 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -108,6 +108,9 @@ def run(self): java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") + java_import(gateway.jvm, "org.apache.spark.streaming.*") + java_import(gateway.jvm, "org.apache.spark.streaming.api.java.*") + java_import(gateway.jvm, "org.apache.spark.streaming.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext") diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py new file mode 100644 index 0000000000000..d7a20caac1ee8 --- /dev/null +++ b/python/pyspark/streaming/context.py @@ -0,0 +1,186 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +from signal import signal, SIGTERM, SIGINT +import atexit +import time + +from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer +from pyspark.context import SparkContext +from pyspark.streaming.dstream import DStream +from pyspark.streaming.duration import Duration + +from py4j.java_collections import ListConverter + + +class StreamingContext(object): + """ + Main entry point for Spark Streaming functionality. A StreamingContext represents the + connection to a Spark cluster, and can be used to create L{DStream}s and + broadcast variables on that cluster. + """ + + def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, + environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, + gateway=None, sparkContext=None, duration=None): + """ + Create a new StreamingContext. At least the master and app name and duration + should be set, either through the named parameters here or through C{conf}. + + @param master: Cluster URL to connect to + (e.g. mesos://host:port, spark://host:port, local[4]). + @param appName: A name for your job, to display on the cluster web UI. + @param sparkHome: Location where Spark is installed on cluster nodes. + @param pyFiles: Collection of .zip or .py files to send to the cluster + and add to PYTHONPATH. These can be paths on the local file + system or HDFS, HTTP, HTTPS, or FTP URLs. + @param environment: A dictionary of environment variables to set on + worker nodes. + @param batchSize: The number of Python objects represented as a single + Java object. Set 1 to disable batching or -1 to use an + unlimited batch size. + @param serializer: The serializer for RDDs. + @param conf: A L{SparkConf} object setting Spark properties. + @param gateway: Use an existing gateway and JVM, otherwise a new JVM + will be instatiated. + @param sparkContext: L{SparkContext} object. + @param duration: A L{Duration} object for SparkStreaming. + + """ + + if not isinstance(duration, Duration): + raise TypeError("Input should be pyspark.streaming.duration.Duration object") + + if sparkContext is None: + # Create the Python Sparkcontext + self._sc = SparkContext(master=master, appName=appName, sparkHome=sparkHome, + pyFiles=pyFiles, environment=environment, batchSize=batchSize, + serializer=serializer, conf=conf, gateway=gateway) + else: + self._sc = sparkContext + + # Start py4j callback server. + # Callback sever is need only by SparkStreming; therefore the callback sever + # is started in StreamingContext. + SparkContext._gateway.restart_callback_server() + self._set_clean_up_handler() + self._jvm = self._sc._jvm + self._jssc = self._initialize_context(self._sc._jsc, duration._jduration) + + # Initialize StremaingContext in function to allow subclass specific initialization + def _initialize_context(self, jspark_context, jduration): + return self._jvm.JavaStreamingContext(jspark_context, jduration) + + def _set_clean_up_handler(self): + """ set clean up hander using atexit """ + + def clean_up_handler(): + SparkContext._gateway.shutdown() + + atexit.register(clean_up_handler) + # atext is not called when the program is killed by a signal not handled by + # Python. + for sig in (SIGINT, SIGTERM): + signal(sig, clean_up_handler) + + @property + def sparkContext(self): + """ + Return SparkContext which is associated with this StreamingContext. + """ + return self._sc + + def start(self): + """ + Start the execution of the streams. + """ + self._jssc.start() + + def awaitTermination(self, timeout=None): + """ + Wait for the execution to stop. + @param timeout: time to wait in milliseconds + """ + if timeout is None: + self._jssc.awaitTermination() + else: + self._jssc.awaitTermination(timeout) + + def remember(self, duration): + """ + Set each DStreams in this context to remember RDDs it generated in the last given duration. + DStreams remember RDDs only for a limited duration of time and releases them for garbage + collection. This method allows the developer to specify how to long to remember the RDDs ( + if the developer wishes to query old data outside the DStream computation). + @param duration pyspark.streaming.duration.Duration object. + Minimum duration that each DStream should remember its RDDs + """ + if not isinstance(duration, Duration): + raise TypeError("Input should be pyspark.streaming.duration.Duration object") + + self._jssc.remember(duration._jduration) + + # TODO: add storageLevel + def socketTextStream(self, hostname, port): + """ + Create an input from TCP source hostname:port. Data is received using + a TCP socket and receive byte is interpreted as UTF8 encoded '\n' delimited + lines. + """ + return DStream(self._jssc.socketTextStream(hostname, port), self, UTF8Deserializer()) + + def textFileStream(self, directory): + """ + Create an input stream that monitors a Hadoop-compatible file system + for new files and reads them as text files. Files must be wrriten to the + monitored directory by "moving" them from another location within the same + file system. File names starting with . are ignored. + """ + return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) + + def stop(self, stopSparkContext=True, stopGraceFully=False): + """ + Stop the execution of the streams immediately (does not wait for all received data + to be processed). + """ + self._jssc.stop(stopSparkContext, stopGraceFully) + if stopSparkContext: + self._sc.stop() + + # Shutdown only callback server and all py3j client is shutdowned + # clean up handler + SparkContext._gateway._shutdown_callback_server() + + def _testInputStream(self, test_inputs, numSlices=None): + """ + This function is only for unittest. + It requires a list as input, and returns the i_th element at the i_th batch + under manual clock. + """ + test_rdds = list() + test_rdd_deserializers = list() + for test_input in test_inputs: + test_rdd = self._sc.parallelize(test_input, numSlices) + test_rdds.append(test_rdd._jrdd) + test_rdd_deserializers.append(test_rdd._jrdd_deserializer) + # All deserializers have to be the same. + # TODO: add deserializer validation + jtest_rdds = ListConverter().convert(test_rdds, SparkContext._gateway._gateway_client) + jinput_stream = self._jvm.PythonTestInputStream(self._jssc, jtest_rdds).asJavaDStream() + + return DStream(jinput_stream, self, test_rdd_deserializers[0]) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py new file mode 100644 index 0000000000000..bb137d09211bf --- /dev/null +++ b/python/pyspark/streaming/dstream.py @@ -0,0 +1,536 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import defaultdict +from itertools import chain, ifilter, imap +import operator + +from pyspark.serializers import NoOpSerializer,\ + BatchedSerializer, CloudPickleSerializer, pack_long,\ + CompressedSerializer +from pyspark.storagelevel import StorageLevel +from pyspark.resultiterable import ResultIterable +from pyspark.streaming.util import rddToFileName, RDDFunction +from pyspark.rdd import portable_hash, _parse_memory +from pyspark.traceback_utils import SCCallSiteSync + +from py4j.java_collections import ListConverter, MapConverter + +__all__ = ["DStream"] + + +class DStream(object): + def __init__(self, jdstream, ssc, jrdd_deserializer): + self._jdstream = jdstream + self._ssc = ssc + self.ctx = ssc._sc + self._jrdd_deserializer = jrdd_deserializer + self.is_cached = False + self.is_checkpointed = False + self._partitionFunc = None + + def context(self): + """ + Return the StreamingContext associated with this DStream + """ + return self._ssc + + def count(self): + """ + Return a new DStream which contains the number of elements in this DStream. + """ + return self.mapPartitions(lambda i: [sum(1 for _ in i)])._sum() + + def _sum(self): + """ + Add up the elements in this DStream. + """ + return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + + def print_(self, label=None): + """ + Since print is reserved name for python, we cannot define a "print" method function. + This function prints serialized data in RDD in DStream because Scala and Java cannot + deserialized pickled python object. Please use DStream.pyprint() to print results. + + Call DStream.print() and this function will print byte array in the DStream + """ + # a hack to call print function in DStream + getattr(self._jdstream, "print")(label) + + def filter(self, f): + """ + Return a new DStream containing only the elements that satisfy predicate. + """ + def func(iterator): + return ifilter(f, iterator) + return self.mapPartitions(func) + + def flatMap(self, f, preservesPartitioning=False): + """ + Pass each value in the key-value pair DStream through flatMap function + without changing the keys: this also retains the original RDD's partition. + """ + def func(s, iterator): + return chain.from_iterable(imap(f, iterator)) + return self._mapPartitionsWithIndex(func, preservesPartitioning) + + def map(self, f, preservesPartitioning=False): + """ + Return a new DStream by applying a function to each element of DStream. + """ + def func(iterator): + return imap(f, iterator) + return self.mapPartitions(func, preservesPartitioning) + + def mapPartitions(self, f, preservesPartitioning=False): + """ + Return a new DStream by applying a function to each partition of this DStream. + """ + def func(s, iterator): + return f(iterator) + return self._mapPartitionsWithIndex(func, preservesPartitioning) + + def _mapPartitionsWithIndex(self, f, preservesPartitioning=False): + """ + Return a new DStream by applying a function to each partition of this DStream, + while tracking the index of the original partition. + """ + return PipelinedDStream(self, f, preservesPartitioning) + + def reduce(self, func): + """ + Return a new DStream by reduceing the elements of this RDD using the specified + commutative and associative binary operator. + """ + return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda x: x[1]) + + def reduceByKey(self, func, numPartitions=None): + """ + Merge the value for each key using an associative reduce function. + + This will also perform the merging locally on each mapper before + sending results to reducer, similarly to a "combiner" in MapReduce. + + Output will be hash-partitioned with C{numPartitions} partitions, or + the default parallelism level if C{numPartitions} is not specified. + """ + return self.combineByKey(lambda x: x, func, func, numPartitions) + + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numPartitions=None): + """ + Count the number of elements for each key, and return the result to the + master as a dictionary + """ + if numPartitions is None: + numPartitions = self._defaultReducePartitions() + + def combineLocally(iterator): + combiners = {} + for x in iterator: + (k, v) = x + if k not in combiners: + combiners[k] = createCombiner(v) + else: + combiners[k] = mergeValue(combiners[k], v) + return combiners.iteritems() + locally_combined = self.mapPartitions(combineLocally) + shuffled = locally_combined.partitionBy(numPartitions) + + def _mergeCombiners(iterator): + combiners = {} + for (k, v) in iterator: + if k not in combiners: + combiners[k] = v + else: + combiners[k] = mergeCombiners(combiners[k], v) + return combiners.iteritems() + + return shuffled.mapPartitions(_mergeCombiners) + + def partitionBy(self, numPartitions, partitionFunc=portable_hash): + """ + Return a copy of the DStream partitioned using the specified partitioner. + """ + if numPartitions is None: + numPartitions = self.ctx._defaultReducePartitions() + + # Transferring O(n) objects to Java is too expensive. Instead, we'll + # form the hash buckets in Python, transferring O(numPartitions) objects + # to Java. Each object is a (splitNumber, [objects]) pair. + + outputSerializer = self.ctx._unbatched_serializer +# +# def add_shuffle_key(split, iterator): +# buckets = defaultdict(list) +# +# for (k, v) in iterator: +# buckets[partitionFunc(k) % numPartitions].append((k, v)) +# for (split, items) in buckets.iteritems(): +# yield pack_long(split) +# yield outputSerializer.dumps(items) +# keyed = PipelinedDStream(self, add_shuffle_key) + + limit = (_parse_memory(self.ctx._conf.get( + "spark.python.worker.memory", "512m")) / 2) + + def add_shuffle_key(split, iterator): + + buckets = defaultdict(list) + c, batch = 0, min(10 * numPartitions, 1000) + + for k, v in iterator: + buckets[partitionFunc(k) % numPartitions].append((k, v)) + c += 1 + + # check used memory and avg size of chunk of objects + if (c % 1000 == 0 and get_used_memory() > limit + or c > batch): + n, size = len(buckets), 0 + for split in buckets.keys(): + yield pack_long(split) + d = outputSerializer.dumps(buckets[split]) + del buckets[split] + yield d + size += len(d) + + avg = (size / n) >> 20 + # let 1M < avg < 10M + if avg < 1: + batch *= 1.5 + elif avg > 10: + batch = max(batch / 1.5, 1) + c = 0 + + for split, items in buckets.iteritems(): + yield pack_long(split) + yield outputSerializer.dumps(items) + + keyed = self._mapPartitionsWithIndex(add_shuffle_key) + + + + + keyed._bypass_serializer = True + with SCCallSiteSync(self.ctx) as css: + partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, + id(partitionFunc)) + jdstream = self.ctx._jvm.PythonPairwiseDStream(keyed._jdstream.dstream(), + partitioner).asJavaDStream() + dstream = DStream(jdstream, self._ssc, BatchedSerializer(outputSerializer)) + # This is required so that id(partitionFunc) remains unique, even if + # partitionFunc is a lambda: + dstream._partitionFunc = partitionFunc + return dstream + + def _defaultReducePartitions(self): + """ + Returns the default number of partitions to use during reduce tasks (e.g., groupBy). + If spark.default.parallelism is set, then we'll use the value from SparkContext + defaultParallelism, otherwise we'll use the number of partitions in this RDD + + This mirrors the behavior of the Scala Partitioner#defaultPartitioner, intended to reduce + the likelihood of OOMs. Once PySpark adopts Partitioner-based APIs, this behavior will + be inherent. + """ + if self.ctx._conf.contains("spark.default.parallelism"): + return self.ctx.defaultParallelism + else: + return self.getNumPartitions() + + def getNumPartitions(self): + """ + Return the number of partitions in RDD + """ + # TODO: remove hard coding. RDD has NumPartitions. How do we get the number of partition + # through DStream? + return 2 + + def foreachRDD(self, func): + """ + Apply userdefined function to all RDD in a DStream. + This python implementation could be expensive because it uses callback server + in order to apply function to RDD in DStream. + This is an output operator, so this DStream will be registered as an output + stream and there materialized. + """ + wrapped_func = RDDFunction(self.ctx, self._jrdd_deserializer, func) + self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), wrapped_func) + + def pyprint(self): + """ + Print the first ten elements of each RDD generated in this DStream. This is an output + operator, so this DStream will be registered as an output stream and there materialized. + """ + def takeAndPrint(rdd, time): + """ + Closure to take element from RDD and print first 10 elements. + This closure is called by py4j callback server. + """ + taken = rdd.take(11) + print "-------------------------------------------" + print "Time: %s" % (str(time)) + print "-------------------------------------------" + for record in taken[:10]: + print record + if len(taken) > 10: + print "..." + print + + self.foreachRDD(takeAndPrint) + + def mapValues(self, f): + """ + Pass each value in the key-value pair RDD through a map function + without changing the keys; this also retains the original RDD's + partitioning. + """ + map_values_fn = lambda (k, v): (k, f(v)) + return self.map(map_values_fn, preservesPartitioning=True) + + def flatMapValues(self, f): + """ + Pass each value in the key-value pair RDD through a flatMap function + without changing the keys; this also retains the original RDD's + partitioning. + """ + flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + return self.flatMap(flat_map_fn, preservesPartitioning=True) + + def glom(self): + """ + Return a new DStream in which RDD is generated by applying glom() to RDD of + this DStream. Applying glom() to an RDD coalesces all elements within each partition into + an list. + """ + def func(iterator): + yield list(iterator) + return self.mapPartitions(func) + + def cache(self): + """ + Persist this DStream with the default storage level (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self.persist(StorageLevel.MEMORY_ONLY_SER) + return self + + def persist(self, storageLevel): + """ + Set this DStream's storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the DStream does not have a storage level set yet. + """ + self.is_cached = True + javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) + self._jdstream.persist(javaStorageLevel) + return self + + def checkpoint(self, interval): + """ + Mark this DStream for checkpointing. It will be saved to a file inside the + checkpoint directory set with L{SparkContext.setCheckpointDir()} + + @param interval: Time interval after which generated RDD will be checkpointed + interval has to be pyspark.streaming.duration.Duration + """ + self.is_checkpointed = True + self._jdstream.checkpoint(interval._jduration) + return self + + def groupByKey(self, numPartitions=None): + """ + Return a new DStream which contains group the values for each key in the + DStream into a single sequence. + Hash-partitions the resulting RDD with into numPartitions partitions in + the DStream. + + Note: If you are grouping in order to perform an aggregation (such as a + sum or average) over each key, using reduceByKey will provide much + better performance. + + """ + def createCombiner(x): + return [x] + + def mergeValue(xs, x): + xs.append(x) + return xs + + def mergeCombiners(a, b): + a.extend(b) + return a + + return self.combineByKey(createCombiner, mergeValue, mergeCombiners, + numPartitions).mapValues(lambda x: ResultIterable(x)) + + def countByValue(self): + """ + Return new DStream which contains the count of each unique value in this + DStreeam as a (value, count) pairs. + """ + def countPartition(iterator): + counts = defaultdict(int) + for obj in iterator: + counts[obj] += 1 + yield counts + + def mergeMaps(m1, m2): + for (k, v) in m2.iteritems(): + m1[k] += v + return m1 + + return self.mapPartitions(countPartition).reduce(mergeMaps).flatMap(lambda x: x.items()) + + def saveAsTextFiles(self, prefix, suffix=None): + """ + Save this DStream as a text file, using string representations of elements. + """ + + def saveAsTextFile(rdd, time): + """ + Closure to save element in RDD in DStream as Pickled data in file. + This closure is called by py4j callback server. + """ + path = rddToFileName(prefix, suffix, time) + rdd.saveAsTextFile(path) + + return self.foreachRDD(saveAsTextFile) + + def saveAsPickleFiles(self, prefix, suffix=None): + """ + Save this DStream as a SequenceFile of serialized objects. The serializer + used is L{pyspark.serializers.PickleSerializer}, default batch size + is 10. + """ + + def saveAsPickleFile(rdd, time): + """ + Closure to save element in RDD in the DStream as Pickled data in file. + This closure is called by py4j callback server. + """ + path = rddToFileName(prefix, suffix, time) + rdd.saveAsPickleFile(path) + + return self.foreachRDD(saveAsPickleFile) + + def _test_output(self, result): + """ + This function is only for test case. + Store data in a DStream to result to verify the result in test case + """ + def get_output(rdd, time): + """ + Closure to get element in RDD in the DStream. + This closure is called by py4j callback server. + """ + collected = rdd.collect() + result.append(collected) + + self.foreachRDD(get_output) + + +# TODO: implement updateStateByKey +# TODO: implement slice + +# Window Operations +# TODO: implement window +# TODO: implement groupByKeyAndWindow +# TODO: implement reduceByKeyAndWindow +# TODO: implement countByValueAndWindow +# TODO: implement countByWindow +# TODO: implement reduceByWindow + +# Transform Operation +# TODO: implement transform +# TODO: implement transformWith +# Following operation has dependency with transform +# TODO: implement union +# TODO: implement repertitions +# TODO: implement cogroup +# TODO: implement join +# TODO: implement leftOuterJoin +# TODO: implement rightOuterJoin + + +class PipelinedDStream(DStream): + """ + Since PipelinedDStream is same to PipelindRDD, if PipliedRDD is changed, + this code should be changed in the same way. + """ + def __init__(self, prev, func, preservesPartitioning=False): + if not isinstance(prev, PipelinedDStream) or not prev._is_pipelinable(): + # This transformation is the first in its stage: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jdstream = prev._jdstream + self._prev_jrdd_deserializer = prev._jrdd_deserializer + else: + prev_func = prev.func + + def pipeline_func(split, iterator): + return func(split, prev_func(split, iterator)) + self.func = pipeline_func + self.preservesPartitioning = \ + prev.preservesPartitioning and preservesPartitioning + self._prev_jdstream = prev._prev_jdstream # maintain the pipeline + self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer + self.is_cached = False + self.is_checkpointed = False + self._ssc = prev._ssc + self.ctx = prev.ctx + self.prev = prev + self._jdstream_val = None + self._jrdd_deserializer = self.ctx.serializer + self._bypass_serializer = False + self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None + + @property + def _jdstream(self): + if self._jdstream_val: + return self._jdstream_val + if self._bypass_serializer: + self.jrdd_deserializer = NoOpSerializer() + command = (self.func, self._prev_jrdd_deserializer, + self._jrdd_deserializer) + # the serialized command will be compressed by broadcast + ser = CloudPickleSerializer() + pickled_command = ser.dumps(command) + if pickled_command > (1 << 20): # 1M + broadcast = self.ctx.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], + self.ctx._gateway._gateway_client) + self.ctx._pickled_broadcast_vars.clear() + class_tag = self._prev_jdstream.classTag() + env = MapConverter().convert(self.ctx.environment, + self.ctx._gateway._gateway_client) + includes = ListConverter().convert(self.ctx._python_includes, + self.ctx._gateway._gateway_client) + python_dstream = self.ctx._jvm.PythonDStream(self._prev_jdstream.dstream(), + bytearray(pickled_command), + env, includes, self.preservesPartitioning, + self.ctx.pythonExec, + broadcast_vars, self.ctx._javaAccumulator, + class_tag) + self._jdstream_val = python_dstream.asJavaDStream() + return self._jdstream_val + + def _is_pipelinable(self): + return not (self.is_cached or self.is_checkpointed) diff --git a/python/pyspark/streaming/duration.py b/python/pyspark/streaming/duration.py new file mode 100644 index 0000000000000..495ac2edff198 --- /dev/null +++ b/python/pyspark/streaming/duration.py @@ -0,0 +1,376 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.streaming import util + + +class Duration(object): + """ + Duration for Spark Streaming application. Used to set duration + + Most of the time, you would create a Duration object with + C{Duration()}, which will load values from C{spark.streaming.*} Java system + properties as well. In this case, any parameters you set directly on + the C{Duration} object take priority over system properties. + + """ + def __init__(self, millis, _jvm=None): + """ + Create new Duration. + + @param millis: milisecond + + """ + self._millis = millis + + from pyspark.context import SparkContext + SparkContext._ensure_initialized() + _jvm = _jvm or SparkContext._jvm + self._jduration = _jvm.Duration(millis) + + def toString(self): + """ + Return duration as string + + >>> d_10 = Duration(10) + >>> d_10.toString() + '10 ms' + """ + return str(self._millis) + " ms" + + def isZero(self): + """ + Check if millis is zero + + >>> d_10 = Duration(10) + >>> d_10.isZero() + False + >>> d_0 = Duration(0) + >>> d_0.isZero() + True + """ + return self._millis == 0 + + def prettyPrint(self): + """ + Return a human-readable string representing a duration + + >>> d_10 = Duration(10) + >>> d_10.prettyPrint() + '10 ms' + >>> d_1sec = Duration(1000) + >>> d_1sec.prettyPrint() + '1.0 s' + >>> d_1min = Duration(60 * 1000) + >>> d_1min.prettyPrint() + '1.0 m' + >>> d_1hour = Duration(60 * 60 * 1000) + >>> d_1hour.prettyPrint() + '1.00 h' + """ + return util.msDurationToString(self._millis) + + def milliseconds(self): + """ + Return millisecond + + >>> d_10 = Duration(10) + >>> d_10.milliseconds() + 10 + + """ + return self._millis + + def toFormattedString(self): + """ + Return millisecond + + >>> d_10 = Duration(10) + >>> d_10.toFormattedString() + '10' + + """ + return str(self._millis) + + def max(self, other): + """ + Return higher Duration + + >>> d_10 = Duration(10) + >>> d_100 = Duration(100) + >>> d_max = d_10.max(d_100) + >>> print d_max + 100 ms + + """ + Duration._is_duration(other) + if self > other: + return self + else: + return other + + def min(self, other): + """ + Return lower Durattion + + >>> d_10 = Duration(10) + >>> d_100 = Duration(100) + >>> d_min = d_10.min(d_100) + >>> print d_min + 10 ms + + """ + Duration._is_duration(other) + if self < other: + return self + else: + return other + + def __str__(self): + """ + >>> d_10 = Duration(10) + >>> str(d_10) + '10 ms' + + """ + return self.toString() + + def __add__(self, other): + """ + Add Duration and Duration + + >>> d_10 = Duration(10) + >>> d_100 = Duration(100) + >>> d_110 = d_10 + d_100 + >>> print d_110 + 110 ms + """ + Duration._is_duration(other) + return Duration(self._millis + other._millis) + + def __sub__(self, other): + """ + Subtract Duration by Duration + + >>> d_10 = Duration(10) + >>> d_100 = Duration(100) + >>> d_90 = d_100 - d_10 + >>> print d_90 + 90 ms + + """ + Duration._is_duration(other) + return Duration(self._millis - other._millis) + + def __mul__(self, other): + """ + Multiple Duration by Duration + + >>> d_10 = Duration(10) + >>> d_100 = Duration(100) + >>> d_1000 = d_10 * d_100 + >>> print d_1000 + 1000 ms + + """ + Duration._is_duration(other) + return Duration(self._millis * other._millis) + + def __div__(self, other): + """ + Divide Duration by Duration + for Python 2.X + + >>> d_10 = Duration(10) + >>> d_20 = Duration(20) + >>> d_2 = d_20 / d_10 + >>> print d_2 + 2 ms + + """ + Duration._is_duration(other) + return Duration(self._millis / other._millis) + + def __truediv__(self, other): + """ + Divide Duration by Duration + for Python 3.0 + + >>> d_10 = Duration(10) + >>> d_20 = Duration(20) + >>> d_2 = d_20 / d_10 + >>> print d_2 + 2 ms + + """ + Duration._is_duration(other) + return Duration(self._millis / other._millis) + + def __floordiv__(self, other): + """ + Divide Duration by Duration + + >>> d_10 = Duration(10) + >>> d_3 = Duration(3) + >>> d_3 = d_10 // d_3 + >>> print d_3 + 3 ms + + """ + Duration._is_duration(other) + return Duration(self._millis // other._millis) + + def __lt__(self, other): + """ + Duration < Duration + + >>> d_10 = Duration(10) + >>> d_20 = Duration(20) + >>> d_10 < d_20 + True + >>> d_20 < d_10 + False + + """ + Duration._is_duration(other) + return self._millis < other._millis + + def __le__(self, other): + """ + Duration <= Duration + + >>> d_10 = Duration(10) + >>> d_20 = Duration(20) + >>> d_10 <= d_20 + True + >>> d_20 <= d_10 + False + + """ + Duration._is_duration(other) + return self._millis <= other._millis + + def __eq__(self, other): + """ + Duration == Duration + + >>> d_10 = Duration(10) + >>> d_20 = Duration(20) + >>> d_10 == d_20 + False + >>> other_d_10 = Duration(10) + >>> d_10 == other_d_10 + True + + """ + Duration._is_duration(other) + return self._millis == other._millis + + def __ne__(self, other): + """ + Duration != Duration + + >>> d_10 = Duration(10) + >>> d_20 = Duration(20) + >>> d_10 != d_20 + True + >>> other_d_10 = Duration(10) + >>> d_10 != other_d_10 + False + + """ + Duration._is_duration(other) + return self._millis != other._millis + + def __gt__(self, other): + """ + Duration > Duration + + >>> d_10 = Duration(10) + >>> d_20 = Duration(20) + >>> d_10 > d_20 + False + >>> d_20 > d_10 + True + + """ + Duration._is_duration(other) + return self._millis > other._millis + + def __ge__(self, other): + """ + Duration >= Duration + + >>> d_10 = Duration(10) + >>> d_20 = Duration(20) + >>> d_10 < d_20 + True + >>> d_20 < d_10 + False + + + """ + Duration._is_duration(other) + return self._millis >= other._millis + + @classmethod + def _is_duration(self, instance): + """ is instance Duration """ + if not isinstance(instance, Duration): + raise TypeError("This should be Duration") + + +def Milliseconds(milliseconds): + """ + Helper function that creates instance of [[pysparkstreaming.duration]] representing + a given number of milliseconds. + + >>> milliseconds = Milliseconds(1) + >>> d_1 = Duration(1) + >>> milliseconds == d_1 + True + + """ + return Duration(milliseconds) + + +def Seconds(seconds): + """ + Helper function that creates instance of [[pysparkstreaming.duration]] representing + a given number of seconds. + + >>> seconds = Seconds(1) + >>> d_1sec = Duration(1000) + >>> seconds == d_1sec + True + + """ + return Duration(seconds * 1000) + + +def Minutes(minutes): + """ + Helper function that creates instance of [[pysparkstreaming.duration]] representing + a given number of minutes. + + >>> minutes = Minutes(1) + >>> d_1min = Duration(60 * 1000) + >>> minutes == d_1min + True + + """ + return Duration(minutes * 60 * 1000) diff --git a/python/pyspark/streaming/jtime.py b/python/pyspark/streaming/jtime.py new file mode 100644 index 0000000000000..801b8871b3879 --- /dev/null +++ b/python/pyspark/streaming/jtime.py @@ -0,0 +1,136 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.streaming.duration import Duration + +""" +The name of this file, time is not a good naming for python +because if we do import time when we want to use native python time package, it does +not import python time package. +""" +# TODO: add doctest + + +class Time(object): + """ + Time for Spark Streaming application. Used to set Time + + Most of the time, you would create a Duration object with + C{Time()}, which will load values from C{spark.streaming.*} Java system + properties as well. In this case, any parameters you set directly on + the C{Time} object take priority over system properties. + + """ + def __init__(self, millis, _jvm=None): + """ + Create new Time. + + @param millis: milisecond + + @param _jvm: internal parameter used to pass a handle to the + Java VM; does not need to be set by users + + """ + self._millis = millis + + from pyspark.context import StreamingContext + StreamingContext._ensure_initialized() + _jvm = _jvm or StreamingContext._jvm + self._jtime = _jvm.Time(millis) + + def toString(self): + """ Return time as string """ + return str(self._millis) + " ms" + + def milliseconds(self): + """ Return millisecond """ + return self._millis + + def max(self, other): + """ Return higher Time """ + Time._is_time(other) + if self > other: + return self + else: + return other + + def min(self, other): + """ Return lower Time """ + Time._is_time(other) + if self < other: + return self + else: + return other + + def __add__(self, other): + """ Add Time and Time """ + Duration._is_duration(other) + return Time(self._millis + other._millis) + + def __sub__(self, other): + """ Subtract Time by Duration or Time """ + if isinstance(other, Duration): + return Time(self._millis - other._millis) + elif isinstance(other, Time): + return Duration(self._millis, other._millis) + else: + raise TypeError + + def __lt__(self, other): + """ Time < Time """ + Time._is_time(other) + return self._millis < other._millis + + def __le__(self, other): + """ Time <= Time """ + Time._is_time(other) + return self._millis <= other._millis + + def __eq__(self, other): + """ Time == Time """ + Time._is_time(other) + return self._millis == other._millis + + def __ne__(self, other): + """ Time != Time """ + Time._is_time(other) + return self._millis != other._millis + + def __gt__(self, other): + """ Time > Time """ + Time._is_time(other) + return self._millis > other._millis + + def __ge__(self, other): + """ Time >= Time """ + Time._is_time(other) + return self._millis >= other._millis + + def isMultipbleOf(self, duration): + """ is multiple by Duration """ + Duration._is_duration(duration) + return self._millis % duration._millis == 0 + + @classmethod + def _is_time(self, instance): + """ is instance Time """ + if not isinstance(instance, Time): + raise TypeError + +# TODO: implement until +# TODO: implement to + diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py new file mode 100644 index 0000000000000..95cb76a15be07 --- /dev/null +++ b/python/pyspark/streaming/tests.py @@ -0,0 +1,557 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for Python SparkStreaming; additional tests are implemented as doctests in +individual modules. + +Callback server is sometimes unstable sometimes, which cause error in test case. +But this is very rare case. +""" +from itertools import chain +import time +import operator +import sys + +if sys.version_info[:2] <= (2, 6): + import unittest2 as unittest +else: + import unittest + +from pyspark.context import SparkContext +from pyspark.conf import SparkConf +from pyspark.streaming.context import StreamingContext +from pyspark.streaming.duration import * + + +class PySparkStreamingTestCase(unittest.TestCase): + def setUp(self): + class_name = self.__class__.__name__ + self.ssc = StreamingContext(appName=class_name, duration=Seconds(1)) + + def tearDown(self): + # Do not call pyspark.streaming.context.StreamingContext.stop directly because + # we do not wait to shutdown py4j client. + self.ssc._jssc.stop() + self.ssc._sc.stop() + time.sleep(1) + + @classmethod + def tearDownClass(cls): + # Make sure tp shutdown the callback server + SparkContext._gateway._shutdown_callback_server() + + +class TestBasicOperationsSuite(PySparkStreamingTestCase): + """ + 2 tests for each function for batach deserializer and unbatch deserilizer because + the deserializer is not changed dunamically after streaming process starts. + Default numInputPartitions is 2. + If the number of input element is over 3, that DStream use batach deserializer. + If not, that DStream use unbatch deserializer. + + All tests input should have list of lists(3 lists are default). This list represents stream. + Every batch interval, the first object of list are chosen to make DStream. + e.g The first list in the list is input of the first batch. + Please see the BasicTestSuits in Scala which is close to this implementation. + """ + def setUp(self): + PySparkStreamingTestCase.setUp(self) + self.timeout = 10 # seconds + self.numInputPartitions = 2 + + def tearDown(self): + PySparkStreamingTestCase.tearDown(self) + + @classmethod + def tearDownClass(cls): + PySparkStreamingTestCase.tearDownClass() + + def test_map_batch(self): + """Basic operation test for DStream.map with batch deserializer.""" + test_input = [range(1, 5), range(5, 9), range(9, 13)] + + def test_func(dstream): + return dstream.map(lambda x: str(x)) + expected_output = map(lambda x: map(lambda y: str(y), x), test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_map_unbatach(self): + """Basic operation test for DStream.map with unbatch deserializer.""" + test_input = [range(1, 4), range(4, 7), range(7, 10)] + + def test_func(dstream): + return dstream.map(lambda x: str(x)) + expected_output = map(lambda x: map(lambda y: str(y), x), test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_flatMap_batch(self): + """Basic operation test for DStream.faltMap with batch deserializer.""" + test_input = [range(1, 5), range(5, 9), range(9, 13)] + + def test_func(dstream): + return dstream.flatMap(lambda x: (x, x * 2)) + expected_output = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))), + test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_flatMap_unbatch(self): + """Basic operation test for DStream.faltMap with unbatch deserializer.""" + test_input = [range(1, 4), range(4, 7), range(7, 10)] + + def test_func(dstream): + return dstream.flatMap(lambda x: (x, x * 2)) + expected_output = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))), + test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_filter_batch(self): + """Basic operation test for DStream.filter with batch deserializer.""" + test_input = [range(1, 5), range(5, 9), range(9, 13)] + + def test_func(dstream): + return dstream.filter(lambda x: x % 2 == 0) + expected_output = map(lambda x: filter(lambda y: y % 2 == 0, x), test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_filter_unbatch(self): + """Basic operation test for DStream.filter with unbatch deserializer.""" + test_input = [range(1, 4), range(4, 7), range(7, 10)] + + def test_func(dstream): + return dstream.filter(lambda x: x % 2 == 0) + expected_output = map(lambda x: filter(lambda y: y % 2 == 0, x), test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_count_batch(self): + """Basic operation test for DStream.count with batch deserializer.""" + test_input = [range(1, 5), range(1, 10), range(1, 20)] + + def test_func(dstream): + return dstream.count() + expected_output = map(lambda x: [len(x)], test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_count_unbatch(self): + """Basic operation test for DStream.count with unbatch deserializer.""" + test_input = [[], [1], range(1, 3), range(1, 4)] + + def test_func(dstream): + return dstream.count() + expected_output = map(lambda x: [len(x)], test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_reduce_batch(self): + """Basic operation test for DStream.reduce with batch deserializer.""" + test_input = [range(1, 5), range(5, 9), range(9, 13)] + + def test_func(dstream): + return dstream.reduce(operator.add) + expected_output = map(lambda x: [reduce(operator.add, x)], test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_reduce_unbatch(self): + """Basic operation test for DStream.reduce with unbatch deserializer.""" + test_input = [[1], range(1, 3), range(1, 4)] + + def test_func(dstream): + return dstream.reduce(operator.add) + expected_output = map(lambda x: [reduce(operator.add, x)], test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_reduceByKey_batch(self): + """Basic operation test for DStream.reduceByKey with batch deserializer.""" + test_input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)], + [("", 1), ("", 1), ("", 1), ("", 1)], + [(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]] + + def test_func(dstream): + return dstream.reduceByKey(operator.add) + expected_output = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]] + output = self._run_stream(test_input, test_func, expected_output) + for result in (output, expected_output): + self._sort_result_based_on_key(result) + self.assertEqual(expected_output, output) + + def test_reduceByKey_unbatch(self): + """Basic operation test for DStream.reduceByKey with unbatch deserializer.""" + test_input = [[("a", 1), ("a", 1), ("b", 1)], [("", 1), ("", 1)], []] + + def test_func(dstream): + return dstream.reduceByKey(operator.add) + expected_output = [[("a", 2), ("b", 1)], [("", 2)], []] + output = self._run_stream(test_input, test_func, expected_output) + for result in (output, expected_output): + self._sort_result_based_on_key(result) + self.assertEqual(expected_output, output) + + def test_mapValues_batch(self): + """Basic operation test for DStream.mapValues with batch deserializer.""" + test_input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 2), (3, 3)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def test_func(dstream): + return dstream.mapValues(lambda x: x + 10) + expected_output = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)], + [("", 14), (1, 11), (2, 12), (3, 13)], + [(1, 11), (2, 11), (3, 11), (4, 11)]] + output = self._run_stream(test_input, test_func, expected_output) + for result in (output, expected_output): + self._sort_result_based_on_key(result) + self.assertEqual(expected_output, output) + + def test_mapValues_unbatch(self): + """Basic operation test for DStream.mapValues with unbatch deserializer.""" + test_input = [[("a", 2), ("b", 1)], [("", 2)], [], [(1, 1), (2, 2)]] + + def test_func(dstream): + return dstream.mapValues(lambda x: x + 10) + expected_output = [[("a", 12), ("b", 11)], [("", 12)], [], [(1, 11), (2, 12)]] + output = self._run_stream(test_input, test_func, expected_output) + for result in (output, expected_output): + self._sort_result_based_on_key(result) + self.assertEqual(expected_output, output) + + def test_flatMapValues_batch(self): + """Basic operation test for DStream.flatMapValues with batch deserializer.""" + test_input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 1), (3, 1)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def test_func(dstream): + return dstream.flatMapValues(lambda x: (x, x + 10)) + expected_output = [[("a", 2), ("a", 12), ("b", 2), ("b", 12), + ("c", 1), ("c", 11), ("d", 1), ("d", 11)], + [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], + [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]] + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_flatMapValues_unbatch(self): + """Basic operation test for DStream.flatMapValues with unbatch deserializer.""" + test_input = [[("a", 2), ("b", 1)], [("", 2)], []] + + def test_func(dstream): + return dstream.flatMapValues(lambda x: (x, x + 10)) + expected_output = [[("a", 2), ("a", 12), ("b", 1), ("b", 11)], [("", 2), ("", 12)], []] + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_glom_batch(self): + """Basic operation test for DStream.glom with batch deserializer.""" + test_input = [range(1, 5), range(5, 9), range(9, 13)] + numSlices = 2 + + def test_func(dstream): + return dstream.glom() + expected_output = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]] + output = self._run_stream(test_input, test_func, expected_output, numSlices) + self.assertEqual(expected_output, output) + + def test_glom_unbatach(self): + """Basic operation test for DStream.glom with unbatch deserializer.""" + test_input = [range(1, 4), range(4, 7), range(7, 10)] + numSlices = 2 + + def test_func(dstream): + return dstream.glom() + expected_output = [[[1], [2, 3]], [[4], [5, 6]], [[7], [8, 9]]] + output = self._run_stream(test_input, test_func, expected_output, numSlices) + self.assertEqual(expected_output, output) + + def test_mapPartitions_batch(self): + """Basic operation test for DStream.mapPartitions with batch deserializer.""" + test_input = [range(1, 5), range(5, 9), range(9, 13)] + numSlices = 2 + + def test_func(dstream): + def f(iterator): + yield sum(iterator) + return dstream.mapPartitions(f) + expected_output = [[3, 7], [11, 15], [19, 23]] + output = self._run_stream(test_input, test_func, expected_output, numSlices) + self.assertEqual(expected_output, output) + + def test_mapPartitions_unbatch(self): + """Basic operation test for DStream.mapPartitions with unbatch deserializer.""" + test_input = [range(1, 4), range(4, 7), range(7, 10)] + numSlices = 2 + + def test_func(dstream): + def f(iterator): + yield sum(iterator) + return dstream.mapPartitions(f) + expected_output = [[1, 5], [4, 11], [7, 17]] + output = self._run_stream(test_input, test_func, expected_output, numSlices) + self.assertEqual(expected_output, output) + + def test_countByValue_batch(self): + """Basic operation test for DStream.countByValue with batch deserializer.""" + test_input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]] + + def test_func(dstream): + return dstream.countByValue() + expected_output = [[(1, 2), (2, 2), (3, 2), (4, 2)], + [(5, 2), (6, 2), (7, 1), (8, 1)], + [("a", 2), ("b", 1), ("", 1)]] + output = self._run_stream(test_input, test_func, expected_output) + for result in (output, expected_output): + self._sort_result_based_on_key(result) + self.assertEqual(expected_output, output) + + def test_countByValue_unbatch(self): + """Basic operation test for DStream.countByValue with unbatch deserializer.""" + test_input = [range(1, 4), [1, 1, ""], ["a", "a", "b"]] + + def test_func(dstream): + return dstream.countByValue() + expected_output = [[(1, 1), (2, 1), (3, 1)], + [(1, 2), ("", 1)], + [("a", 2), ("b", 1)]] + output = self._run_stream(test_input, test_func, expected_output) + for result in (output, expected_output): + self._sort_result_based_on_key(result) + self.assertEqual(expected_output, output) + + def test_groupByKey_batch(self): + """Basic operation test for DStream.groupByKey with batch deserializer.""" + test_input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def test_func(dstream): + return dstream.groupByKey() + expected_output = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])], + [(1, [1, 1, 1]), (2, [1, 1]), (3, [1])], + [("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]] + scattered_output = self._run_stream(test_input, test_func, expected_output) + output = self._convert_iter_value_to_list(scattered_output) + for result in (output, expected_output): + self._sort_result_based_on_key(result) + self.assertEqual(expected_output, output) + + def test_groupByKey_unbatch(self): + """Basic operation test for DStream.groupByKey with unbatch deserializer.""" + test_input = [[(1, 1), (2, 1), (3, 1)], + [(1, 1), (1, 1), ("", 1)], + [("a", 1), ("a", 1), ("b", 1)]] + + def test_func(dstream): + return dstream.groupByKey() + expected_output = [[(1, [1]), (2, [1]), (3, [1])], + [(1, [1, 1]), ("", [1])], + [("a", [1, 1]), ("b", [1])]] + scattered_output = self._run_stream(test_input, test_func, expected_output) + output = self._convert_iter_value_to_list(scattered_output) + for result in (output, expected_output): + self._sort_result_based_on_key(result) + self.assertEqual(expected_output, output) + + def test_combineByKey_batch(self): + """Basic operation test for DStream.combineByKey with batch deserializer.""" + test_input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def test_func(dstream): + def add(a, b): + return a + str(b) + return dstream.combineByKey(str, add, add) + expected_output = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")], + [(1, "111"), (2, "11"), (3, "1")], + [("a", "11"), ("b", "1"), ("", "111")]] + output = self._run_stream(test_input, test_func, expected_output) + for result in (output, expected_output): + self._sort_result_based_on_key(result) + self.assertEqual(expected_output, output) + + def test_combineByKey_unbatch(self): + """Basic operation test for DStream.combineByKey with unbatch deserializer.""" + test_input = [[(1, 1), (2, 1), (3, 1)], + [(1, 1), (1, 1), ("", 1)], + [("a", 1), ("a", 1), ("b", 1)]] + + def test_func(dstream): + def add(a, b): + return a + str(b) + return dstream.combineByKey(str, add, add) + expected_output = [[(1, "1"), (2, "1"), (3, "1")], + [(1, "11"), ("", "1")], + [("a", "11"), ("b", "1")]] + output = self._run_stream(test_input, test_func, expected_output) + for result in (output, expected_output): + self._sort_result_based_on_key(result) + self.assertEqual(expected_output, output) + + def _convert_iter_value_to_list(self, outputs): + """Return key value pair list. Value is converted to iterator to list.""" + result = list() + for output in outputs: + result.append(map(lambda (x, y): (x, list(y)), output)) + return result + + def _sort_result_based_on_key(self, outputs): + """Sort the list base onf first value.""" + for output in outputs: + output.sort(key=lambda x: x[0]) + + def _run_stream(self, test_input, test_func, expected_output, numSlices=None): + """ + Start stream and return the result. + @param test_input: dataset for the test. This should be list of lists. + @param test_func: wrapped test_function. This function should return PythonDStream object. + @param expected_output: expected output for this testcase. + @param numSlices: the number of slices in the rdd in the dstream. + """ + # Generate input stream with user-defined input. + numSlices = numSlices or self.numInputPartitions + test_input_stream = self.ssc._testInputStream(test_input, numSlices) + # Apply test function to stream. + test_stream = test_func(test_input_stream) + # Add job to get output from stream. + result = list() + test_stream._test_output(result) + self.ssc.start() + + start_time = time.time() + # Loop until get the expected the number of the result from the stream. + while True: + current_time = time.time() + # Check time out. + if (current_time - start_time) > self.timeout: + break + # StreamingContext.awaitTermination is not used to wait because + # if py4j server is called every 50 milliseconds, it gets an error. + time.sleep(0.05) + # Check if the output is the same length of expected output. + if len(expected_output) == len(result): + break + + return result + + +class TestStreamingContextSuite(unittest.TestCase): + """ + Should we have conf property in SparkContext? + @property + def conf(self): + return self._conf + + """ + def setUp(self): + self.master = "local[2]" + self.appName = self.__class__.__name__ + self.batachDuration = Milliseconds(500) + self.sparkHome = "SomeDir" + self.envPair = {"key": "value"} + self.ssc = None + self.sc = None + + def tearDown(self): + # Do not call pyspark.streaming.context.StreamingContext.stop directly because + # we do not wait to shutdown py4j client. + # We need change this simply calll streamingConxt.Stop + #self.ssc._jssc.stop() + if self.ssc is not None: + self.ssc.stop() + if self.sc is not None: + self.sc.stop() + # Why does it long time to terminate StremaingContext and SparkContext? + # Should we change the sleep time if this depends on machine spec? + time.sleep(1) + + @classmethod + def tearDownClass(cls): + # Make sure tp shutdown the callback server + SparkContext._gateway._shutdown_callback_server() + + def test_from_no_conf_constructor(self): + self.ssc = StreamingContext(master=self.master, appName=self.appName, + duration=self.batachDuration) + # Alternative call master: ssc.sparkContext.master + # I try to make code close to Scala. + self.assertEqual(self.ssc.sparkContext._conf.get("spark.master"), self.master) + self.assertEqual(self.ssc.sparkContext._conf.get("spark.app.name"), self.appName) + + def test_from_no_conf_plus_spark_home(self): + self.ssc = StreamingContext(master=self.master, appName=self.appName, + sparkHome=self.sparkHome, duration=self.batachDuration) + self.assertEqual(self.ssc.sparkContext._conf.get("spark.home"), self.sparkHome) + + def test_from_no_conf_plus_spark_home_plus_env(self): + self.ssc = StreamingContext(master=self.master, appName=self.appName, + sparkHome=self.sparkHome, environment=self.envPair, + duration=self.batachDuration) + self.assertEqual(self.ssc.sparkContext._conf.get("spark.executorEnv.key"), self.envPair["key"]) + + def test_from_existing_spark_context(self): + self.sc = SparkContext(master=self.master, appName=self.appName) + self.ssc = StreamingContext(sparkContext=self.sc, duration=self.batachDuration) + + def test_existing_spark_context_with_settings(self): + conf = SparkConf() + conf.set("spark.cleaner.ttl", "10") + self.sc = SparkContext(master=self.master, appName=self.appName, conf=conf) + self.ssc = StreamingContext(sparkContext=self.sc, duration=self.batachDuration) + self.assertEqual(int(self.ssc.sparkContext._conf.get("spark.cleaner.ttl")), 10) + + def test_from_conf_with_settings(self): + conf = SparkConf() + conf.set("spark.cleaner.ttl", "10") + conf.setMaster(self.master) + conf.setAppName(self.appName) + self.ssc = StreamingContext(conf=conf, duration=self.batachDuration) + self.assertEqual(int(self.ssc.sparkContext._conf.get("spark.cleaner.ttl")), 10) + + def test_stop_only_streaming_context(self): + self.sc = SparkContext(master=self.master, appName=self.appName) + self.ssc = StreamingContext(sparkContext=self.sc, duration=self.batachDuration) + self._addInputStream(self.ssc) + self.ssc.start() + self.ssc.stop(False) + self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) + + def test_stop_multiple_times(self): + self.ssc = StreamingContext(master=self.master, appName=self.appName, + duration=self.batachDuration) + self._addInputStream(self.ssc) + self.ssc.start() + self.ssc.stop() + self.ssc.stop() + + def _addInputStream(self, s): + # Make sure each length of input is over 3 and + # numSlice is 2 due to deserializer problem in pyspark.streaming + test_inputs = map(lambda x: range(1, x), range(5, 101)) + test_stream = s._testInputStream(test_inputs, 2) + # Register fake output operation + result = list() + test_stream._test_output(result) + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py new file mode 100644 index 0000000000000..cf90952543fc0 --- /dev/null +++ b/python/pyspark/streaming/util.py @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.rdd import RDD + + +class RDDFunction(): + """ + This class is for py4j callback. This class is related with + org.apache.spark.streaming.api.python.PythonRDDFunction. + """ + def __init__(self, ctx, jrdd_deserializer, func): + self.ctx = ctx + self.deserializer = jrdd_deserializer + self.func = func + + def call(self, jrdd, time): + # Wrap JavaRDD into python's RDD class + rdd = RDD(jrdd, self.ctx, self.deserializer) + # Call user defined RDD function + self.func(rdd, time) + + def __str__(self): + return "%s, %s" % (str(self.deserializer), str(self.func)) + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction'] + + +def msDurationToString(ms): + """ + Returns a human-readable string representing a duration such as "35ms" + + >> msDurationToString(10) + '10 ms' + >>> msDurationToString(1000) + '1.0 s' + >>> msDurationToString(60000) + '1.0 m' + >>> msDurationToString(3600000) + '1.00 h' + """ + second = 1000 + minute = 60 * second + hour = 60 * minute + + if ms < second: + return "%d ms" % ms + elif ms < minute: + return "%.1f s" % (float(ms) / second) + elif ms < hour: + return "%.1f m" % (float(ms) / minute) + else: + return "%.2f h" % (float(ms) / hour) + + +def rddToFileName(prefix, suffix, time): + """ + Return string prefix-time(.suffix) + + >>> rddToFileName("spark", None, 12345678910) + 'spark-12345678910' + >>> rddToFileName("spark", "tmp", 12345678910) + 'spark-12345678910.tmp' + + """ + if suffix is None: + return prefix + "-" + str(time) + else: + return prefix + "-" + str(time) + "." + suffix diff --git a/python/pyspark/streaming/utils.py b/python/pyspark/streaming/utils.py new file mode 100644 index 0000000000000..5ba179cae7f9c --- /dev/null +++ b/python/pyspark/streaming/utils.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.rdd import RDD + + +class RDDFunction(): + """ + This class is for py4j callback. This + """ + def __init__(self, ctx, jrdd_deserializer, func): + self.ctx = ctx + self.deserializer = jrdd_deserializer + self.func = func + + def call(self, jrdd, time): + # Wrap JavaRDD into python's RDD class + rdd = RDD(jrdd, self.ctx, self.deserializer) + # Call user defined RDD function + self.func(rdd, time) + + def __str__(self): + return "%s, %s" % (str(self.deserializer), str(self.func)) + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction'] + + +def msDurationToString(ms): + #TODO: add doctest + """ + Returns a human-readable string representing a duration such as "35ms" + """ + second = 1000 + minute = 60 * second + hour = 60 * minute + + if ms < second: + return "%d ms" % ms + elif ms < minute: + return "%.1f s" % (float(ms) / second) + elif ms < hour: + return "%.1f m" % (float(ms) / minute) + else: + return "%.2f h" % (float(ms) / hour) + + +def rddToFileName(prefix, suffix, time): + #TODO: add doctest + if suffix is None: + return prefix + "-" + str(time) + else: + return prefix + "-" + str(time) + "." + suffix diff --git a/python/run-tests b/python/run-tests index a7ec270c7da21..79d7602ccbc87 100755 --- a/python/run-tests +++ b/python/run-tests @@ -70,6 +70,8 @@ export PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" +run_test "pyspark/streaming/duration.py" +run_test "pyspark/streaming/util.py" unset PYSPARK_DOC_TEST run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" @@ -83,6 +85,9 @@ run_test "pyspark/mllib/stat.py" run_test "pyspark/mllib/tests.py" run_test "pyspark/mllib/tree.py" run_test "pyspark/mllib/util.py" +if [ -n "$_RUN_STREAMING_TESTS" ]; then + run_test "pyspark/streaming/tests.py" +fi # Try to test with PyPy if [ $(which pypy) ]; then diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 9dc26dc6b32a1..662cd8d22c6a5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -549,6 +549,10 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * JavaStreamingContext object contains a number of utility functions. */ object JavaStreamingContext { + implicit def fromStreamingContext(ssc: StreamingContext): + JavaStreamingContext = new JavaStreamingContext(ssc) + + implicit def toStreamingContext(jssc: JavaStreamingContext): StreamingContext = jssc.ssc /** * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala new file mode 100644 index 0000000000000..720823d78a110 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.python + +import java.io._ +import java.util.{List => JList, ArrayList => JArrayList, Map => JMap} + +import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.api.java._ +import org.apache.spark.api.python._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.streaming.{StreamingContext, Duration, Time} +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.api.java._ + + +class PythonDStream[T: ClassTag]( + parent: DStream[T], + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) + extends DStream[Array[Byte]](parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + //pythonDStream compute + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + parent.getOrCompute(validTime) match{ + case Some(rdd) => + // create PythonRDD to compute Python functions. + val pythonRDD = new PythonRDD(rdd, command, envVars, pythonIncludes, + preservePartitoning, pythonExec, broadcastVars, accumulator) + Some(pythonRDD.asJavaRDD.rdd) + case None => None + } + } + + def foreachRDD(foreachFunc: PythonRDDFunction) { + new PythonForeachDStream(this, context.sparkContext.clean(foreachFunc, false)).register() + } + + val asJavaDStream = JavaDStream.fromDStream(this) +} + + +private class PythonPairwiseDStream( + prev:DStream[Array[Byte]], + partitioner: Partitioner + ) extends DStream[Array[Byte]](prev.ssc){ + override def dependencies = List(prev) + + override def slideDuration: Duration = prev.slideDuration + + override def compute(validTime:Time):Option[RDD[Array[Byte]]]={ + prev.getOrCompute(validTime) match{ + case Some(rdd)=>Some(rdd) + val pairwiseRDD = new PairwiseRDD(rdd) + /* + * Since python function is executed by Scala after StreamingContext.start. + * What PythonPairwiseDStream does is equivalent to python code in pyspark. + * + * with _JavaStackTrace(self.context) as st: + * pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() + * partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, + * id(partitionFunc)) + * jrdd = pairRDD.partitionBy(partitioner).values() + * rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) + */ + Some(pairwiseRDD.asJavaPairRDD.partitionBy(partitioner).values().rdd) + case None => None + } + } + + val asJavaDStream = JavaDStream.fromDStream(this) +} + + +class PythonForeachDStream( + prev: DStream[Array[Byte]], + foreachFunction: PythonRDDFunction + ) extends ForEachDStream[Array[Byte]]( + prev, + (rdd: RDD[Array[Byte]], time: Time) => { + foreachFunction.call(rdd.toJavaRDD(), time.milliseconds) + } + ) { + + this.register() +} + + +/** + * This is a input stream just for the unitest. This is equivalent to a checkpointable, + * replayable, reliable message queue like Kafka. It requires a JArrayList of JavaRDD, + * and returns the i_th element at the i_th batch under manual clock. + */ + +class PythonTestInputStream( + ssc_ : JavaStreamingContext, + inputRDDs: JArrayList[JavaRDD[Array[Byte]]] + ) extends InputDStream[Array[Byte]](JavaStreamingContext.toStreamingContext(ssc_)) { + + def start() {} + + def stop() {} + + def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val emptyRDD = ssc.sparkContext.emptyRDD[Array[Byte]] + val index = ((validTime - zeroTime) / slideDuration - 1).toInt + val selectedRDD = { + if (inputRDDs.isEmpty) { + emptyRDD + } else if (index < inputRDDs.size()) { + inputRDDs.get(index).rdd + } else { + emptyRDD + } + } + + Some(selectedRDD) + } + + val asJavaDStream = JavaDStream.fromDStream(this) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonRDDFunction.java b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonRDDFunction.java new file mode 100644 index 0000000000000..eacff4b0e6006 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonRDDFunction.java @@ -0,0 +1,12 @@ +package org.apache.spark.streaming.api.python; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.streaming.Time; + +/* + * Interface for py4j callback function. + * This interface is related to pyspark.streaming.dstream.DStream.foreachRDD . + */ +public interface PythonRDDFunction { + JavaRDD call(JavaRDD rdd, long time); +} \ No newline at end of file